aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--README.md2
-rw-r--r--configure.py10
-rw-r--r--tensorflow/BUILD9
-rw-r--r--tensorflow/__init__.py3
-rw-r--r--tensorflow/c/c_api.cc3
-rw-r--r--tensorflow/c/eager/c_api_test.cc57
-rw-r--r--tensorflow/cc/framework/cc_op_gen.cc9
-rw-r--r--tensorflow/cc/gradients/math_grad.cc15
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc8
-rw-r--r--tensorflow/compiler/aot/BUILD2
-rw-r--r--tensorflow/compiler/aot/codegen.cc3
-rw-r--r--tensorflow/compiler/aot/embedded_protocol_buffers.cc6
-rw-r--r--tensorflow/compiler/jit/BUILD6
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op.cc5
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op_test.cc9
-rw-r--r--tensorflow/compiler/jit/deadness_analysis.cc451
-rw-r--r--tensorflow/compiler/jit/deadness_analysis_internal.h8
-rw-r--r--tensorflow/compiler/jit/deadness_analysis_test.cc373
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc115
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc40
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc11
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h5
-rw-r--r--tensorflow/compiler/jit/xla_device.cc5
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc3
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc3
-rw-r--r--tensorflow/compiler/jit/xla_tensor.h3
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py10
-rw-r--r--tensorflow/compiler/tests/eager_test.py1
-rw-r--r--tensorflow/compiler/tests/reverse_ops_test.py25
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py18
-rw-r--r--tensorflow/compiler/tf2xla/BUILD106
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond.cc1380
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond.h248
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond_test.cc182
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc1520
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.h6
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc69
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc72
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow_util.h56
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_while.cc668
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_while.h32
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_ops.cc78
-rw-r--r--tensorflow/compiler/tf2xla/kernels/gather_op.cc32
-rw-r--r--tensorflow/compiler/tf2xla/kernels/identity_op.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/if_op.cc19
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reverse_op.cc20
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tile_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.cc14
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc35
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h10
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc5
-rw-r--r--tensorflow/compiler/xla/BUILD17
-rw-r--r--tensorflow/compiler/xla/array2d.h4
-rw-r--r--tensorflow/compiler/xla/client/BUILD10
-rw-r--r--tensorflow/compiler/xla/client/client.cc12
-rw-r--r--tensorflow/compiler/xla/client/client_library.cc10
-rw-r--r--tensorflow/compiler/xla/client/compile_only_client.cc2
-rw-r--r--tensorflow/compiler/xla/client/lib/math.cc6
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc8
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc148
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h57
-rw-r--r--tensorflow/compiler/xla/client/xla_computation.cc4
-rw-r--r--tensorflow/compiler/xla/iterator_util_test.cc6
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc7
-rw-r--r--tensorflow/compiler/xla/literal.cc41
-rw-r--r--tensorflow/compiler/xla/literal.h8
-rw-r--r--tensorflow/compiler/xla/literal_comparison.cc57
-rw-r--r--tensorflow/compiler/xla/literal_test.cc13
-rw-r--r--tensorflow/compiler/xla/literal_util.cc20
-rw-r--r--tensorflow/compiler/xla/literal_util.h24
-rw-r--r--tensorflow/compiler/xla/packed_literal_reader.cc4
-rw-r--r--tensorflow/compiler/xla/python/BUILD1
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc13
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h6
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i1
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py9
-rw-r--r--tensorflow/compiler/xla/reference_util.cc51
-rw-r--r--tensorflow/compiler/xla/reference_util.h50
-rw-r--r--tensorflow/compiler/xla/reference_util_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/BUILD109
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc11
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/allocation_tracker.cc7
-rw-r--r--tensorflow/compiler/xla/service/backend.cc5
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification.cc9
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc20
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc18
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness_test.cc74
-rw-r--r--tensorflow/compiler/xla/service/call_graph.cc6
-rw-r--r--tensorflow/compiler/xla/service/call_inliner_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/channel_tracker.cc2
-rw-r--r--tensorflow/compiler/xla/service/computation_placer.cc8
-rw-r--r--tensorflow/compiler/xla/service/convolution_feature_group_converter.cc248
-rw-r--r--tensorflow/compiler/xla/service/convolution_feature_group_converter.h43
-rw-r--r--tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc100
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD8
-rw-r--r--tensorflow/compiler/xla/service/cpu/compiler_functor.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc34
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc59
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc3
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/cpu/vector_support_library.cc5
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc12
-rw-r--r--tensorflow/compiler/xla/service/executable.cc5
-rw-r--r--tensorflow/compiler/xla/service/execution_tracker.cc6
-rw-r--r--tensorflow/compiler/xla/service/gather_expander.cc185
-rw-r--r--tensorflow/compiler/xla/service/gather_expander_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD21
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_allocations.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/conditional_thunk.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc103
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc50
-rw-r--r--tensorflow/compiler/xla/service/gpu/for_thunk.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger.cc22
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_schedule.cc10
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/infeed_manager.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc127
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc11
-rw-r--r--tensorflow/compiler/xla/service/gpu/outfeed_manager.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/partition_assignment.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/BUILD6
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/tuple_thunk.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_thunk.cc6
-rw-r--r--tensorflow/compiler/xla/service/graphviz_example.cc4
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc12
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc25
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto8
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc103
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h7
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation_test.cc23
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc23
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc255
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc176
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h65
-rw-r--r--tensorflow/compiler/xla/service/hlo_execution_profile.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc254
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h29
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc66
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc207
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h40
-rw-r--r--tensorflow/compiler/xla/service/hlo_liveness_analysis.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc13
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc112
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_fix.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.cc13
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h5
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc107
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis_test.cc300
-rw-r--r--tensorflow/compiler/xla/service/inliner_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc3
-rw-r--r--tensorflow/compiler/xla/service/interpreter/BUILD8
-rw-r--r--tensorflow/compiler/xla/service/interpreter/compiler.cc10
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.cc2
-rw-r--r--tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc4
-rw-r--r--tensorflow/compiler/xla/service/interpreter/platform.cc5
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc24
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.h3
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc2
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer_analysis.cc3
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.h10
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover.cc4
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/scatter_expander.cc3
-rw-r--r--tensorflow/compiler/xla/service/service.cc13
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc225
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h7
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc269
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.cc2
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/stream_pool.cc4
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.cc5
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h20
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc3
-rw-r--r--tensorflow/compiler/xla/service/while_loop_constant_sinking.cc11
-rw-r--r--tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc15
-rw-r--r--tensorflow/compiler/xla/service/while_util.cc8
-rw-r--r--tensorflow/compiler/xla/service/while_util_test.cc3
-rw-r--r--tensorflow/compiler/xla/shape_tree.h2
-rw-r--r--tensorflow/compiler/xla/shape_tree_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/BUILD29
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc6
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h4
-rw-r--r--tensorflow/compiler/xla/tests/convert_test.cc13
-rw-r--r--tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc6
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc191
-rw-r--r--tensorflow/compiler/xla/tests/copy_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/custom_call_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/gather_operation_test.cc312
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc41
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h16
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.cc3
-rw-r--r--tensorflow/compiler/xla/tests/llvm_compiler_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc8
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.cc2
-rw-r--r--tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/pad_test.cc30
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc11
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc201
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.h21
-rw-r--r--tensorflow/compiler/xla/tests/test_utils_test.cc47
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc5
-rw-r--r--tensorflow/compiler/xla/text_literal_reader.cc4
-rw-r--r--tensorflow/compiler/xla/util.h116
-rw-r--r--tensorflow/compiler/xla/xla.proto9
-rw-r--r--tensorflow/compiler/xla/xla_data.proto18
-rw-r--r--tensorflow/contrib/BUILD2
-rw-r--r--tensorflow/contrib/__init__.py3
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py2
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py61
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py15
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py4
-rw-r--r--tensorflow/contrib/checkpoint/python/BUILD28
-rw-r--r--tensorflow/contrib/checkpoint/python/python_state.py166
-rw-r--r--tensorflow/contrib/checkpoint/python/python_state_test.py101
-rw-r--r--tensorflow/contrib/cloud/python/ops/gcs_config_ops.py5
-rw-r--r--tensorflow/contrib/cmake/external/nsync.cmake8
-rw-r--r--tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt325
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt2
-rw-r--r--tensorflow/contrib/crf/__init__.py2
-rw-r--r--tensorflow/contrib/data/__init__.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py28
-rw-r--r--tensorflow/contrib/distribute/BUILD1
-rw-r--r--tensorflow/contrib/distribute/__init__.py2
-rw-r--r--tensorflow/contrib/distribute/python/BUILD26
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py19
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops.py11
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops_test.py10
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py4
-rw-r--r--tensorflow/contrib/distribute/python/metrics_v1_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py19
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py148
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_test.py32
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_strategy.py141
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_strategy_test.py62
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py16
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy.py105
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py73
-rw-r--r--tensorflow/contrib/distribute/python/step_fn.py4
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py22
-rw-r--r--tensorflow/contrib/distribute/python/values.py180
-rw-r--r--tensorflow/contrib/eager/python/BUILD1
-rw-r--r--tensorflow/contrib/eager/python/examples/densenet/densenet_test.py11
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py11
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet_test.py8
-rw-r--r--tensorflow/contrib/eager/python/saver_test.py51
-rw-r--r--tensorflow/contrib/eager/python/tfe.py4
-rw-r--r--tensorflow/contrib/estimator/__init__.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/export.py4
-rw-r--r--tensorflow/contrib/estimator/python/estimator/hooks.py26
-rw-r--r--tensorflow/contrib/estimator/python/estimator/hooks_test.py32
-rw-r--r--tensorflow/contrib/ffmpeg/__init__.py2
-rw-r--r--tensorflow/contrib/framework/__init__.py4
-rw-r--r--tensorflow/contrib/framework/python/ops/critical_section_ops.py5
-rw-r--r--tensorflow/contrib/framework/python/ops/script_ops.py2
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h6
-rw-r--r--tensorflow/contrib/graph_editor/__init__.py4
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py76
-rw-r--r--tensorflow/contrib/image/python/ops/interpolate_spline.py35
-rw-r--r--tensorflow/contrib/integrate/__init__.py4
-rw-r--r--tensorflow/contrib/layers/BUILD2
-rw-r--r--tensorflow/contrib/layers/__init__.py4
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column.py9
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_test.py51
-rw-r--r--tensorflow/contrib/layers/python/layers/normalization.py25
-rw-r--r--tensorflow/contrib/layers/python/layers/normalization_test.py100
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib.py16
-rw-r--r--tensorflow/contrib/learn/__init__.py3
-rw-r--r--tensorflow/contrib/linalg/__init__.py3
-rw-r--r--tensorflow/contrib/lite/BUILD8
-rw-r--r--tensorflow/contrib/lite/build_def.bzl1
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h1
-rw-r--r--tensorflow/contrib/lite/context.h17
-rw-r--r--tensorflow/contrib/lite/delegates/eager/BUILD21
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate.cc37
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate.h22
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_test.cc12
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util.cc6
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util.h4
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util_test.cc10
-rw-r--r--tensorflow/contrib/lite/examples/android/build.gradle1
-rw-r--r--tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h6
-rw-r--r--tensorflow/contrib/lite/g3doc/_book.yaml1
-rw-r--r--tensorflow/contrib/lite/interpreter.h32
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc58
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java44
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java70
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java29
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java51
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java15
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java19
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java22
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD1
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h12
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h138
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h24
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h300
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h212
-rw-r--r--tensorflow/contrib/lite/kernels/pack.cc47
-rw-r--r--tensorflow/contrib/lite/kernels/pack_test.cc40
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc5
-rw-r--r--tensorflow/contrib/lite/model.cc12
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc1
-rw-r--r--tensorflow/contrib/lite/python/BUILD3
-rw-r--r--tensorflow/contrib/lite/python/convert.py35
-rw-r--r--tensorflow/contrib/lite/python/convert_test.py93
-rw-r--r--tensorflow/contrib/lite/python/lite.py11
-rw-r--r--tensorflow/contrib/lite/python/op_hint.py898
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py6
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs7
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h156
-rw-r--r--tensorflow/contrib/lite/string.h6
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py65
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc6
-rw-r--r--tensorflow/contrib/lite/toco/allocate_transient_arrays.cc16
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc10
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc1
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_python_api.h6
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc10
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h3
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/BUILD41
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc13
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h8
-rw-r--r--tensorflow/contrib/lite/util.cc7
-rw-r--r--tensorflow/contrib/lite/util.h10
-rw-r--r--tensorflow/contrib/lite/util_test.cc10
-rw-r--r--tensorflow/contrib/lookup/BUILD1
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py55
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py168
-rw-r--r--tensorflow/contrib/losses/__init__.py2
-rw-r--r--tensorflow/contrib/losses/python/losses/__init__.py2
-rw-r--r--tensorflow/contrib/losses/python/metric_learning/__init__.py4
-rwxr-xr-xtensorflow/contrib/makefile/compile_nsync.sh1
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt1
-rw-r--r--tensorflow/contrib/metrics/__init__.py4
-rw-r--r--tensorflow/contrib/model_pruning/BUILD1
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning.py4
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_utils.py70
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_utils_test.py62
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager.h7
-rw-r--r--tensorflow/contrib/opt/BUILD16
-rw-r--r--tensorflow/contrib/opt/__init__.py2
-rw-r--r--tensorflow/contrib/opt/python/training/elastic_average_optimizer.py166
-rw-r--r--tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py113
-rw-r--r--tensorflow/contrib/opt/python/training/lars_optimizer.py164
-rw-r--r--tensorflow/contrib/opt/python/training/lars_optimizer_test.py127
-rw-r--r--tensorflow/contrib/quantize/BUILD2
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph.py53
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph_test.py15
-rw-r--r--tensorflow/contrib/rnn/__init__.py2
-rw-r--r--tensorflow/contrib/seq2seq/__init__.py4
-rw-r--r--tensorflow/contrib/signal/__init__.py4
-rw-r--r--tensorflow/contrib/tensor_forest/BUILD1
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py20
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc19
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h4
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc34
-rw-r--r--tensorflow/contrib/tensorrt/BUILD18
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc261
-rw-r--r--tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc9
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.cc78
-rw-r--r--tensorflow/contrib/tensorrt/test/base_test.py154
-rw-r--r--tensorflow/contrib/tensorrt/test/batch_matmul_test.py42
-rw-r--r--tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py58
-rw-r--r--tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py47
-rw-r--r--tensorflow/contrib/tensorrt/test/concatenation_test.py13
-rw-r--r--tensorflow/contrib/tensorrt/test/const_broadcast_test.py21
-rw-r--r--tensorflow/contrib/tensorrt/test/manual_test.py114
-rw-r--r--tensorflow/contrib/tensorrt/test/memory_alignment_test.py21
-rw-r--r--tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py13
-rw-r--r--tensorflow/contrib/tensorrt/test/neighboring_engine_test.py19
-rw-r--r--tensorflow/contrib/tensorrt/test/rank_two_test.py89
-rw-r--r--tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py304
-rw-r--r--tensorflow/contrib/tensorrt/test/unary_test.py19
-rw-r--r--tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py21
-rw-r--r--tensorflow/contrib/tensorrt/test/vgg_block_test.py21
-rw-r--r--tensorflow/contrib/timeseries/examples/BUILD2
-rw-r--r--tensorflow/contrib/timeseries/examples/predict.py16
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py2
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators_test.py6
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head.py3
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head_test.py2
-rw-r--r--tensorflow/contrib/tpu/profiler/op_profile.proto2
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py95
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py69
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py53
-rw-r--r--tensorflow/contrib/training/__init__.py4
-rw-r--r--tensorflow/contrib/util/__init__.py2
-rw-r--r--tensorflow/core/BUILD19
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DivNoNan.pbtxt9
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Fill.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentMean.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtN.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentSum.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StaticRegexReplace.pbtxt26
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt20
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt5
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DivNoNan.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt6
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt4
-rw-r--r--tensorflow/core/common_runtime/eager/attr_builder.cc8
-rw-r--r--tensorflow/core/common_runtime/eager/attr_builder.h3
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc7
-rw-r--r--tensorflow/core/common_runtime/eager/context.h4
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc17
-rw-r--r--tensorflow/core/common_runtime/executor.cc123
-rw-r--r--tensorflow/core/common_runtime/step_stats_collector.cc86
-rw-r--r--tensorflow/core/common_runtime/step_stats_collector.h82
-rw-r--r--tensorflow/core/common_runtime/sycl/sycl_allocator.h6
-rw-r--r--tensorflow/core/distributed_runtime/BUILD1
-rw-r--r--tensorflow/core/distributed_runtime/master.cc51
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel.cc14
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel.h2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc65
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc5
-rw-r--r--tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc2
-rw-r--r--tensorflow/core/distributed_runtime/test_utils.h14
-rw-r--r--tensorflow/core/distributed_runtime/worker_cache.h2
-rw-r--r--tensorflow/core/distributed_runtime/worker_cache_wrapper.h4
-rw-r--r--tensorflow/core/distributed_runtime/worker_session.cc5
-rw-r--r--tensorflow/core/framework/op_def_util.cc9
-rw-r--r--tensorflow/core/framework/op_def_util.h5
-rw-r--r--tensorflow/core/framework/register_types.h5
-rw-r--r--tensorflow/core/framework/resource_mgr.h4
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc39
-rw-r--r--tensorflow/core/kernels/BUILD70
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_impl.h5
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/BUILD63
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h132
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer_test.cc99
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h330
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream_test.cc276
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h344
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary_test.cc223
-rw-r--r--tensorflow/core/kernels/cast_op.cc8
-rw-r--r--tensorflow/core/kernels/colorspace_op.h6
-rw-r--r--tensorflow/core/kernels/concat_lib_cpu.h5
-rw-r--r--tensorflow/core/kernels/constant_op.cc5
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.cc2
-rw-r--r--tensorflow/core/kernels/cross_op.h6
-rw-r--r--tensorflow/core/kernels/cuda_solvers.h5
-rw-r--r--tensorflow/core/kernels/cwise_op_div.cc3
-rw-r--r--tensorflow/core/kernels/cwise_ops.h8
-rw-r--r--tensorflow/core/kernels/cwise_ops_gpu_common.cu.h6
-rw-r--r--tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h6
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc7
-rw-r--r--tensorflow/core/kernels/gemm_functors.h5
-rw-r--r--tensorflow/core/kernels/hexagon/soc_interface.h6
-rw-r--r--tensorflow/core/kernels/list_kernels.h7
-rw-r--r--tensorflow/core/kernels/lookup_table_op.h9
-rw-r--r--tensorflow/core/kernels/matrix_band_part_op.h6
-rw-r--r--tensorflow/core/kernels/matrix_diag_op.h6
-rw-r--r--tensorflow/core/kernels/matrix_solve_ls_op_impl.h5
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc174
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc144
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc157
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.h414
-rw-r--r--tensorflow/core/kernels/parameterized_truncated_normal_op.cc34
-rw-r--r--tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/pooling_ops_3d_gpu.h6
-rw-r--r--tensorflow/core/kernels/qr_op_impl.h7
-rw-r--r--tensorflow/core/kernels/reduction_gpu_kernels.cu.h7
-rw-r--r--tensorflow/core/kernels/regex_replace_op.cc80
-rw-r--r--tensorflow/core/kernels/regex_replace_op_test.cc137
-rw-r--r--tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h5
-rw-r--r--tensorflow/core/kernels/softplus_op.cc11
-rw-r--r--tensorflow/core/kernels/softsign_op.cc11
-rw-r--r--tensorflow/core/kernels/sparse_xent_op.h6
-rw-r--r--tensorflow/core/kernels/strided_slice_op.cc38
-rw-r--r--tensorflow/core/kernels/string_length_op.cc45
-rw-r--r--tensorflow/core/kernels/string_split_op.cc111
-rw-r--r--tensorflow/core/kernels/string_split_op_test.cc129
-rw-r--r--tensorflow/core/kernels/svd_op_impl.h5
-rw-r--r--tensorflow/core/kernels/tensor_array_ops.cc12
-rw-r--r--tensorflow/core/kernels/warn_about_ints.cc33
-rw-r--r--tensorflow/core/kernels/where_op_gpu.cu.h5
-rw-r--r--tensorflow/core/kernels/xent_op.h6
-rw-r--r--tensorflow/core/lib/core/stringpiece.h16
-rw-r--r--tensorflow/core/lib/core/stringpiece_test.cc4
-rw-r--r--tensorflow/core/ops/array_ops_test.cc18
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt97
-rw-r--r--tensorflow/core/ops/math_grad.cc8
-rw-r--r--tensorflow/core/ops/math_grad_test.cc6
-rw-r--r--tensorflow/core/ops/math_ops.cc7
-rw-r--r--tensorflow/core/ops/math_ops_test.cc2
-rw-r--r--tensorflow/core/ops/nn_ops.cc87
-rw-r--r--tensorflow/core/ops/ops.pbtxt97
-rw-r--r--tensorflow/core/ops/string_ops.cc13
-rw-r--r--tensorflow/core/platform/default/build_config.bzl1187
-rw-r--r--tensorflow/core/platform/default/protobuf.h2
-rw-r--r--tensorflow/core/platform/default/protobuf_compiler.h (renamed from tensorflow/compiler/xla/ptr_util.h)26
-rw-r--r--tensorflow/core/platform/protobuf_compiler.h (renamed from tensorflow/core/kernels/warn_about_ints.h)20
-rw-r--r--tensorflow/core/util/env_var.h5
-rw-r--r--tensorflow/core/util/mkl_util.h103
-rw-r--r--tensorflow/core/util/strided_slice_op.cc2
-rw-r--r--tensorflow/core/util/tensor_format.cc4
-rw-r--r--tensorflow/core/util/tensor_format.h1
-rw-r--r--tensorflow/docs_src/about/index.md6
-rw-r--r--tensorflow/docs_src/api_guides/python/client.md2
-rw-r--r--tensorflow/docs_src/api_guides/python/constant_op.md2
-rw-r--r--tensorflow/docs_src/api_guides/python/input_dataset.md2
-rw-r--r--tensorflow/docs_src/api_guides/python/io_ops.md10
-rw-r--r--tensorflow/docs_src/api_guides/python/meta_graph.md2
-rw-r--r--tensorflow/docs_src/api_guides/python/reading_data.md24
-rw-r--r--tensorflow/docs_src/api_guides/python/regression_examples.md2
-rw-r--r--tensorflow/docs_src/api_guides/python/summary.md2
-rw-r--r--tensorflow/docs_src/api_guides/python/threading_and_queues.md2
-rw-r--r--tensorflow/docs_src/api_guides/python/train.md8
-rw-r--r--tensorflow/docs_src/community/contributing.md6
-rw-r--r--tensorflow/docs_src/community/index.md6
-rw-r--r--tensorflow/docs_src/community/style_guide.md2
-rw-r--r--tensorflow/docs_src/deploy/distributed.md2
-rw-r--r--tensorflow/docs_src/deploy/hadoop.md4
-rw-r--r--tensorflow/docs_src/deploy/index.md6
-rw-r--r--tensorflow/docs_src/deploy/s3.md2
-rw-r--r--tensorflow/docs_src/extend/add_filesys.md2
-rw-r--r--tensorflow/docs_src/extend/adding_an_op.md10
-rw-r--r--tensorflow/docs_src/extend/architecture.md8
-rw-r--r--tensorflow/docs_src/extend/index.md12
-rw-r--r--tensorflow/docs_src/extend/language_bindings.md2
-rw-r--r--tensorflow/docs_src/extend/new_data_formats.md10
-rw-r--r--tensorflow/docs_src/guide/checkpoints.md8
-rw-r--r--tensorflow/docs_src/guide/custom_estimators.md14
-rw-r--r--tensorflow/docs_src/guide/datasets.md16
-rw-r--r--tensorflow/docs_src/guide/datasets_for_estimators.md14
-rw-r--r--tensorflow/docs_src/guide/debugger.md2
-rw-r--r--tensorflow/docs_src/guide/eager.md2
-rw-r--r--tensorflow/docs_src/guide/embedding.md2
-rw-r--r--tensorflow/docs_src/guide/estimators.md4
-rw-r--r--tensorflow/docs_src/guide/faq.md38
-rw-r--r--tensorflow/docs_src/guide/feature_columns.md6
-rw-r--r--tensorflow/docs_src/guide/graph_viz.md4
-rw-r--r--tensorflow/docs_src/guide/graphs.md8
-rw-r--r--tensorflow/docs_src/guide/index.md46
-rw-r--r--tensorflow/docs_src/guide/low_level_intro.md18
-rw-r--r--tensorflow/docs_src/guide/premade_estimators.md18
-rw-r--r--tensorflow/docs_src/guide/saved_model.md10
-rw-r--r--tensorflow/docs_src/guide/summaries_and_tensorboard.md8
-rw-r--r--tensorflow/docs_src/guide/tensors.md2
-rw-r--r--tensorflow/docs_src/guide/using_gpu.md2
-rw-r--r--tensorflow/docs_src/guide/using_tpu.md16
-rw-r--r--tensorflow/docs_src/guide/version_compat.md9
-rw-r--r--tensorflow/docs_src/install/index.md18
-rw-r--r--tensorflow/docs_src/install/install_c.md4
-rw-r--r--tensorflow/docs_src/install/install_go.md4
-rw-r--r--tensorflow/docs_src/install/install_java.md6
-rw-r--r--tensorflow/docs_src/install/install_linux.md2
-rw-r--r--tensorflow/docs_src/install/install_sources.md4
-rw-r--r--tensorflow/docs_src/performance/index.md22
-rw-r--r--tensorflow/docs_src/performance/performance_guide.md16
-rw-r--r--tensorflow/docs_src/performance/performance_models.md2
-rw-r--r--tensorflow/docs_src/performance/quantization.md2
-rw-r--r--tensorflow/docs_src/performance/xla/index.md10
-rw-r--r--tensorflow/docs_src/performance/xla/jit.md2
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md318
-rw-r--r--tensorflow/docs_src/performance/xla/tfcompile.md4
-rw-r--r--tensorflow/docs_src/tutorials/_toc.yaml50
-rw-r--r--tensorflow/docs_src/tutorials/eager/index.md1
-rw-r--r--tensorflow/docs_src/tutorials/estimators/cnn.md16
-rw-r--r--tensorflow/docs_src/tutorials/images/deep_cnn.md20
-rw-r--r--tensorflow/docs_src/tutorials/images/image_recognition.md4
-rw-r--r--tensorflow/docs_src/tutorials/representation/kernel_methods.md4
-rw-r--r--tensorflow/docs_src/tutorials/representation/linear.md4
-rw-r--r--tensorflow/docs_src/tutorials/representation/word2vec.md4
-rw-r--r--tensorflow/docs_src/tutorials/sequences/recurrent.md2
-rw-r--r--tensorflow/docs_src/tutorials/sequences/recurrent_quickdraw.md11
-rw-r--r--tensorflow/examples/ios/benchmark/ios_image_load.h6
-rw-r--r--tensorflow/examples/ios/camera/ios_image_load.h6
-rw-r--r--tensorflow/go/op/wrappers.go1380
-rw-r--r--tensorflow/js/BUILD52
-rw-r--r--tensorflow/js/ops/ts_op_gen.cc199
-rw-r--r--tensorflow/js/ops/ts_op_gen.h (renamed from tensorflow/contrib/lite/delegates/eager/constants.h)24
-rw-r--r--tensorflow/js/ops/ts_op_gen_test.cc212
-rw-r--r--tensorflow/python/BUILD22
-rw-r--r--tensorflow/python/client/client_lib.py2
-rw-r--r--tensorflow/python/client/session.py2
-rw-r--r--tensorflow/python/compat/compat.py12
-rw-r--r--tensorflow/python/data/__init__.py2
-rw-r--r--tensorflow/python/data/ops/BUILD1
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py4
-rw-r--r--tensorflow/python/data/util/BUILD35
-rw-r--r--tensorflow/python/data/util/structure.py315
-rw-r--r--tensorflow/python/data/util/structure_test.py327
-rw-r--r--tensorflow/python/debug/__init__.py2
-rw-r--r--tensorflow/python/distribute/BUILD40
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py204
-rw-r--r--tensorflow/python/distribute/distribute_coordinator_context.py31
-rw-r--r--tensorflow/python/distribute/distribute_coordinator_test.py248
-rw-r--r--tensorflow/python/distribute/multi_worker_util.py80
-rw-r--r--tensorflow/python/distribute/multi_worker_util_test.py107
-rw-r--r--tensorflow/python/eager/BUILD36
-rw-r--r--tensorflow/python/eager/backprop.py78
-rw-r--r--tensorflow/python/eager/benchmarks_test.py30
-rw-r--r--tensorflow/python/eager/function.py379
-rw-r--r--tensorflow/python/eager/graph_callable.py435
-rw-r--r--tensorflow/python/eager/graph_callable_test.py249
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc3
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py248
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py20
-rw-r--r--tensorflow/python/estimator/canned/prediction_keys.py1
-rw-r--r--tensorflow/python/estimator/estimator.py231
-rw-r--r--tensorflow/python/estimator/estimator_test.py13
-rw-r--r--tensorflow/python/estimator/exporter_test.py37
-rw-r--r--tensorflow/python/estimator/gc.py8
-rw-r--r--tensorflow/python/estimator/gc_test.py11
-rw-r--r--tensorflow/python/estimator/keras.py311
-rw-r--r--tensorflow/python/estimator/keras_test.py10
-rw-r--r--tensorflow/python/estimator/model_fn.py2
-rw-r--r--tensorflow/python/estimator/training.py15
-rw-r--r--tensorflow/python/estimator/training_test.py33
-rw-r--r--tensorflow/python/framework/constant_op.py13
-rw-r--r--tensorflow/python/framework/ops.py32
-rw-r--r--tensorflow/python/framework/python_op_gen.cc9
-rw-r--r--tensorflow/python/framework/python_op_gen_internal.cc9
-rw-r--r--tensorflow/python/framework/tensor_shape.py3
-rw-r--r--tensorflow/python/framework/test_util.py328
-rw-r--r--tensorflow/python/framework/test_util_test.py45
-rwxr-xr-xtensorflow/python/keras/BUILD2
-rw-r--r--tensorflow/python/keras/backend.py11
-rw-r--r--tensorflow/python/keras/backend_test.py24
-rw-r--r--tensorflow/python/keras/callbacks_test.py107
-rw-r--r--tensorflow/python/keras/engine/network.py12
-rw-r--r--tensorflow/python/keras/engine/sequential.py4
-rw-r--r--tensorflow/python/keras/engine/sequential_test.py39
-rw-r--r--tensorflow/python/keras/engine/training.py26
-rw-r--r--tensorflow/python/keras/engine/training_test.py1391
-rw-r--r--tensorflow/python/keras/models.py218
-rw-r--r--tensorflow/python/keras/models_test.py134
-rw-r--r--tensorflow/python/keras/testing_utils.py19
-rw-r--r--tensorflow/python/kernel_tests/BUILD33
-rw-r--r--tensorflow/python/kernel_tests/batch_scatter_ops_test.py129
-rw-r--r--tensorflow/python/kernel_tests/list_ops_test.py25
-rw-r--r--tensorflow/python/kernel_tests/parameterized_truncated_normal_op_test.py13
-rw-r--r--tensorflow/python/kernel_tests/partitioned_variables_test.py35
-rw-r--r--tensorflow/python/kernel_tests/regex_replace_op_test.py76
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py7
-rw-r--r--tensorflow/python/kernel_tests/softplus_op_test.py10
-rw-r--r--tensorflow/python/kernel_tests/softsign_op_test.py10
-rw-r--r--tensorflow/python/kernel_tests/string_length_op_test.py37
-rw-r--r--tensorflow/python/kernel_tests/string_split_op_test.py22
-rw-r--r--tensorflow/python/kernel_tests/template_test.py18
-rw-r--r--tensorflow/python/layers/base.py4
-rw-r--r--tensorflow/python/lib/core/py_func.cc2
-rw-r--r--tensorflow/python/lib/io/py_record_writer.cc13
-rw-r--r--tensorflow/python/lib/io/py_record_writer.h2
-rw-r--r--tensorflow/python/lib/io/python_io.py2
-rw-r--r--tensorflow/python/lib/io/tf_record.py4
-rw-r--r--tensorflow/python/lib/io/tf_record_test.py6
-rw-r--r--tensorflow/python/ops/array_ops.py7
-rw-r--r--tensorflow/python/ops/check_ops.py3
-rw-r--r--tensorflow/python/ops/control_flow_ops.py3
-rw-r--r--tensorflow/python/ops/functional_ops.py3
-rw-r--r--tensorflow/python/ops/histogram_ops.py2
-rw-r--r--tensorflow/python/ops/image_ops.py2
-rw-r--r--tensorflow/python/ops/io_ops.py3
-rw-r--r--tensorflow/python/ops/math_grad.py10
-rw-r--r--tensorflow/python/ops/math_grad_test.py15
-rw-r--r--tensorflow/python/ops/math_ops.py38
-rw-r--r--tensorflow/python/ops/math_ops_test.py17
-rw-r--r--tensorflow/python/ops/metrics_impl.py202
-rw-r--r--tensorflow/python/ops/nn.py2
-rw-r--r--tensorflow/python/ops/nn_grad.py4
-rw-r--r--tensorflow/python/ops/nn_impl.py4
-rw-r--r--tensorflow/python/ops/nn_ops.py12
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py25
-rw-r--r--tensorflow/python/ops/script_ops.py5
-rw-r--r--tensorflow/python/ops/session_ops.py6
-rw-r--r--tensorflow/python/ops/sparse_ops.py15
-rw-r--r--tensorflow/python/ops/state_ops.py105
-rw-r--r--tensorflow/python/ops/string_ops.py39
-rw-r--r--tensorflow/python/ops/variable_scope.py26
-rw-r--r--tensorflow/python/ops/variables.py2
-rw-r--r--tensorflow/python/platform/test.py2
-rw-r--r--tensorflow/python/summary/summary.py2
-rw-r--r--tensorflow/python/tools/BUILD6
-rw-r--r--tensorflow/python/tools/component_api_helper.py85
-rw-r--r--tensorflow/python/tools/freeze_graph.py4
-rw-r--r--tensorflow/python/training/adagrad.py26
-rw-r--r--tensorflow/python/training/adagrad_test.py33
-rw-r--r--tensorflow/python/training/checkpointable/BUILD13
-rw-r--r--tensorflow/python/training/checkpointable/base.py128
-rw-r--r--tensorflow/python/training/checkpointable/data_structures.py13
-rw-r--r--tensorflow/python/training/checkpointable/data_structures_test.py6
-rw-r--r--tensorflow/python/training/checkpointable/layer_utils.py9
-rw-r--r--tensorflow/python/training/checkpointable/util.py59
-rw-r--r--tensorflow/python/training/checkpointable/util_test.py15
-rw-r--r--tensorflow/python/training/distribute.py10
-rw-r--r--tensorflow/python/training/input.py3
-rw-r--r--tensorflow/python/training/monitored_session.py116
-rw-r--r--tensorflow/python/training/monitored_session_test.py118
-rw-r--r--tensorflow/python/training/saver.py18
-rw-r--r--tensorflow/python/training/sync_replicas_optimizer.py6
-rw-r--r--tensorflow/python/training/training.py2
-rw-r--r--tensorflow/python/util/nest.py76
-rw-r--r--tensorflow/python/util/util.cc1
-rw-r--r--tensorflow/python/util/util.h9
-rw-r--r--tensorflow/python/util/util.i52
-rw-r--r--tensorflow/tensorflow.bzl6
-rw-r--r--tensorflow/tools/api/golden/BUILD2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-iterator.pbtxt1
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.applications.densenet.pbtxt23
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.applications.inception_resnet_v2.pbtxt15
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.applications.inception_v3.pbtxt15
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.applications.mobilenet.pbtxt15
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.applications.nasnet.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.applications.pbtxt87
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.applications.resnet50.pbtxt15
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.applications.vgg16.pbtxt15
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.applications.vgg19.pbtxt15
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.applications.xception.pbtxt15
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt23
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt29
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-iterator.pbtxt18
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt23
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.pbtxt63
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.pbtxt15
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.sequence.-timeseries-generator.pbtxt14
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.sequence.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt33
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.text.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.pbtxt20
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt4
-rw-r--r--tensorflow/tools/api/tests/api_compatibility_test.py4
-rw-r--r--tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh12
-rw-r--r--tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh12
-rw-r--r--tensorflow/tools/common/public_api.py5
-rwxr-xr-xtensorflow/tools/docker/parameterized_docker_build.sh2
-rw-r--r--tensorflow/tools/docs/generate.py5
-rw-r--r--tensorflow/tools/docs/generate_lib.py77
-rw-r--r--tensorflow/tools/docs/parser.py6
-rw-r--r--tensorflow/tools/pip_package/BUILD15
-rw-r--r--tensorflow/tools/pip_package/MANIFEST.in1
-rw-r--r--tensorflow/tools/proto_text/BUILD1
-rw-r--r--tensorflow/tools/proto_text/gen_proto_text_functions.cc1
-rw-r--r--tensorflow/workspace.bzl47
-rw-r--r--third_party/gpus/cuda/BUILD.windows.tpl1
-rw-r--r--third_party/kafka/BUILD6
-rw-r--r--third_party/ngraph/build_defs.bzl13
-rw-r--r--third_party/ngraph/ngraph.BUILD34
-rw-r--r--third_party/ngraph/ngraph_tf.BUILD8
-rw-r--r--third_party/ngraph/nlohmann_json.BUILD2
-rw-r--r--third_party/repo.bzl229
-rw-r--r--third_party/systemlibs/nsync.BUILD23
-rw-r--r--third_party/systemlibs/syslibs_configure.bzl174
843 files changed, 24492 insertions, 12340 deletions
diff --git a/README.md b/README.md
index 669ff5b711..16d354ca7b 100644
--- a/README.md
+++ b/README.md
@@ -100,7 +100,7 @@ The TensorFlow project strives to abide by generally accepted best practices in
| **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA |
| **IBM ppc64le GPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA |
| **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) |
-| **Linux CPU with Intel® MKL-DNN** Python 2.7<br> **Linux CPU with Intel® MKL-DNN** Python 3.5<br> **Linux CPU with Intel® MKL-DNN** Python 3.6| ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)|[1.9.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp27-cp27mu-linux_x86_64.whl)<br>[1.9.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp35-cp35m-linux_x86_64.whl)<br>[1.9.0 py3.6](https://storage.cloud.google.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp36-cp36m-linux_x86_64.whl) |
+| **Linux CPU with Intel® MKL-DNN** Python 2.7<br> **Linux CPU with Intel® MKL-DNN** Python 3.5<br> **Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild)|[1.9.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp27-cp27mu-linux_x86_64.whl)<br>[1.9.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp35-cp35m-linux_x86_64.whl)<br>[1.9.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp36-cp36m-linux_x86_64.whl) |
## For more information
diff --git a/configure.py b/configure.py
index 7acc6932eb..10fee6993e 100644
--- a/configure.py
+++ b/configure.py
@@ -848,7 +848,7 @@ def set_tf_cuda_version(environ_cp):
cuda_toolkit_paths_full = [os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths]
if any([os.path.exists(x) for x in cuda_toolkit_paths_full]):
- break
+ break
# Reset and retry
print('Invalid path to CUDA %s toolkit. %s cannot be found' %
@@ -1399,6 +1399,13 @@ def set_grpc_build_flags():
write_to_bazelrc('build --define grpc_no_ares=true')
+def set_system_libs_flag(environ_cp):
+ syslibs = environ_cp.get('TF_SYSTEM_LIBS', '')
+ syslibs = ','.join(sorted(syslibs.split(',')))
+ if syslibs and syslibs != '':
+ write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs)
+
+
def set_windows_build_flags(environ_cp):
"""Set Windows specific build options."""
# The non-monolithic build is not supported yet
@@ -1557,6 +1564,7 @@ def main():
set_grpc_build_flags()
set_cc_opt_flags(environ_cp)
+ set_system_libs_flag(environ_cp)
if is_windows():
set_windows_build_flags(environ_cp)
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index e5654a5141..9cc4c4567b 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -23,7 +23,10 @@ load(
"//tensorflow/python/tools/api/generator:api_gen.bzl",
"gen_api_init_files", # @unused
)
-load("//third_party/ngraph:build_defs.bzl", "if_ngraph")
+load(
+ "//third_party/ngraph:build_defs.bzl",
+ "if_ngraph",
+)
# Config setting used when building for products
# which requires restricted licenses to be avoided.
@@ -440,7 +443,7 @@ filegroup(
name = "intel_binary_blob",
data = if_mkl_ml(
[
- "//third_party/intel_mkl_ml",
+ "//third_party/mkl:intel_binary_blob",
],
),
)
@@ -572,7 +575,7 @@ tf_cc_shared_object(
"//tensorflow/cc:scope",
"//tensorflow/cc/profiler",
"//tensorflow/core:tensorflow",
- ] + if_ngraph(["@ngraph_tf//:ngraph_tf"])
+ ] + if_ngraph(["@ngraph_tf//:ngraph_tf"]),
)
exports_files(
diff --git a/tensorflow/__init__.py b/tensorflow/__init__.py
index 440e9f8dbd..21677512b6 100644
--- a/tensorflow/__init__.py
+++ b/tensorflow/__init__.py
@@ -28,7 +28,8 @@ contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
-app.flags = flags # pylint: disable=undefined-variable
+from tensorflow.python.platform import app # pylint: disable=g-import-not-at-top
+app.flags = flags
del absolute_import
del division
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 19ccb6e71d..b8adf6c127 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -202,7 +202,8 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
buf->len_ = len;
if (dtype != TF_STRING && dtype != TF_RESOURCE &&
tensorflow::DataTypeCanUseMemcpy(static_cast<DataType>(dtype)) &&
- reinterpret_cast<intptr_t>(data) % EIGEN_MAX_ALIGN_BYTES != 0) {
+ reinterpret_cast<intptr_t>(data) % std::max(1, EIGEN_MAX_ALIGN_BYTES) !=
+ 0) {
// TF_STRING and TF_RESOURCE tensors have a different representation in
// TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste
// (any alignment requirements will be taken care of by TF_TensorToTensor
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 71d5f3613c..7126227cf5 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -1471,4 +1471,61 @@ void BM_ReadVariable(int iters) {
}
BENCHMARK(BM_ReadVariable);
+TEST(CAPI, StringAttributes) {
+ // Test that TFE_OpSetAttrString doesn't hold on to the value after it
+ // returns.
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ std::vector<int64_t> dims(4, 1);
+ TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_Tensor* tensor =
+ TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float));
+ float tensor_data[] = {1};
+ memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor));
+ TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, tensor_handle, status);
+ TF_DeleteTensor(tensor);
+ TFE_DeleteTensorHandle(tensor_handle);
+
+ std::vector<int64_t> values(4, 1);
+ TFE_OpSetAttrIntList(op, "ksize", values.data(), values.size());
+ TFE_OpSetAttrIntList(op, "strides", values.data(), values.size());
+
+ const int BUFFER_SIZE = 10;
+ char buffer[BUFFER_SIZE];
+ std::strncpy(buffer, "VALID", BUFFER_SIZE);
+ TFE_OpSetAttrString(op, "padding", buffer, std::strlen(buffer));
+ // Overwriting value in "buffer", should be fine since TFE_Op
+ // shouldn't be holding on to it.
+ std::strncpy(buffer, "NHWC", BUFFER_SIZE);
+ TFE_OpSetAttrString(op, "data_format", buffer, std::strlen(buffer));
+
+ TFE_OpSetAttrType(op, "T", TF_FLOAT);
+
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_TensorHandle* retvals[1];
+ int num_retvals = 1;
+ TFE_Execute(op, &retvals[0], &num_retvals, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ ASSERT_EQ(1, num_retvals);
+
+ tensor = TFE_TensorHandleResolve(retvals[0], status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ EXPECT_EQ(4, TF_TensorByteSize(tensor));
+ TF_DeleteTensor(tensor);
+ TFE_DeleteTensorHandle(retvals[0]);
+
+ TFE_DeleteOp(op);
+
+ TFE_DeleteContext(ctx);
+ TF_DeleteStatus(status);
+}
} // namespace
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc
index dfdef88945..c20ea95a15 100644
--- a/tensorflow/cc/framework/cc_op_gen.cc
+++ b/tensorflow/cc/framework/cc_op_gen.cc
@@ -508,15 +508,6 @@ bool HasOptionalAttrs(
return false;
}
-const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
- for (int i = 0; i < api_def.in_arg_size(); ++i) {
- if (api_def.in_arg(i).name() == name) {
- return &api_def.in_arg(i);
- }
- }
- return nullptr;
-}
-
struct OpInfo {
// graph_op_def: The OpDef used by the runtime, has the names that
// must be used when calling NodeBuilder.
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index 5dcf00857d..1329b568ab 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -441,21 +441,20 @@ Status RealDivGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
-Status UnsafeDivGrad(const Scope& scope, const Operation& op,
- const std::vector<Output>& grad_inputs,
- std::vector<Output>* grad_outputs) {
+Status DivNoNanGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
auto x_1 = ConjugateHelper(scope, op.input(0));
auto x_2 = ConjugateHelper(scope, op.input(1));
// y = x_1 / x_2
// dy/dx_1 = 1/x_2
// dy/dx_2 = -x_1/x_2^2
- auto gx_1 = UnsafeDiv(scope, grad_inputs[0], x_2);
- auto gx_2 =
- Mul(scope, grad_inputs[0],
- UnsafeDiv(scope, UnsafeDiv(scope, Neg(scope, x_1), x_2), x_2));
+ auto gx_1 = DivNoNan(scope, grad_inputs[0], x_2);
+ auto gx_2 = Mul(scope, grad_inputs[0],
+ DivNoNan(scope, DivNoNan(scope, Neg(scope, x_1), x_2), x_2));
return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
}
-REGISTER_GRADIENT_OP("UnsafeDiv", UnsafeDivGrad);
+REGISTER_GRADIENT_OP("DivNoNan", DivNoNanGrad);
Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index 88aef1fab4..c16938322c 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -33,6 +33,7 @@ using ops::AddN;
using ops::BatchMatMul;
using ops::Const;
using ops::Div;
+using ops::DivNoNan;
using ops::MatMul;
using ops::Max;
using ops::Maximum;
@@ -48,7 +49,6 @@ using ops::SegmentSum;
using ops::SquaredDifference;
using ops::Sub;
using ops::Sum;
-using ops::UnsafeDiv;
// TODO(andydavis) Test gradient function against numeric gradients output.
// TODO(andydavis) As more gradients are added move common test functions
@@ -854,13 +854,13 @@ TEST_F(NaryGradTest, RealDiv) {
RunTest({x}, {x_shape}, {y}, {x_shape});
}
-TEST_F(NaryGradTest, UnsafeDiv) {
+TEST_F(NaryGradTest, DivNoNan) {
{
TensorShape x_shape({3, 2, 5});
const auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
// Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
// division errors in the numeric estimator used by the gradient checker.
- const auto y = UnsafeDiv(
+ const auto y = DivNoNan(
scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
RunTest({x}, {x_shape}, {y}, {x_shape});
}
@@ -868,7 +868,7 @@ TEST_F(NaryGradTest, UnsafeDiv) {
// Return 0 gradient (rather than NaN) for division by zero.
const auto x = Placeholder(scope_, DT_FLOAT);
const auto zero = Const<float>(scope_, 0.0);
- const auto y = UnsafeDiv(scope_, x, zero);
+ const auto y = DivNoNan(scope_, x, zero);
std::vector<Output> grad_outputs;
TF_EXPECT_OK(AddSymbolicGradients(scope_, {y}, {x}, &grad_outputs));
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index 1899a32e4d..2220d0786d 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -55,6 +55,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/memory",
],
)
@@ -193,6 +194,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
"@llvm//:core",
"@llvm//:support",
"@llvm//:target",
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index 89fefdad54..a8485576ac 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/compiler/tf2xla/str_util.h"
@@ -617,7 +618,7 @@ Status GenerateMetadata(const CodegenOpts& opts,
if (opts.gen_program_shape) {
program_shape =
- tensorflow::MakeUnique<xla::ProgramShape>(compile_result.program_shape);
+ absl::make_unique<xla::ProgramShape>(compile_result.program_shape);
// The parameter names are currently meaningless, and redundant with the
// rest of our metadata, so clear them out to avoid confusion and save
// space.
diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc
index 4e27aafec7..8fb2fad31c 100644
--- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc
+++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/memory/memory.h"
#include "llvm/ADT/Triple.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/LLVMContext.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "tensorflow/compiler/tf2xla/str_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/util.h"
@@ -105,7 +105,7 @@ GetTargetMachineFromTriple(StringPiece target_triple) {
error.c_str());
}
- return WrapUnique(target->createTargetMachine(
+ return absl::WrapUnique(target->createTargetMachine(
normalized_triple, /*CPU=*/"",
/*Features=*/"", llvm::TargetOptions(), llvm::None));
}
@@ -118,7 +118,7 @@ StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
llvm::LLVMContext llvm_context;
std::unique_ptr<llvm::Module> module_with_serialized_proto =
- MakeUnique<llvm::Module>("embedded_data_module", llvm_context);
+ absl::make_unique<llvm::Module>("embedded_data_module", llvm_context);
EmbeddedProtocolBuffers result;
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index e059f77563..2c9adfe4f0 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -128,11 +128,11 @@ cc_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:shaped_buffer",
- "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/memory",
],
)
@@ -191,6 +191,7 @@ cc_library(
"//tensorflow/core/kernels/data:generator_dataset_op",
"//tensorflow/core/kernels/data:iterator_ops",
"//tensorflow/core/kernels/data:prefetch_dataset_op",
+ "@com_google_absl//absl/memory",
],
)
@@ -235,6 +236,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:variable_ops",
+ "@com_google_absl//absl/memory",
],
)
@@ -283,6 +285,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/memory",
],
alwayslink = 1,
)
@@ -303,6 +306,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc
index a2e6285339..1b1ce78ed2 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/create_xla_launch_op.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@@ -223,8 +224,8 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
- *kernel = MakeUnique<XlaLocalLaunchBase>(&construction, constant_arg_indices,
- resource_arg_indices, function);
+ *kernel = absl::make_unique<XlaLocalLaunchBase>(
+ &construction, constant_arg_indices, resource_arg_indices, function);
return s;
}
diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
index b75ab486b8..7386660762 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op_test.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/create_xla_launch_op.h"
+#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function_testlib.h"
@@ -65,11 +66,11 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
for (const auto& fdef : flib) {
*(proto.add_function()) = fdef;
}
- lib_def_ =
- MakeUnique<FunctionLibraryDefinition>(OpRegistry::Global(), proto);
+ lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
+ OpRegistry::Global(), proto);
OptimizerOptions opts;
- device_mgr_ = MakeUnique<DeviceMgr>(devices_);
- pflr_ = MakeUnique<ProcessFunctionLibraryRuntime>(
+ device_mgr_ = absl::make_unique<DeviceMgr>(devices_);
+ pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index 62007e6115..0ca0f949dc 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -21,18 +21,79 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
// ALGORITHM OVERVIEW
+// ==================
//
// We map every output produced by each node in the TensorFlow graph (including
// control dependence) into an instance of the Predicate class. Instances of
// Predicate denote logical formulas and mapping a node `n` to a predicate
-// `pred` implies that `n` is executed whenver `pred` is true. Then we can
-// deduce mismatching liveness in the inputs to node by comparing the predicate
-// those inputs are mapped to.
+// `pred` implies that `n` is live whenever `pred` is true. Then we can deduce
+// mismatching liveness in the inputs to node by comparing the predicate those
+// inputs are mapped to. The core logic of this pass resides in creating the
+// map from TensorFlow nodes to predicates.
//
-// Loops are handled pessimistically -- we map Merge nodes with backedges to
-// uninterpreted symbols (the same kind we use to represent Switch and _Recv).
-// Predicate equality has to hold over all possible assignments to these
-// uninterpreted symbols.
+//
+// MAPPING NODES TO PREDICATES, MODULO CYCLES
+// ------------------------------------------
+//
+// If we ignore cycles for a moment, computing predicates is fairly
+// straightforward. We traverse the graph in RPO, mapping each node to a
+// predicate based on the predicates its inputs are mapped to. For instance a
+// Merge(X, Y) node will be mapped to OR(PredicateFor(X), PredicateFor(Y)).
+// Roughtly speaking, we abstract interpret each node on the "liveness" domain,
+// where values in the domain represent if a tensor carries a dead signal or
+// not.
+//
+//
+// DEALING WITH CYCLES
+// -------------------
+//
+// We map Merge nodes that are the target of a backedge to AndRecurrence
+// instances. An AndRecurrence with start() = S and step() = X, printed as
+// {S,&,X}, *roughly* represents the infinite list of predicates
+// [S,S&X,S&X&X,S&X&X, ...]. So {S,&,X} can be used to represent the predicate
+// for Merge in a graph like:
+//
+// Init
+// |
+// v
+// Merge <-----------+
+// | |
+// v |
+// Incr |
+// | |
+// v |
+// Switch <- Cond |
+// | |
+// v (oidx: 1) |
+// | |
+// +---------------+
+//
+// Where S is the predicate for Init and X is the predicate that asserts that
+// Cond is true. {S,&,X} states that Merge is live on the first "iteration" iff
+// S is true, live on the second iteration iff "S&X" is true, live on the third
+// iteration iff "S&X&X" is true etc. There is a subtlety here, S&X&X would
+// normally be equivalent to S&X which isn't quite what we want to represent.
+// Instead we want {S,&,X} to denote the infinite list [S, S&X,
+// S&X&X',S&X&X'&X'', ...] where X, X', X'' are predicates that assert Cond is
+// true on iteration 0, 1, 2 respectively. This is made more precise in the
+// comment on the AndRecurrence class.
+//
+// The general algorithm that deals with cycles does two RPO (reverse post
+// order) passes over the graph. On the first pass it assigns a symbolic
+// predicate to merge nodes with backedges. On the second pass it tries to
+// pattern matche the predicates for the backedges of these merges and infer an
+// AndRecurrence for the merge.
+//
+// In other words, we do a pessimistic data flow analysis where the data-flow
+// lattice has two elements, Symbolic and NonSymbolic with Symbolic >
+// NonSymbolic. The lattice has height = 2 so two iterations are sufficient to
+// converge. We don't do an optimistic data flow analysis to make pattern
+// matching easier: if we assigned the predicate of the initial value to the
+// merge during the first pass, on the second pass the backedge may see a
+// simplified value that would be difficult to pattern match.
+//
+// We still use symbolic predicates for merges for which we can't pattern match
+// on the backedge predicate. This is conservatively correct.
namespace tensorflow {
@@ -42,7 +103,7 @@ namespace {
// above.
class Predicate {
public:
- enum class Kind { kAnd, kOr, kNot, kSymbol };
+ enum class Kind { kAnd, kOr, kNot, kAndRecurrence, kSymbol };
virtual string ToString() const = 0;
int64 hash() const { return hash_; }
@@ -51,6 +112,12 @@ class Predicate {
virtual Kind kind() const = 0;
virtual ~Predicate() {}
+ // Invokes func on p and on all of its operands recursively. Does not invoke
+ // `func` on the same Predicate instance twice. Aborts the search if `func`
+ // returns true.
+ template <typename FunctionTy>
+ static void Visit(Predicate* p, const FunctionTy& func);
+
protected:
explicit Predicate(int64 hash) : hash_(hash) {}
@@ -145,10 +212,44 @@ class NotPredicate : public Predicate {
std::array<Predicate*, 1> operands_;
};
+// Represents an infinite list of predicates.
+//
+// An AndRecurrence with start = S and step = X is printed as {S,&,X} and stands
+// for the list of predicates:
+//
+// S, S & GenSym(X,1), S & GenSym(X,1) & GenSym(X,2), ...
+//
+// where GenSym(<expression>, <id>) renames every SymbolPredicate in
+// <expression> by appending <id> to it, in effect creating a "fresh" symbol.
+// This means {P,&,Q} is not equal to "P on the first iteration; P&Q on
+// subsequent iterations".
+class AndRecurrencePredicate : public Predicate {
+ public:
+ explicit AndRecurrencePredicate(Predicate* start, Predicate* step)
+ : Predicate(HashPredicateSequence(Kind::kAndRecurrence, {start, step})),
+ operands_({start, step}) {}
+
+ Predicate* start() const { return operands_[0]; }
+ Predicate* step() const { return operands_[1]; }
+
+ string ToString() const override {
+ return strings::StrCat("{", start()->ToString(), ",&,", step()->ToString(),
+ "}");
+ }
+
+ Kind kind() const override { return Kind::kAndRecurrence; }
+
+ gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
+
+ private:
+ std::array<Predicate*, 2> operands_;
+};
+
// Represents an uninterpreted symbol in a logical predicate.
//
// Two predicates are equivalent iff they are equivalent for all assignments to
-// the symbols contained in them.
+// the symbols contained in them, i.e. predicates are forall qualified over
+// symbols.
class SymbolPredicate : public Predicate {
public:
explicit SymbolPredicate(TensorId tensor_id, bool must_be_true)
@@ -184,6 +285,29 @@ class SymbolPredicate : public Predicate {
}
};
+template <typename FunctionTy>
+/*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) {
+ gtl::FlatSet<Predicate*> visited;
+ std::vector<Predicate*> stack;
+
+ stack.push_back(p);
+ visited.insert(p);
+
+ while (!stack.empty()) {
+ Predicate* current = stack.back();
+ stack.pop_back();
+ bool done = func(current);
+ if (done) {
+ return;
+ }
+ for (Predicate* op : current->GetOperands()) {
+ if (visited.insert(op).second) {
+ stack.push_back(op);
+ }
+ }
+ }
+}
+
// Creates and owns Predicate instances. Simplifies predicates as it creates
// them.
class PredicateFactory {
@@ -209,6 +333,21 @@ class PredicateFactory {
}
}
+ Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step) {
+ auto it = interned_and_rec_instances_.find({start, step});
+ if (it != interned_and_rec_instances_.end()) {
+ return it->second.get();
+ }
+
+ std::unique_ptr<Predicate> new_pred =
+ Make<AndRecurrencePredicate>(start, step);
+ Predicate* new_pred_ptr = new_pred.get();
+ CHECK(interned_and_rec_instances_
+ .emplace(SignatureForAndRec(start, step), std::move(new_pred))
+ .second);
+ return new_pred_ptr;
+ }
+
Predicate* MakeSymbolPredicate(TensorId tensor_id, bool must_be_true) {
SignatureForSymbol signature = {tensor_id, must_be_true};
auto it = interned_symbol_instances_.find(signature);
@@ -249,6 +388,7 @@ class PredicateFactory {
using SignatureForAndOr =
std::pair<Predicate::Kind, gtl::ArraySlice<Predicate*>>;
using SignatureForNot = Predicate*;
+ using SignatureForAndRec = std::pair<Predicate*, Predicate*>;
using SignatureForSymbol = std::pair<SafeTensorId, bool>;
struct HashSignatureForAndOr {
@@ -273,6 +413,8 @@ class PredicateFactory {
interned_and_or_instances_;
gtl::FlatMap<SignatureForNot, std::unique_ptr<Predicate>>
interned_not_instances_;
+ gtl::FlatMap<SignatureForAndRec, std::unique_ptr<Predicate>>
+ interned_and_rec_instances_;
gtl::FlatMap<SignatureForSymbol, std::unique_ptr<Predicate>,
HashSignatureForSymbol>
interned_symbol_instances_;
@@ -353,6 +495,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
: graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
Status Populate();
+ Status PopulateWithReversePostOrder(gtl::ArraySlice<Node*> rpo);
bool HasInputsWithMismatchingDeadness(const Node& node) override;
void Print() const override;
gtl::FlatMap<TensorId, string, TensorId::Hasher> PredicateMapAsString() const;
@@ -361,20 +504,40 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
std::vector<Predicate*> GetIncomingPreds(Node* n, EdgeKind edge_kind);
- void SetPred(Node* n, int output_idx, Predicate* pred) {
- CHECK(
- predicate_map_.insert({TensorId(n->name(), output_idx), pred}).second);
+
+ // Sets the predicate for output `output_idx` of `n` to `pred`. Sets the i'th
+ // bit of `should_revisit` if `pred` is different from the current predicate
+ // for the `output_idx` output of `n`.
+ void SetPredicate(Node* n, int output_idx, Predicate* pred,
+ std::vector<bool>* should_revisit) {
+ auto insert_result =
+ predicate_map_.insert({TensorId(n->name(), output_idx), pred});
+ if (!insert_result.second && insert_result.first->second != pred) {
+ VLOG(4) << "For " << n->name() << ":" << output_idx << " from "
+ << insert_result.first->second->ToString() << " "
+ << insert_result.first->second << " to " << pred->ToString()
+ << " " << pred;
+ insert_result.first->second = pred;
+ if (should_revisit != nullptr) {
+ for (const Edge* e : n->out_edges()) {
+ (*should_revisit)[e->dst()->id()] = true;
+ }
+ }
+ }
}
- void SetPred(Node* n, gtl::ArraySlice<int> output_idxs, Predicate* pred) {
+
+ void SetPredicate(Node* n, gtl::ArraySlice<int> output_idxs, Predicate* pred,
+ std::vector<bool>* should_revisit) {
for (int output_idx : output_idxs) {
- SetPred(n, output_idx, pred);
+ SetPredicate(n, output_idx, pred, should_revisit);
}
}
- Status HandleSwitch(Node* n);
- Status HandleMerge(Node* n);
- Status HandleRecv(Node* n);
- Status HandleGeneric(Node* n);
+ Status HandleSwitch(Node* n, std::vector<bool>* should_revisit);
+ Status HandleMerge(Node* n, std::vector<bool>* should_revisit);
+ Status HandleRecv(Node* n, std::vector<bool>* should_revisit);
+ Status HandleGeneric(Node* n, std::vector<bool>* should_revisit);
+ Status HandleNode(Node* n, std::vector<bool>* should_revisit);
const Graph& graph_;
gtl::FlatMap<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
@@ -397,14 +560,15 @@ std::vector<Predicate*> DeadnessAnalysisImpl::GetIncomingPreds(
if (should_process) {
auto it = predicate_map_.find(InputEdgeToTensorId(in_edge));
- CHECK(it != predicate_map_.end());
+ CHECK(it != predicate_map_.end()) << n->name();
incoming_preds.push_back(it->second);
}
}
return incoming_preds;
}
-Status DeadnessAnalysisImpl::HandleSwitch(Node* n) {
+Status DeadnessAnalysisImpl::HandleSwitch(Node* n,
+ std::vector<bool>* should_revisit) {
std::vector<Predicate*> input_preds =
GetIncomingPreds(n, EdgeKind::kDataAndControl);
const Edge* pred_edge;
@@ -416,84 +580,252 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n) {
// Output 0 is alive iff all inputs are alive and the condition is false.
input_preds.push_back(false_switch);
- SetPred(n, 0, predicate_factory_.MakeAndPredicate(input_preds));
+ SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds),
+ should_revisit);
input_preds.pop_back();
// Output 1 is alive iff all inputs are alive and the condition is true.
input_preds.push_back(true_switch);
- SetPred(n, 1, predicate_factory_.MakeAndPredicate(input_preds));
+ SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds),
+ should_revisit);
input_preds.pop_back();
- // Control is alive iff any inputs are alive.
- SetPred(n, Graph::kControlSlot,
- predicate_factory_.MakeAndPredicate(input_preds));
+ // Control is alive iff all inputs are alive.
+ SetPredicate(n, Graph::kControlSlot,
+ predicate_factory_.MakeAndPredicate(input_preds),
+ should_revisit);
return Status::OK();
}
-Status DeadnessAnalysisImpl::HandleMerge(Node* n) {
+namespace {
+const Edge* FindUniqueBackedge(Node* merge) {
+ CHECK(merge->IsMerge());
+ const Edge* result = nullptr;
+ for (const Edge* e : merge->in_edges()) {
+ if (e->src()->IsNextIteration()) {
+ CHECK_EQ(result, nullptr)
+ << "Multiple backedges to " << merge->DebugString();
+ result = e;
+ }
+ }
+ return result;
+}
+
+// If `backedge_predicate` is equal to `symbolic_predicate` & Step where Step
+// does not contain `symbolic_predicate` as an inner (not top-level) operand
+// then returns `Step`. Otherwise returns nullptr.
+Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory,
+ Predicate* symbolic_predicate,
+ Predicate* backedge_predicate) {
+ CHECK(dynamic_cast<SymbolPredicate*>(symbolic_predicate));
+ if (backedge_predicate->kind() != Predicate::Kind::kAnd) {
+ return nullptr;
+ }
+
+ std::vector<Predicate*> and_ops;
+ gtl::ArraySlice<Predicate*> recurrent_pred_ops =
+ backedge_predicate->GetOperands();
+
+ bool found_sym = false;
+ for (Predicate* and_op : recurrent_pred_ops) {
+ // We want the `symbol_predicate` to be the one of the operands of
+ // `backedge_predicate`,
+ if (and_op == symbolic_predicate) {
+ found_sym = true;
+ continue;
+ }
+
+ // but we don't want it to be present anywhere else in the formula. E.g. we
+ // don't want the recurrent predicate to be
+ // symbol_predicate&(X|symbol_predicate).
+ bool found_sym_as_inner_operand = false;
+ auto has_self_as_inner_operand = [&](Predicate* p) {
+ if (p == symbolic_predicate) {
+ found_sym_as_inner_operand = true;
+ return true; // Stop searching, we're done.
+ }
+
+ // Continue searching.
+ return false;
+ };
+
+ Predicate::Visit(and_op, has_self_as_inner_operand);
+ if (found_sym_as_inner_operand) {
+ return nullptr;
+ }
+ and_ops.push_back(and_op);
+ }
+
+ return found_sym ? predicate_factory->MakeAndPredicate(and_ops) : nullptr;
+}
+} // namespace
+
+Status DeadnessAnalysisImpl::HandleMerge(Node* n,
+ std::vector<bool>* should_revisit) {
// Merge ignores deadness of its control inputs. A merge that isn't the
- // target of a backedge has is alive iff any of its data inputs are. We treat
- // the liveness of a merge that is the target of a backedge symbolically.
+ // target of a backedge has is alive iff any of its data inputs are. The
+ // liveness of a merge that is the target of a backedge can sometimes be
+ // represented using a AndRecurrencePredicate. If neither apply, we represent
+ // the liveness of the merge symbolically.
+
+ bool has_unvisited_backedge = false;
+ for (const Edge* e : n->in_edges()) {
+ if (!e->IsControlEdge() && e->src()->IsNextIteration()) {
+ has_unvisited_backedge |= !predicate_map_.count(InputEdgeToTensorId(e));
+ }
+ }
+
+ auto it = predicate_map_.find(TensorId(n->name(), 0));
+ if (it == predicate_map_.end()) {
+ if (has_unvisited_backedge) {
+ // We're visiting this merge for the first time and it has an unvisited
+ // backedge.
+ Predicate* input_data_pred = predicate_factory_.MakeSymbolPredicate(
+ TensorId(n->name(), 0), /*must_be_true=*/false);
+ SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
+ should_revisit);
+ return Status::OK();
+ }
- bool has_backedge = std::any_of(
- n->in_edges().begin(), n->in_edges().end(), [](const Edge* e) {
- return !e->IsControlEdge() && e->src()->IsNextIteration();
- });
+ // We're visiting this merge for the first time and it is a acyclic merge.
+ Predicate* input_data_pred = predicate_factory_.MakeOrPredicate(
+ GetIncomingPreds(n, EdgeKind::kDataOnly));
+ SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
+ should_revisit);
+ return Status::OK();
+ }
- Predicate* input_data_pred =
- has_backedge ? predicate_factory_.MakeSymbolPredicate(
- TensorId(n->name(), 0), /*must_be_true=*/false)
- : predicate_factory_.MakeOrPredicate(
- GetIncomingPreds(n, EdgeKind::kDataOnly));
+ if (it->second->kind() == Predicate::Kind::kSymbol) {
+ // Last time we visited this merge we only got a symbolic predicate because
+ // of an unvisited backedge. Try to pattern match the predicate expression
+ // for that backedge (which should be visited now) into an and recurrence
+ // for the merge node.
+ if (const Edge* unique_backedge = FindUniqueBackedge(n)) {
+ if (Predicate* step = DeduceStepPredicate(
+ &predicate_factory_, it->second,
+ predicate_map_[InputEdgeToTensorId(unique_backedge)])) {
+ // If the predicate for the backedge is "Sym&X" where "Sym" is the
+ // predicate for the merge then the merge has predicate {S,&,X} where S
+ // is the predicate for the merge ignoring the backedge.
+ std::vector<Predicate*> non_recurrent_inputs;
+ for (const Edge* e : n->in_edges()) {
+ if (e != unique_backedge) {
+ non_recurrent_inputs.push_back(
+ predicate_map_[InputEdgeToTensorId(e)]);
+ }
+ }
- SetPred(n, {0, 1, Graph::kControlSlot}, input_data_pred);
+ Predicate* start =
+ predicate_factory_.MakeOrPredicate(non_recurrent_inputs);
+ Predicate* and_rec =
+ predicate_factory_.MakeAndRecurrencePredicate(start, step);
+ SetPredicate(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit);
+ return Status::OK();
+ }
+ }
+ }
return Status::OK();
}
-Status DeadnessAnalysisImpl::HandleRecv(Node* n) {
+Status DeadnessAnalysisImpl::HandleRecv(Node* n,
+ std::vector<bool>* should_revisit) {
// In addition to being alive or dead based on the inputs, a _Recv can also
// acquire a dead signal from a _Send.
std::vector<Predicate*> input_preds =
GetIncomingPreds(n, EdgeKind::kDataAndControl);
input_preds.push_back(predicate_factory_.MakeSymbolPredicate(
TensorId(n->name(), 0), /*must_be_true=*/false));
- SetPred(n, {0, Graph::kControlSlot},
- predicate_factory_.MakeAndPredicate(input_preds));
+ SetPredicate(n, {0, Graph::kControlSlot},
+ predicate_factory_.MakeAndPredicate(input_preds),
+ should_revisit);
return Status::OK();
}
-Status DeadnessAnalysisImpl::HandleGeneric(Node* n) {
+Status DeadnessAnalysisImpl::HandleGeneric(Node* n,
+ std::vector<bool>* should_revisit) {
// Generally nodes are alive iff all their inputs are alive.
Predicate* pred = predicate_factory_.MakeAndPredicate(
GetIncomingPreds(n, EdgeKind::kDataAndControl));
for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) {
- SetPred(n, output_idx, pred);
+ SetPredicate(n, output_idx, pred, should_revisit);
+ }
+ SetPredicate(n, Graph::kControlSlot, pred, should_revisit);
+ return Status::OK();
+}
+
+Status DeadnessAnalysisImpl::HandleNode(Node* n,
+ std::vector<bool>* should_revisit) {
+ if (n->IsSwitch()) {
+ TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit));
+ } else if (n->IsMerge()) {
+ TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit));
+ } else if (n->IsControlTrigger()) {
+ SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(),
+ nullptr);
+ } else if (n->IsRecv() || n->IsHostRecv()) {
+ TF_RETURN_IF_ERROR(HandleRecv(n, should_revisit));
+ } else if (n->IsNextIteration()) {
+ TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit));
+ } else {
+ TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit));
}
- SetPred(n, Graph::kControlSlot, pred);
return Status::OK();
}
Status DeadnessAnalysisImpl::Populate() {
std::vector<Node*> rpo;
- GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/{},
+ GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/NodeComparatorName(),
/*edge_filter=*/[](const Edge& edge) {
return !edge.src()->IsNextIteration();
});
+ return PopulateWithReversePostOrder(rpo);
+}
+Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
+ gtl::ArraySlice<Node*> rpo) {
// This an abstract interpretation over the deadness propagation semantics of
// the graph executor.
+ //
+ // We iterate over the graph twice, each time in RPO. On the first iteration
+ // merge nodes with backedges are mapped to symbolic predicates. On the
+ // second iteration we use the predicates assigned to the backedges in the
+ // previous iteration to infer a more precise predicate for the backedge merge
+ // nodes and all the nodes that transitively use it.
+ //
+ // We don't track the output indices for should_revisit. Instead, putting a
+ // node in `should_revisit` denotes that the deadness flowing out from any
+ // output from said node may have changed. This is fine; only switches
+ // propagate different deadness along different output edges, and since the
+ // delta is solely due to the input *values* (and not input deadness), the
+ // delta should not change in the second iteration.
+ std::vector<bool> should_revisit;
+ should_revisit.resize(graph_.num_node_ids());
for (Node* n : rpo) {
- if (n->IsSwitch()) {
- TF_RETURN_IF_ERROR(HandleSwitch(n));
- } else if (n->IsMerge()) {
- TF_RETURN_IF_ERROR(HandleMerge(n));
- } else if (n->IsControlTrigger()) {
- SetPred(n, Graph::kControlSlot, predicate_factory_.MakeTrue());
- } else if (n->IsRecv() || n->IsHostRecv()) {
- TF_RETURN_IF_ERROR(HandleRecv(n));
- } else {
- TF_RETURN_IF_ERROR(HandleGeneric(n));
+ VLOG(4) << "Visiting " << n->name();
+ TF_RETURN_IF_ERROR(HandleNode(n, /*should_revisit=*/nullptr));
+ if (n->IsNextIteration()) {
+ // If this is a backedge for a merge node then remember to reprocess the
+ // merge the next time we run.
+ for (const Edge* e : n->out_edges()) {
+ if (e->dst()->IsMerge()) {
+ should_revisit[e->dst()->id()] = true;
+ }
+ }
+ }
+ }
+
+ for (Node* n : rpo) {
+ // The nodes added to should_revisit in the previous loop need to be
+ // revisited now. Reprocesing these initial nodes may add *their* consumers
+ // to should_revisit, and these newly added nodes will also be processed by
+ // this very same loop. Since we're traversing the graph in reverse post
+ // order (producers before consumers) and HandleNode(n) can only ever add
+ // n's consumers to should_revisit, we won't "miss" an addition to
+ // should_revisit.
+ if (should_revisit[n->id()]) {
+ VLOG(4) << "Revisiting " << n->name();
+ TF_RETURN_IF_ERROR(HandleNode(n, &should_revisit));
}
}
@@ -589,6 +921,15 @@ Status ComputePredicates(const Graph& graph,
*out_predicate_map = impl.PredicateMapAsString();
return Status::OK();
}
+
+Status ComputePredicates(const Graph& graph,
+ gtl::ArraySlice<Node*> reverse_post_order,
+ PredicateMapTy* out_predicate_map) {
+ DeadnessAnalysisImpl impl(&graph);
+ TF_RETURN_IF_ERROR(impl.PopulateWithReversePostOrder(reverse_post_order));
+ *out_predicate_map = impl.PredicateMapAsString();
+ return Status::OK();
+}
} // namespace deadness_analysis_internal
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h
index cdef405110..401d6e406a 100644
--- a/tensorflow/compiler/jit/deadness_analysis_internal.h
+++ b/tensorflow/compiler/jit/deadness_analysis_internal.h
@@ -26,6 +26,14 @@ namespace deadness_analysis_internal {
// testing purposes only.
using PredicateMapTy = gtl::FlatMap<TensorId, string, TensorId::Hasher>;
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map);
+
+// Returns a map describing the predicate each Tensor was mapped to. For
+// testing purposes only. Makes deadness analysis visit the graph in the order
+// specified in `reverse_post_order` which must be a valid RPO for the graph
+// minus NextIteration->Merge edges.
+Status ComputePredicates(const Graph& graph,
+ gtl::ArraySlice<Node*> reverse_post_order,
+ PredicateMapTy* out_predicate_map);
} // namespace deadness_analysis_internal
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc
index 6881095b51..cc9f102398 100644
--- a/tensorflow/compiler/jit/deadness_analysis_test.cc
+++ b/tensorflow/compiler/jit/deadness_analysis_test.cc
@@ -38,6 +38,9 @@ limitations under the License.
namespace tensorflow {
namespace {
+using deadness_analysis_internal::ComputePredicates;
+using deadness_analysis_internal::PredicateMapTy;
+
Status AnalyzeDeadness(Graph* graph,
std::unique_ptr<DeadnessAnalysis>* result) {
FixupSourceAndSinkEdges(graph);
@@ -51,13 +54,73 @@ ops::Switch CreateSwitch(const Scope& root, const string& prefix) {
return ops::Switch(root.WithOpName(prefix + "/switch"), value, predicate);
}
-Output CreateInductionVariable(const Scope& root, const string& prefix,
- const string& frame_name, int32 init) {
- Output initial_value = ops::Const(root.WithOpName(prefix + "/init"), init);
+TensorId ControlOutputFor(const Output& o) {
+ return {o.node()->name(), Graph::kControlSlot};
+}
+
+void VLogGraphIfAsked(const Graph& graph) {
+ if (VLOG_IS_ON(3)) {
+ GraphDef graph_def;
+ graph.ToGraphDef(&graph_def);
+ string serialized;
+ ::tensorflow::protobuf::TextFormat::PrintToString(graph_def, &serialized);
+ LOG(INFO) << serialized;
+ }
+}
+
+struct InductionVarInfo {
+ Output induction_var;
+ Output loop_cond;
+};
+
+// Creates an induction variable with the following structure (simplified for
+// brevity):
+//
+// +---------------+
+// | initial_value |
+// +---------------+
+// |
+// |
+// v
+// +---------------+
+// | Enter |
+// +---------------+
+// |
+// |
+// v
+// +---------------+
+// +> | Merge | -+
+// | +---------------+ |
+// | | |
+// | | |
+// | v |
+// | +---------------+ |
+// | | LessThan10 | |
+// | +---------------+ |
+// | | |
+// | | |
+// | v |
+// | +---------------+ |
+// +----+- | Switch | <+
+// | | +---------------+
+// | | |
+// | | |
+// | | v
+// | | +---------------+
+// | +- | AddOne |
+// | +---------------+
+// | +---------------+
+// +-----> | Exit |
+// +---------------+
+InductionVarInfo CreateInductionVariable(const Scope& root,
+ const string& prefix,
+ const string& frame_name,
+ const Output& initial_value) {
Output enter_initial_value = ops::internal::Enter(
root.WithOpName(prefix + "/enter"), initial_value, frame_name);
- ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_initial_value});
+ ops::Merge iv(root.WithOpName(prefix + "/iv"),
+ {enter_initial_value, enter_initial_value});
Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1);
Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10);
Output loop_cond_expr =
@@ -66,16 +129,84 @@ Output CreateInductionVariable(const Scope& root, const string& prefix,
ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr);
ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output);
- Output iv_next =
- ops::Add(root.WithOpName(prefix + "/ivnext"), iv.output, increment_by);
+ Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"),
+ latch.output_true, increment_by);
Output next_iteration =
- ops::NextIteration(root.WithOpName(prefix + "next_iteration"), iv_next);
+ ops::NextIteration(root.WithOpName(prefix + "/next_iteration"), iv_next);
- root.graph()->AddEdge(next_iteration.node(), 0, iv.output.node(), 1);
+ CHECK(root.graph()
+ ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1)
+ .ok());
root.graph()->AddControlEdge(iv.output.node(), increment_by.node());
root.graph()->AddControlEdge(iv.output.node(), final_value.node());
- return iv.output;
+ return {iv.output, loop_cond};
+}
+
+InductionVarInfo CreateInductionVariable(const Scope& root,
+ const string& prefix,
+ const string& frame_name, int32 init) {
+ return CreateInductionVariable(
+ root, prefix, frame_name,
+ ops::Const(root.WithOpName(prefix + "/init"), init));
+}
+
+// Creates an induction variable with the following structure:
+//
+// +---------------+
+// | initial_value |
+// +---------------+
+// |
+// |
+// v
+// +---------------+
+// | Enter |
+// +---------------+
+// |
+// |
+// v
+// +---------------+
+// | Merge | <+
+// +---------------+ |
+// | |
+// | |
+// v |
+// +-----------+ +---------------+ |
+// | loop_cond | --> | Switch | -+
+// +-----------+ +---------------+
+// |
+// |
+// v
+// +---------------+
+// | Exit |
+// +---------------+
+struct DependentInductionVar {
+ Output induction_var;
+ ops::Switch latch;
+};
+
+DependentInductionVar CreateDependentLoopInvariantValue(
+ const Scope& root, const string& prefix, const string& frame_name,
+ const Output& loop_cond, const Output& value) {
+ Output enter_value = ops::internal::Enter(root.WithOpName(prefix + "/enter"),
+ value, frame_name);
+ ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value});
+ ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
+ ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output);
+ Output next_iteration = ops::NextIteration(
+ root.WithOpName(prefix + "/next_iteration"), latch.output_true);
+ CHECK(root.graph()
+ ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1)
+ .ok());
+ return {iv.output, latch};
+}
+
+DependentInductionVar CreateDependentLoopInvariantValue(
+ const Scope& root, const string& prefix, const string& frame_name,
+ const Output& loop_cond, int32 value) {
+ return CreateDependentLoopInvariantValue(
+ root, prefix, frame_name, loop_cond,
+ ops::Const(root.WithOpName(prefix + "/init"), value));
}
TEST(DeadnessAnalysisTest, BasicPositive) {
@@ -337,21 +468,224 @@ TEST(DeadnessAnalysisTest, HostRecv) {
TEST(DeadnessAnalysisTest, Loop) {
Scope root = Scope::NewRootScope().ExitOnError();
- Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0);
- Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0);
- Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1);
+ Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0).induction_var;
+ Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0).induction_var;
+ Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1).induction_var;
Output add0 = ops::Add(root.WithOpName("add0"), iv0, iv1);
Output add1 = ops::Add(root.WithOpName("add1"), iv1, iv2);
- std::unique_ptr<DeadnessAnalysis> result;
- TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
-
// NB! iv0 and iv1 are equivalent and a smarter deadness analysis would have
// noticed that. Today we are pessimistic here because we assign an
// uninterpreted symbol to merges with backedges.
- EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node()));
- EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node()));
+ VLogGraphIfAsked(*root.graph());
+
+ {
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node()));
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node()));
+ }
+ {
+ PredicateMapTy predicate_map;
+ TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
+
+ // In theory we should be able to tell that iv0/cond:0 and iv1/cond:0
+ // produce the same deadness. But we're not that smart today.
+ EXPECT_EQ(predicate_map[ControlOutputFor(iv0)], "{#true,&,*iv0/cond:0}");
+ EXPECT_EQ(predicate_map[ControlOutputFor(iv1)], "{#true,&,*iv1/cond:0}");
+ EXPECT_EQ(predicate_map[ControlOutputFor(iv2)], "{#true,&,*iv2/cond:0}");
+ EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
+ "({#true,&,*iv1/cond:0} & {#true,&,*iv0/cond:0})");
+ EXPECT_EQ(predicate_map[ControlOutputFor(add1)],
+ "({#true,&,*iv1/cond:0} & {#true,&,*iv2/cond:0})");
+ }
+}
+
+TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ InductionVarInfo iv = CreateInductionVariable(root, "iv0", "frame", 0);
+ Output dependent_iv0 =
+ CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0)
+ .induction_var;
+ Output dependent_iv1 =
+ CreateDependentLoopInvariantValue(root, "div1", "frame", iv.loop_cond, 0)
+ .induction_var;
+ Output add0 = ops::Add(root.WithOpName("add0"), dependent_iv0, dependent_iv1);
+
+ VLogGraphIfAsked(*root.graph());
+
+ {
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node()));
+ }
+ {
+ PredicateMapTy predicate_map;
+ TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
+
+ EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
+ "{#true,&,*iv0/cond:0}");
+ EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)],
+ "{#true,&,(*iv0/cond:0 & iv0/iv:0)}");
+ EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)],
+ "{#true,&,(*iv0/cond:0 & iv0/iv:0)}");
+ EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
+ "{#true,&,(*iv0/cond:0 & iv0/iv:0)}");
+ }
+}
+
+TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
+ // Create a merge that "looks like" a loop but isn't really. It has a value
+ // that does not depend on the merge on its backedge.
+ Scope root = Scope::NewRootScope().ExitOnError();
+ InductionVarInfo iv = CreateInductionVariable(root, "iv0", "frame", 0);
+ DependentInductionVar dependent_iv =
+ CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0);
+ FixupSourceAndSinkEdges(root.graph());
+
+ // To make deadness analysis think that dependent_iv is a loop we need an RPO
+ // that visits the merge before the backedge. This is a legal RPO for
+ // deadness analysis since it ignores NextIteration->Merge edges during RPO.
+ // Right now dependent_iv has an edge from Merge to NextIteration so do the
+ // RPO with this edge in place. Then remove this edge to get our test case.
+ std::vector<Node*> rpo;
+ GetReversePostOrder(*root.graph(), &rpo, /*stable_comparator=*/{},
+ /*edge_filter=*/[](const Edge& edge) {
+ return !edge.src()->IsNextIteration();
+ });
+ TF_ASSERT_OK(root.graph()->UpdateEdge(
+ iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0));
+
+ VLogGraphIfAsked(*root.graph());
+
+ {
+ PredicateMapTy predicate_map;
+ TF_ASSERT_OK(ComputePredicates(*root.graph(), rpo, &predicate_map));
+
+ EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
+ "div0/iv:0");
+ }
+}
+
+TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ InductionVarInfo iv_outer =
+ CreateInductionVariable(root, "iv_outer", "frame", 0);
+ ops::Switch inner_value(root.WithOpName("outer_is_live"),
+ ops::Const(root.WithOpName("constant"), 5),
+ iv_outer.loop_cond);
+ InductionVarInfo iv_inner = CreateInductionVariable(
+ root, "iv_inner", "frame",
+ ops::internal::Enter(root.WithOpName("iv_inner/enter"),
+ inner_value.output_true, "frame_inner"));
+
+ Output dependent_outer_iv0 =
+ CreateDependentLoopInvariantValue(root, "dependent_outer_iv0", "frame",
+ iv_outer.loop_cond, 0)
+ .induction_var;
+ Output dependent_outer_iv1 =
+ CreateDependentLoopInvariantValue(root, "dependent_outer_iv1", "frame",
+ iv_outer.loop_cond, 0)
+ .induction_var;
+
+ Output dependent_inner_iv0 =
+ CreateDependentLoopInvariantValue(root, "dependent_inner_iv0", "frame",
+ iv_inner.loop_cond, dependent_outer_iv0)
+ .induction_var;
+ Output dependent_inner_iv1 =
+ CreateDependentLoopInvariantValue(root, "dependent_inner_iv1", "frame",
+ iv_inner.loop_cond, dependent_outer_iv1)
+ .induction_var;
+
+ Output add0 = ops::Add(root.WithOpName("add0"), dependent_inner_iv0,
+ dependent_inner_iv1);
+
+ VLogGraphIfAsked(*root.graph());
+
+ {
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node()));
+ }
+ {
+ PredicateMapTy predicate_map;
+ TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
+
+ EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
+ "{#true,&,*iv_outer/cond:0}");
+ EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)],
+ "{(*iv_outer/cond:0 & {#true,&,*iv_outer/cond:0}),&,"
+ "*iv_inner/cond:0}");
+
+ EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)],
+ "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&,"
+ "(*iv_inner/cond:0 & iv_inner/iv:0)}");
+ EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
+ "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&,"
+ "(*iv_inner/cond:0 & iv_inner/iv:0)}");
+ EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
+ "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&,"
+ "(*iv_inner/cond:0 & iv_inner/iv:0)}");
+ }
+}
+
+TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ InductionVarInfo iv_outer_0 =
+ CreateInductionVariable(root, "iv_outer_0", "frame", 0);
+ ops::Switch inner_value_0(root.WithOpName("outer_0_is_live"),
+ ops::Const(root.WithOpName("constant"), 5),
+ iv_outer_0.loop_cond);
+ InductionVarInfo iv_inner_0 = CreateInductionVariable(
+ root, "iv_inner_0", "frame",
+ ops::internal::Enter(root.WithOpName("iv_inner_0/enter"),
+ inner_value_0.output_true, "frame_inner"));
+
+ InductionVarInfo iv_outer_1 =
+ CreateInductionVariable(root, "iv_outer_1", "frame", 1);
+ ops::Switch inner_init_value_1(root.WithOpName("outer_1_is_live"),
+ ops::Const(root.WithOpName("constant"), 5),
+ iv_outer_1.loop_cond);
+ InductionVarInfo iv_inner_1 = CreateInductionVariable(
+ root, "iv_inner_1", "frame",
+ ops::internal::Enter(root.WithOpName("iv_inner_1/enter"),
+ inner_init_value_1.output_true, "frame_inner"));
+ Output add0 = ops::Add(root.WithOpName("add0"), iv_inner_0.induction_var,
+ iv_inner_1.induction_var);
+
+ VLogGraphIfAsked(*root.graph());
+
+ {
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node()));
+ }
+
+ {
+ PredicateMapTy predicate_map;
+ TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
+
+ EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer_0.induction_var)],
+ "{#true,&,*iv_outer_0/cond:0}");
+ EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner_0.induction_var)],
+ "{(*iv_outer_0/cond:0 & {#true,&,*iv_outer_0/cond:0}),&,"
+ "*iv_inner_0/cond:0}");
+ EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer_1.induction_var)],
+ "{#true,&,*iv_outer_1/cond:0}");
+ EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner_1.induction_var)],
+ "{(*iv_outer_1/cond:0 & {#true,&,*iv_outer_1/cond:0}),&,"
+ "*iv_inner_1/cond:0}");
+ EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
+ "({(*iv_outer_1/cond:0 & {#true,&,*iv_outer_1/cond:0}),&,"
+ "*iv_inner_1/cond:0} & "
+ "{(*iv_outer_0/cond:0 & {#true,&,*iv_outer_0/cond:0}),&,"
+ "*iv_inner_0/cond:0})");
+ }
}
TEST(DeadnessAnalysisTest, ControlInputs) {
@@ -454,9 +788,8 @@ TEST(DeadnessAnalysisTest, RecvVsSwitchText) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
- deadness_analysis_internal::PredicateMapTy predicate_map;
- TF_ASSERT_OK(deadness_analysis_internal::ComputePredicates(*root.graph(),
- &predicate_map));
+ PredicateMapTy predicate_map;
+ TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
TensorId logical_and_output_0 = {logical_and.node()->name(),
Graph::kControlSlot};
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 90d5d56998..3e41e44ba9 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -39,7 +39,9 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
@@ -141,6 +143,10 @@ bool IsCompilableCall(const NodeDef& call_def,
<< ": could not instantiate: " << status;
return false;
}
+
+ auto release_handle_on_return = gtl::MakeCleanup(
+ [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
+
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
CHECK(fbody);
const FunctionDef& fdef = fbody->fdef;
@@ -412,6 +418,31 @@ Status FindCompilationCandidates(
return Status::OK();
}
+// Determine the global jit level which is ON if either the
+// GraphOptimizationPassOptions has the jit ON, or if the --tf_xla_auto_jit flag
+// is true.
+OptimizerOptions::GlobalJitLevel GetGlobalJitLevel(
+ const GraphOptimizationPassOptions& options) {
+ OptimizerOptions::GlobalJitLevel global_jit_level =
+ options.session_options->config.graph_options()
+ .optimizer_options()
+ .global_jit_level();
+ if (global_jit_level == OptimizerOptions::DEFAULT) {
+ // To set compilation to be on by default, change the following line.
+ global_jit_level = OptimizerOptions::OFF;
+ }
+ legacy_flags::MarkForCompilationPassFlags* flags =
+ legacy_flags::GetMarkForCompilationPassFlags();
+ if (flags->tf_xla_auto_jit == -1 ||
+ (1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) {
+ // If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides
+ // the setting in ConfigProto.
+ global_jit_level =
+ static_cast<OptimizerOptions::GlobalJitLevel>(flags->tf_xla_auto_jit);
+ }
+ return global_jit_level;
+}
+
struct Cluster {
// Identifies the node that represents this cluster in the cycle detection
// graph.
@@ -434,22 +465,9 @@ Status MarkForCompilationPass::Run(
// TODO(phawkins): precompute the "GetCompilationDevice" properties of each
// device ahead of time.
OptimizerOptions::GlobalJitLevel global_jit_level =
- options.session_options->config.graph_options()
- .optimizer_options()
- .global_jit_level();
- if (global_jit_level == OptimizerOptions::DEFAULT) {
- // To set compilation to be on by default, change the following line.
- global_jit_level = OptimizerOptions::OFF;
- }
+ GetGlobalJitLevel(options);
legacy_flags::MarkForCompilationPassFlags* flags =
legacy_flags::GetMarkForCompilationPassFlags();
- if (flags->tf_xla_auto_jit == -1 ||
- (1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) {
- // If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides
- // the setting in ConfigProto.
- global_jit_level =
- static_cast<OptimizerOptions::GlobalJitLevel>(flags->tf_xla_auto_jit);
- }
bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
bool fusion_only = flags->tf_xla_fusion_only;
@@ -517,9 +535,9 @@ Status MarkForCompilationPass::Run(
bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU;
bool should_compile =
(ignore_registration || registration->enable_jit_by_default) &&
- global_jit_level > 0;
+ global_jit_level != OptimizerOptions::OFF;
if (!should_compile) {
- if (global_jit_level <= 0) {
+ if (global_jit_level == OptimizerOptions::OFF) {
VLOG(2) << "Rejecting " << node->name() << ": global jit disabled.";
} else {
VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled.";
@@ -530,6 +548,60 @@ Status MarkForCompilationPass::Run(
return RunImpl(options, is_compilable);
}
+static string RatioToString(int numerator, int denominator) {
+ return strings::Printf("%d / %d (%.2f%%)", numerator, denominator,
+ (100.0 * numerator) / denominator);
+}
+
+static void VLogClusteringSummary(const Graph& g) {
+ if (!VLOG_IS_ON(2)) {
+ return;
+ }
+
+ std::map<StringPiece, int> cluster_name_to_size;
+ std::map<StringPiece, std::map<StringPiece, int>>
+ cluster_name_to_op_histogram;
+ std::map<StringPiece, int> unclustered_op_histogram;
+ int clustered_node_count = 0;
+
+ for (Node* n : g.nodes()) {
+ gtl::optional<StringPiece> cluster_name = GetXlaClusterForNode(*n);
+ if (cluster_name) {
+ clustered_node_count++;
+ cluster_name_to_size[*cluster_name]++;
+ cluster_name_to_op_histogram[*cluster_name][n->type_string()]++;
+ } else {
+ unclustered_op_histogram[n->type_string()]++;
+ }
+ }
+
+ int unclustered_node_count = g.num_nodes() - clustered_node_count;
+
+ VLOG(2) << "*** Clustering info for graph of size " << g.num_nodes();
+ VLOG(2) << " Built " << cluster_name_to_size.size() << " clusters, size "
+ << RatioToString(clustered_node_count, g.num_nodes());
+
+ for (const auto& cluster_name_size_pair : cluster_name_to_size) {
+ StringPiece cluster_name = cluster_name_size_pair.first;
+ int size = cluster_name_size_pair.second;
+ VLOG(2) << " " << cluster_name << " "
+ << RatioToString(size, g.num_nodes());
+ for (const auto& op_count_pair :
+ cluster_name_to_op_histogram[cluster_name]) {
+ VLOG(3) << " " << op_count_pair.first << ": " << op_count_pair.second
+ << " instances";
+ }
+ }
+
+ if (!unclustered_op_histogram.empty()) {
+ VLOG(2) << " Unclustered nodes: "
+ << RatioToString(unclustered_node_count, g.num_nodes());
+ for (const auto& pair : unclustered_op_histogram) {
+ VLOG(3) << " " << pair.first << ": " << pair.second << " instances";
+ }
+ }
+}
+
// Is 'node' an operator that consumes only the shape of its input, not the
// data itself?
static bool IsShapeConsumerOp(const Node& node) {
@@ -577,6 +649,8 @@ Status MarkForCompilationPass::RunImpl(
worklist.push_back(&clusters[node->id()]);
}
+ OptimizerOptions::GlobalJitLevel global_jit_level =
+ GetGlobalJitLevel(options);
legacy_flags::MarkForCompilationPassFlags* flags =
legacy_flags::GetMarkForCompilationPassFlags();
@@ -616,13 +690,15 @@ Status MarkForCompilationPass::RunImpl(
}
// Look for an _XlaScope on both nodes. If both nodes have a
// scope and the scopes do not match, do not cluster along this
- // edge. If even one of the nodes lacks an _XlaScope attribute,
+ // edge. This restriction is overridden if the global_jit_level is ON. If
+ // even one of the nodes lacks an _XlaScope attribute,
// then it is treated as a "bridge" and a cluster may be created
// along it. We may want to restrict this behavior to require
// all nodes marked with _XlaCompile=true to also have a
// _XlaScope property set (and raise an error otherwise); but
// for now we don't do this.
- if (GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() &&
+ if (global_jit_level == OptimizerOptions::OFF &&
+ GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() &&
GetNodeAttr(node_to->attrs(), kXlaScopeAttr, &to_scope).ok() &&
from_scope != to_scope) {
continue;
@@ -718,6 +794,9 @@ Status MarkForCompilationPass::RunImpl(
dump_graph::DumpGraphToFile("mark_for_compilation", **options.graph,
options.flib_def);
}
+
+ VLogClusteringSummary(*graph);
+
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index a780d4a936..9d7ac0d609 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -199,7 +199,7 @@ TEST(XlaCompilationTest, FunctionCalls) {
{}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
FunctionDef noinline = compilable;
noinline.mutable_signature()->set_name("NoInlineFn");
- AddAttr("_noinline", bool(true), noinline.mutable_attr());
+ AddAttr("_noinline", static_cast<bool>(true), noinline.mutable_attr());
FunctionDefLibrary flib;
*flib.add_function() = compilable;
@@ -372,6 +372,44 @@ TEST(XlaCompilationTest, Loops) {
EXPECT_EQ(0, clusters.size());
}
+TEST(XlaCompilationTest, CyclesWithAllDifferentScopesGlobalJitOverridden) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ GraphDef graphdef;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* a = ops::SourceOp("Const", builder.opts()
+ .WithName("A")
+ .WithAttr("dtype", DT_FLOAT)
+ .WithAttr("value", Tensor())
+ .WithAttr(kXlaScopeAttr, "ScopeA"));
+ Node* b = ops::UnaryOp(
+ "Relu", a,
+ builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
+ ops::BinaryOp(
+ "MatMul", a, b,
+ builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC"));
+ TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
+ }
+
+ FunctionDefLibrary flib;
+ FunctionLibraryDefinition flib_def(graph->op_registry(), flib);
+ SessionOptions session_options;
+ session_options.config.mutable_graph_options()
+ ->mutable_optimizer_options()
+ ->set_global_jit_level(OptimizerOptions::ON_2);
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
+ &graph, &flib_def, &session_options));
+ auto clusters = GetClusters(*graph);
+
+ // The computation is: C = A + relu(A)
+ // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
+ // In this case, the GlobalJitLevel overrides the scopes to cluster while
+ // ignoring scopes.
+ EXPECT_EQ(3, clusters.size());
+ EXPECT_EQ(clusters["A"], clusters["B"]);
+ EXPECT_EQ(clusters["A"], clusters["C"]);
+}
+
TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
index a84b82e479..65669877f7 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
@@ -14,10 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
+#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
- std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) {
+ std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
+ SessionOptions* session_options) {
// Assign all nodes to the CPU device.
static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
for (Node* n : (*graph)->nodes()) {
@@ -26,12 +28,19 @@ namespace tensorflow {
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
+ opt_options.session_options = session_options;
opt_options.flib_def = flib_def;
MarkForCompilationPass pass;
return pass.RunImpl(opt_options);
}
/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
+ std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) {
+ SessionOptions session_options;
+ return MarkForCompilation(graph, flib_def, &session_options);
+}
+
+/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
std::unique_ptr<Graph>* graph) {
FunctionDefLibrary flib;
FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h
index b9a0531cb0..216baaf933 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h
@@ -25,6 +25,11 @@ class MarkForCompilationPassTestHelper {
// `graph` to the CPU device. To make testing easier, ignores device
// registration, _XlaCompile attributes, input deadness and global jit level.
static Status MarkForCompilation(std::unique_ptr<Graph>* graph,
+ FunctionLibraryDefinition* flib_def,
+ SessionOptions* session_options);
+
+ // Like `MarkForCompilation` but creates a default SessionOptions.
+ static Status MarkForCompilation(std::unique_ptr<Graph>* graph,
FunctionLibraryDefinition* flib_def);
// Like `MarkForCompilation` but creates `flib_def` from the op registry.
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 2a2691a6a4..70e6d0be0f 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <stdlib.h>
#include <unordered_set>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/compiler/jit/xla_device_context.h"
@@ -101,7 +102,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
}
std::unique_ptr<XlaDeviceAllocator> alloc =
- xla::MakeUnique<XlaDeviceAllocator>();
+ absl::make_unique<XlaDeviceAllocator>();
XlaDeviceAllocator* alloc_ptr = alloc.get();
state.allocators_[{backend, device_ordinal}] = std::move(alloc);
return alloc_ptr;
@@ -327,7 +328,7 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
// to those methods; see the bug for details. Our only saving grace at the
// moment is that this race doesn't seem to occur in practice.
if (use_gpu_device_info_) {
- auto gpu_device_info = MakeUnique<GpuDeviceInfo>();
+ auto gpu_device_info = absl::make_unique<GpuDeviceInfo>();
gpu_device_info->stream = stream_.get();
gpu_device_info->default_context = device_context_;
set_tensorflow_gpu_device_info(gpu_device_info.get());
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 0a0c089241..175a571ddb 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -91,7 +91,8 @@ Status XlaTransferManager::TransferLiteralToDevice(
const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
VLOG(1) << "Transfer to device as literal: " << literal->ToString() << " "
<< shaped_buffer.ToString();
- if (UseMultipleStreams()) {
+ if (UseMultipleStreams() && !transfer_manager_->CanShapedBufferBeAccessedNow(
+ stream_->parent(), shaped_buffer)) {
// Initially wait for the compute stream so that memory allocations are
// synchronized.
host_to_device_stream_->ThenWaitFor(stream_.get());
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 4efbb2d5d7..2ffce9298d 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@@ -175,7 +176,7 @@ void XlaComputationLaunchContext::PopulateInputs(
<< " not the same as on-host shape "
<< xla::ShapeUtil::HumanStringWithLayout(shape);
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
- arg_buffers_[i] = xla::MakeUnique<ShapedBuffer>(
+ arg_buffers_[i] = absl::make_unique<ShapedBuffer>(
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
client_->platform(), client_->default_device_ordinal());
arg_buffers_[i]->set_buffer(dmem, /*index=*/{});
diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h
index 8d36d0fa0a..07a9bf0d4a 100644
--- a/tensorflow/compiler/jit/xla_tensor.h
+++ b/tensorflow/compiler/jit/xla_tensor.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/core/framework/allocator.h"
@@ -70,7 +71,7 @@ class XlaTensor {
// Mutates the XlaTensor to set the ShapedBuffer.
void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) {
shaped_buffer_ =
- xla::MakeUnique<xla::ScopedShapedBuffer>(std::move(shaped_buffer));
+ absl::make_unique<xla::ScopedShapedBuffer>(std::move(shaped_buffer));
}
// Some tensors on the device may have known values on the host. We use these
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 0aafda7fb4..5b7001b5a4 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -1167,6 +1167,16 @@ class BinaryOpsTest(xla_test.XLATestCase):
for dtype in self.numeric_types:
self._testBinary(
array_ops.tile,
+ np.array([[6], [3], [4]], dtype=dtype),
+ np.array([2, 0], dtype=np.int32),
+ expected=np.empty([6, 0], dtype=dtype))
+ self._testBinary(
+ array_ops.tile,
+ np.array([[6, 3, 4]], dtype=dtype),
+ np.array([2, 0], dtype=np.int32),
+ expected=np.empty([2, 0], dtype=dtype))
+ self._testBinary(
+ array_ops.tile,
np.array([[6]], dtype=dtype),
np.array([1, 2], dtype=np.int32),
expected=np.array([[6, 6]], dtype=dtype))
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index ff097f80f1..3d21fb5864 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -443,7 +443,6 @@ class EagerFunctionTest(xla_test.XLATestCase):
self.assertAllEqual((2, 3, 4), dz.shape.as_list())
def testNestedDefun(self):
- self.skipTest('Nested defuns do not work on TPU at the moment')
with self.test_scope():
@function.defun
diff --git a/tensorflow/compiler/tests/reverse_ops_test.py b/tensorflow/compiler/tests/reverse_ops_test.py
index d01c676e7c..32ab5d08f0 100644
--- a/tensorflow/compiler/tests/reverse_ops_test.py
+++ b/tensorflow/compiler/tests/reverse_ops_test.py
@@ -32,14 +32,20 @@ class ReverseOpsTest(xla_test.XLATestCase):
def testReverseOneDim(self):
shape = (7, 5, 9, 11)
- for revdim in range(len(shape)):
+ for revdim in range(-len(shape), len(shape)):
self._AssertReverseEqual([revdim], shape)
def testReverseMoreThanOneDim(self):
shape = (7, 5, 9, 11)
+ # The offset is used to test various (but not all) combinations of negative
+ # and positive axis indices that are guaranteed to not collide at the same
+ # index.
for revdims in itertools.chain.from_iterable(
- itertools.combinations(range(len(shape)), k)
- for k in range(2, len(shape)+1)):
+ itertools.combinations(range(-offset,
+ len(shape) - offset), k)
+ for k in range(2,
+ len(shape) + 1)
+ for offset in range(0, len(shape))):
self._AssertReverseEqual(revdims, shape)
def _AssertReverseEqual(self, revdims, shape):
@@ -50,15 +56,16 @@ class ReverseOpsTest(xla_test.XLATestCase):
p = array_ops.placeholder(dtypes.int32, shape=shape)
axis = constant_op.constant(
np.array(revdims, dtype=np.int32),
- shape=(len(revdims),), dtype=dtypes.int32)
+ shape=(len(revdims),),
+ dtype=dtypes.int32)
rval = array_ops.reverse(p, axis).eval({p: pval})
slices = [
- slice(-1, None, -1) if d in revdims else slice(None)
- for d in range(len(shape))]
- self.assertEqual(
- pval[slices].flatten().tolist(),
- rval.flatten().tolist())
+ slice(-1, None, -1)
+ if d in revdims or d - len(shape) in revdims else slice(None)
+ for d in range(len(shape))
+ ]
+ self.assertEqual(pval[slices].flatten().tolist(), rval.flatten().tolist())
if __name__ == '__main__':
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 73adb0d243..124cf9da81 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -398,6 +398,11 @@ class UnaryOpsTest(xla_test.XLATestCase):
self._assertOpOutputMatchesExpected(
math_ops.lgamma,
+ np.array(0.5, dtype=dtype),
+ expected=np.array(np.log(np.pi) / 2, dtype=dtype))
+
+ self._assertOpOutputMatchesExpected(
+ math_ops.lgamma,
np.array(
[[1, 2, 3], [4, 5, 6], [1 / 2, 3 / 2, 5 / 2],
[-3 / 2, -7 / 2, -11 / 2]],
@@ -420,6 +425,19 @@ class UnaryOpsTest(xla_test.XLATestCase):
],
dtype=dtype))
+ # The actual result is complex. Take the real part.
+ self._assertOpOutputMatchesExpected(
+ math_ops.lgamma,
+ np.array([-1 / 2, -5 / 2, -9 / 2], dtype=dtype),
+ expected=np.array(
+ [
+ np.log(np.pi) / 2 + np.log(2),
+ np.log(np.pi) / 2 - np.log(15) + np.log(8),
+ np.log(np.pi) / 2 - np.log(945) + np.log(32),
+ ],
+ dtype=dtype),
+ atol=1e-4)
+
self._assertOpOutputMatchesExpected(
math_ops.digamma,
np.array(
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index fda32c8a1c..575917d078 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -211,6 +211,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
alwayslink = 1,
)
@@ -443,21 +444,94 @@ cc_library(
)
cc_library(
+ name = "functionalize_control_flow_util",
+ srcs = [
+ "functionalize_control_flow_util.cc",
+ ],
+ hdrs = [
+ "functionalize_control_flow_util.h",
+ ],
+ deps = [
+ "//tensorflow/compiler/tf2xla/ops:xla_ops",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
+ name = "functionalize_cond",
+ srcs = [
+ "functionalize_cond.cc",
+ ],
+ hdrs = [
+ "functionalize_cond.h",
+ ],
+ deps = [
+ ":functionalize_control_flow_util",
+ ":tf2xla_util",
+ "//tensorflow/compiler/jit:union_find",
+ "//tensorflow/compiler/tf2xla:dump_graph",
+ "//tensorflow/compiler/tf2xla/ops:xla_ops",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ ],
+)
+
+cc_library(
name = "functionalize_control_flow",
- srcs = ["functionalize_control_flow.cc"],
- hdrs = ["functionalize_control_flow.h"],
+ srcs = [
+ "functionalize_control_flow.cc",
+ ],
+ hdrs = [
+ "functionalize_control_flow.h",
+ ],
+ deps = [
+ ":functionalize_cond",
+ ":functionalize_control_flow_util",
+ ":functionalize_while",
+ ":tf2xla_util",
+ "//tensorflow/compiler/jit:union_find",
+ "//tensorflow/compiler/tf2xla:dump_graph",
+ "//tensorflow/compiler/tf2xla/ops:xla_ops",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ ],
+)
+
+cc_library(
+ name = "functionalize_while",
+ srcs = [
+ "functionalize_while.cc",
+ ],
+ hdrs = [
+ "functionalize_while.h",
+ ],
deps = [
+ ":functionalize_control_flow_util",
":tf2xla_util",
"//tensorflow/compiler/jit:union_find",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -485,6 +559,32 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "functionalize_cond_test",
+ srcs = ["functionalize_cond_test.cc"],
+ deps = [
+ ":functionalize_cond",
+ ":functionalize_control_flow",
+ ":test_util",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
+ "//tensorflow/cc:function_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:resource_variable_ops",
+ "//tensorflow/compiler/tf2xla/cc:xla_ops",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:resource_variable_ops_op_lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
cc_library(
name = "test_util",
testonly = 1,
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc
new file mode 100644
index 0000000000..0f5471616e
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc
@@ -0,0 +1,1380 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
+
+#include <algorithm>
+#include <deque>
+#include <stack>
+#include <unordered_set>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/compiler/jit/union_find.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/control_flow.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/gtl/optional.h"
+
+using xla::StatusOr;
+
+namespace tensorflow {
+namespace functionalize_cond {
+
+string DebugString(const CondStateMap::CondNode& node) {
+ return node.ToString();
+}
+
+// TODO(jpienaar): Move to OutputTensor.
+string DebugString(const OutputTensor& tensor) {
+ return strings::StrCat(tensor.node->name(), ":", tensor.index);
+}
+
+string DebugString(CondStateMap::CondId cond_state) {
+ if (cond_state == nullptr || cond_state->empty()) return "[]";
+ return strings::StrCat(
+ "[",
+ tensorflow::str_util::Join(
+ *cond_state, ", ",
+ [](string* output, const CondStateMap::CondNode& node) {
+ strings::StrAppend(output, node.ToString());
+ }),
+ "]");
+}
+
+string Branch_Name(BranchType b) {
+ switch (b) {
+ case BranchType::kElseBranch:
+ return "else";
+ case BranchType::kThenBranch:
+ return "then";
+ case BranchType::kBoth:
+ return "both";
+ case BranchType::kNeither:
+ return "neither";
+ }
+}
+
+// Returns the predicate of a switch.
+Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) {
+ const Edge* pred_edge;
+ TF_RETURN_IF_ERROR(switch_node.input_edge(1, &pred_edge));
+ // The predicate can be preceded by a identity node. Look through
+ // identity nodes to predicate.
+ while (pred_edge->src()->IsIdentity()) {
+ TF_RETURN_IF_ERROR(pred_edge->src()->input_edge(0, &pred_edge));
+ }
+ *pred = OutputTensor(pred_edge->src(), pred_edge->src_output());
+ return Status::OK();
+}
+
+CondStateMap::CondNode::CondNode(Type type, Node* switch_node,
+ BranchType branch)
+ : type(type), branch(branch) {
+ if (type == Type::kSwitch) {
+ TF_CHECK_OK(GetSwitchPredicate(*switch_node, &predicate));
+ }
+}
+
+string CondStateMap::CondNode::ToString() const {
+ switch (type) {
+ case Type::kSwitch:
+ return strings::StrCat("s(", DebugString(predicate), ",",
+ Branch_Name(branch), ")");
+ case Type::kMerge:
+ return "m";
+ case Type::kDead:
+ return "d";
+ }
+}
+
+bool CondStateMap::CondNode::operator==(const CondNode& other) const {
+ if (type != Type::kSwitch) return type == other.type;
+ return type == other.type && predicate == other.predicate &&
+ branch == other.branch;
+}
+
+bool CondStateMap::CondNode::operator!=(const CondNode& other) const {
+ return !(*this == other);
+}
+
+CondStateMap::CondStateMap(Graph* graph) {
+ node_to_condid_map_.resize(graph->num_node_ids());
+ // Initialize the dead state (empty state is designated with a nullptr).
+ dead_id_ = GetUniqueId({CondNode(CondStateMap::CondNode::Type::kDead)});
+}
+
+bool CondStateMap::IsDead(CondStateMap::CondId id) const {
+ return id == dead_id_;
+}
+
+bool CondStateMap::IsEmpty(CondStateMap::CondId id) const {
+ return id == nullptr;
+}
+
+size_t CondStateMap::CondHash::operator()(
+ const CondStateMap::CondNode& item) const {
+ return Hash64Combine(Hash64Combine(OutputTensor::Hash()(item.predicate),
+ hash<BranchType>()(item.branch)),
+ hash<CondStateMap::CondNode::Type>()(item.type));
+}
+
+size_t CondStateMap::CondHash::operator()(
+ const CondStateMap::CondState& vec) const {
+ if (vec.empty()) return 0;
+ size_t h = (*this)(vec.front());
+ auto it = vec.begin();
+ for (++it; it != vec.end(); ++it) {
+ h = Hash64Combine(h, (*this)(*it));
+ }
+ return h;
+}
+
+// CondArgNode represents a input to the conditional and its corresponding
+// switch nodes.
+struct CondArgNode {
+ explicit CondArgNode(Node* src, int src_output)
+ : src(src), src_output(src_output) {}
+
+ string ToString() const {
+ return strings::StrCat("src=", src->name(), ":", src_output,
+ " switches=", NodesToString(switches));
+ }
+
+ Node* src;
+ int src_output;
+ std::array<Node*, 2> branch_copy;
+ std::vector<Node*> switches;
+};
+using CondArgNodes = std::vector<CondArgNode>;
+
+string DebugString(const CondArgNodes& nodes) {
+ return strings::StrCat(
+ "[",
+ tensorflow::str_util::Join(nodes, ", ",
+ [](string* output, const CondArgNode& node) {
+ strings::StrAppend(output, node.ToString());
+ }),
+ "]");
+}
+
+CondStateMap::CondId CondStateMap::LookupId(const Node* node) const {
+ if (node->id() < node_to_condid_map_.size())
+ return node_to_condid_map_[node->id()];
+ return added_node_mapping_.at(node->id());
+}
+
+CondStateMap::CondId CondStateMap::GetUniqueId(
+ const CondStateMap::CondState& state) {
+ if (state.empty()) return nullptr;
+ return &*condstate_set_.insert(state).first;
+}
+
+const CondStateMap::CondState& CondStateMap::LookupState(
+ const Node* node) const {
+ return *LookupId(node);
+}
+
+void CondStateMap::ResetId(const Node* node, CondStateMap::CondId id) {
+ if (node->id() < node_to_condid_map_.size())
+ node_to_condid_map_[node->id()] = id;
+ else
+ added_node_mapping_[node->id()] = id;
+}
+
+void CondStateMap::MarkDead(const Node* node) { ResetId(node, dead_id_); }
+
+string CondStateMap::CondStateToString(const Node* node) const {
+ return CondStateToString(LookupId(node));
+}
+
+string CondStateMap::CondStateToString(CondStateMap::CondId id) const {
+ return DebugString(id);
+}
+
+FunctionalizeCond::FunctionalizeCond(Graph* graph,
+ FunctionLibraryDefinition* library)
+ : cond_state_map_(graph), library_(library), graph_(graph) {}
+
+// Class representing the merge/switch nodes that will become a conditional.
+class Conditional {
+ public:
+ Conditional(OutputTensor predicate, FunctionalizeCond* parent,
+ CondStateMap* cond_state_map);
+
+ // Adds merge node that is part of this conditional.
+ Status AddMerge(Node* m);
+
+ // Constructs an If node from the merge nodes.
+ Status BuildAndReplace(Graph* graph, FunctionLibraryDefinition* library);
+
+ private:
+ // Extracts the then/else bodies: creates new graphs with the nodes
+ // corresponding to the nodes in the then/else branches as of this conditional
+ // as function bodies.
+ Status ExtractBodies(Graph* graph);
+
+ // Builds the arguments that are the input to the If.
+ Status BuildArgumentNodes();
+
+ // Builds the If node for the extracted bodies with the given predicate.
+ Status BuildIfNode(Graph* graph, FunctionLibraryDefinition* library);
+
+ // Adds input edges to If node.
+ Status AddInputEdges(Graph* graph);
+
+ // Adds output edges from If node.
+ Status AddOutputEdges(Graph* graph);
+
+ // Adds switch node that is part of this conditional.
+ Status AddSwitch(Node* s);
+
+ // Internal name of conditional. The name is based on the first merge node
+ // added.
+ string name() const;
+
+ // The FunctionalizeCond instance that created this.
+ FunctionalizeCond* parent_;
+
+ // Mapping between nodes and their cond state.
+ CondStateMap* cond_state_map_;
+
+ // The predicate of the conditional.
+ OutputTensor predicate_;
+
+ // The predicate of the switches of the conditional. This may be different
+ // than predicate (which is initialized from the original graph) as the
+ // predicate could be the output of a newly created If node.
+ OutputTensor switch_predicate_;
+
+ // Switch nodes in graph that are part of this conditional.
+ std::set<Node*, NodeCmpByNameResourcesLast> switches_;
+
+ // Merge nodes in graph that are part of this conditional.
+ std::set<Node*, NodeCmpByNameResourcesLast> merges_;
+
+ // Vector of control inputs from outside the conditional to a node inside.
+ std::vector<Node*> external_control_inputs_;
+ std::vector<Node*> external_control_outputs_;
+
+ // Graphs corresponding to the then and else branch.
+ std::array<std::unique_ptr<Graph>, 2> bodies_;
+
+ // Maps from graph_ to the branch body's graph.
+ std::array<std::vector<Node*>, 2> node_maps_;
+
+ // The argument nodes created for the switches.
+ CondArgNodes cond_arg_nodes_;
+
+ // The constructed If node.
+ Node* if_node_ = nullptr;
+
+ // Whether the merge nodes of this conditional have been replaced.
+ bool replaced_ = false;
+};
+
+Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent,
+ CondStateMap* cond_state_map)
+ : parent_(parent), cond_state_map_(cond_state_map), predicate_(predicate) {}
+
+Status Conditional::AddMerge(Node* m) {
+ merges_.insert(m);
+ return Status::OK();
+}
+
+Status Conditional::AddSwitch(Node* s) {
+ VLOG(5) << "Adding switch " << s->DebugString();
+ OutputTensor predicate;
+ TF_RETURN_IF_ERROR(GetSwitchPredicate(*s, &predicate));
+ if (switch_predicate_.node == nullptr) switch_predicate_ = predicate;
+ if (!(switch_predicate_ == predicate)) {
+ return errors::InvalidArgument(
+ "Merge nodes ", NodesToString(merges_),
+ " directly dominated by switch nodes with different predicates (",
+ DebugString(switch_predicate_), " vs ", DebugString(predicate), ").");
+ }
+ switches_.insert(s);
+ return Status::OK();
+}
+
+Status Conditional::BuildArgumentNodes() {
+ VLOG(1) << "Build function arguments";
+ struct Hash {
+ size_t operator()(const std::pair<Node*, int>& item) const {
+ return Hash64Combine(hash<Node*>()(item.first),
+ std::hash<int>()(item.second));
+ }
+ };
+
+ std::unordered_map<std::pair<Node*, int>, int, Hash> input_index;
+ for (Node* switch_node : switches_) {
+ const Edge* e;
+ TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e));
+ std::pair<Node*, int> key = std::make_pair(e->src(), e->src_output());
+ if (input_index.find(key) == input_index.end()) {
+ input_index[key] = cond_arg_nodes_.size();
+ cond_arg_nodes_.emplace_back(key.first, key.second);
+ }
+ cond_arg_nodes_.at(input_index.at(key)).switches.push_back(switch_node);
+ }
+ VLOG(5) << "CondArg nodes created: " << DebugString(cond_arg_nodes_);
+
+ int arg_count = 0;
+ for (CondArgNode& cond_arg_node : cond_arg_nodes_) {
+ DataType dtype = cond_arg_node.src->output_type(cond_arg_node.src_output);
+ for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
+ int branch_index = static_cast<int>(branch);
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(strings::StrCat("_Arg", arg_count),
+ FunctionLibraryDefinition::kArgOp)
+ .Attr("T", dtype)
+ .Attr("index", arg_count)
+ .Finalize(bodies_[branch_index].get(),
+ &cond_arg_node.branch_copy[branch_index]));
+ }
+ for (Node* node : cond_arg_node.switches) {
+ for (const Edge* e : node->out_edges()) {
+ if (e->IsControlEdge()) continue;
+ int branch_index = e->src_output();
+ Node* src_copy = cond_arg_node.branch_copy[branch_index];
+ Node* dst_copy = node_maps_[branch_index][e->dst()->id()];
+
+ // The graph may contain dead switch nodes,
+ if (dst_copy == nullptr) continue;
+
+ TF_RET_CHECK(dst_copy != nullptr)
+ << "Unable to find copied node for " << e->dst()->DebugString()
+ << " on branch " << Branch_Name(BranchType(branch_index));
+ // If the input goes directly to a merge then the merge has
+ // been replaced by a retval so the dst input is 0 instead of
+ // dst_input.
+ int dst_input = IsMerge(e->dst()) ? 0 : e->dst_input();
+ bodies_[branch_index]->AddEdge(src_copy, 0, dst_copy, dst_input);
+ }
+ }
+ ++arg_count;
+ }
+
+ // Verify that all retvals have an input.
+ // TODO(jpienaar): One could add a ZerosLike in the branch that doesn't have
+ // input.
+ for (Node* m : merges_) {
+ for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
+ bool has_input = false;
+ for (auto e : node_maps_[static_cast<int>(branch)][m->id()]->in_edges()) {
+ if (!e->IsControlEdge()) {
+ has_input = true;
+ break;
+ }
+ }
+ if (!has_input) {
+ return errors::Internal(
+ "Failed to functionalize control flow with merge '", m->name(),
+ "' that doesn't have input on ", Branch_Name(branch), " branch.");
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+Status Conditional::ExtractBodies(Graph* graph) {
+ VLOG(2) << "Extracting bodies for " << name();
+ for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) {
+ bodies_[static_cast<int>(b)] =
+ absl::make_unique<Graph>(graph->op_registry());
+ }
+
+ auto find_branch = [&](const Edge* e) {
+ const auto& id = cond_state_map_->LookupId(e->src());
+ return IsSwitch(e->src()) ? BranchType(e->src_output())
+ : cond_state_map_->FindBranchOf(id, predicate_);
+ };
+
+ std::array<std::vector<Node*>, 2> stacks;
+ VLOG(5) << "Merges: " << NodesToString(merges_);
+ for (Node* m : merges_) {
+ VLOG(5) << "For merge: " << m->DebugString() << " "
+ << cond_state_map_->CondStateToString(m);
+ for (auto e : m->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ BranchType branch = find_branch(e);
+ TF_RET_CHECK(branch == BranchType::kThenBranch ||
+ branch == BranchType::kElseBranch)
+ << "Error: " << e->src()->name()
+ << " is not on either then or else branch (" << Branch_Name(branch)
+ << ").";
+ Node* src = e->src();
+ if (IsSwitch(src)) {
+ // Switch node outputs and dependencies are handled separately.
+ TF_RETURN_IF_ERROR(AddSwitch(src));
+ } else {
+ stacks[static_cast<int>(branch)].push_back(src);
+ }
+ }
+ }
+
+ for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
+ int branch_index = static_cast<int>(branch);
+ auto output = bodies_[branch_index].get();
+ auto& stack = stacks[branch_index];
+ VLOG(5) << "In branch: " << Branch_Name(branch) << " "
+ << NodesToString(stack);
+ std::vector<bool> visited(graph->num_node_ids(), false);
+ node_maps_[branch_index].resize(graph->num_node_ids(), nullptr);
+ auto& node_map = node_maps_[branch_index];
+
+ while (!stack.empty()) {
+ Node* n = stack.back();
+ stack.pop_back();
+
+ if (visited.at(n->id())) continue;
+ visited[n->id()] = true;
+
+ // Verify output edges and record control edges exitting scope.
+ for (const Edge* e : n->out_edges()) {
+ Node* dst = e->dst();
+ if (IsMerge(dst)) continue;
+ Node* src = e->src();
+
+ auto dst_id = cond_state_map_->LookupId(dst);
+ auto src_id = cond_state_map_->LookupId(src);
+ if (dst_id != src_id) {
+ if (e->IsControlEdge()) {
+ external_control_outputs_.push_back(e->src());
+ } else {
+ // Constants are treated specially to workaround the case of
+ // non-dominated constant nodes.
+ if (!IsConstant(src)) {
+ // TODO(b/78882471): A node that feeds into two different
+ // CondState is not necessarily an error so log a warning for now
+ // but revisit to improve the testing to enable making this an
+ // error.
+ LOG(WARNING) << errors::InvalidArgument(
+ "Graph contains node ", src->name(), " that feeds into node ",
+ dst->name(),
+ " but these nodes are in different control contexts (",
+ DebugString(src_id), " vs ", DebugString(dst_id),
+ " (detected during out edge testing)");
+ }
+ }
+ }
+ }
+
+ // Copying incomming edges to dst node.
+ for (const Edge* e : n->in_edges()) {
+ Node* src = e->src();
+ // Skip src/dst node.
+ if (!src->IsOp()) continue;
+
+ Node* dst = e->dst();
+ if (IsSwitch(src)) {
+ // Switch node outputs and dependencies are handled separately.
+ TF_RETURN_IF_ERROR(AddSwitch(src));
+ continue;
+ }
+
+ // Verify input is from the same context.
+ auto src_id = cond_state_map_->LookupId(src);
+ auto dst_id = cond_state_map_->LookupId(dst);
+ if (IsMerge(dst) || src_id == dst_id) {
+ // TODO(jpienaar): The merge case can be more strict.
+ if (node_map.at(src->id()) == nullptr) {
+ node_map.at(src->id()) = output->CopyNode(src);
+ stack.push_back(src);
+ }
+ } else if (e->IsControlEdge()) {
+ external_control_inputs_.push_back(src);
+ } else {
+ // This shouldn't happen, this means we have an external data input
+ // not entering via a switch node. Work around this for constant
+ // nodes as some constant nodes are inserted without the required
+ // control context dominance.
+ if (IsConstant(src)) {
+ node_map.at(src->id()) = output->CopyNode(src);
+ } else {
+ return errors::InvalidArgument(
+ "Graph contains node ", src->name(), " that feeds into node ",
+ dst->name(),
+ " but these nodes are in different control contexts (",
+ DebugString(src_id), " vs ", DebugString(dst_id),
+ " (detected during in edge testing)");
+ }
+ }
+
+ Node* src_copy = node_map.at(e->src()->id());
+ int src_output = e->src_output();
+ if (node_map.at(dst->id()) == nullptr) {
+ node_map.at(dst->id()) = output->CopyNode(dst);
+ }
+ Node* dst_copy = node_map.at(e->dst()->id());
+ if (e->IsControlEdge()) {
+ // Skip control inputs from external context.
+ if (src_copy != nullptr) output->AddControlEdge(src_copy, dst_copy);
+ } else {
+ output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
+ }
+ }
+ }
+ }
+
+ // Build return values from the merge nodes.
+ int index = 0;
+ for (Node* m : merges_) {
+ for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
+ int branch_index = static_cast<int>(branch);
+ auto& node_map = node_maps_[branch_index];
+ auto output = bodies_[branch_index].get();
+ TF_ASSIGN_OR_RETURN(node_map[m->id()],
+ BuildRetvalNode(output, m->output_type(0), index));
+ }
+ ++index;
+
+ // Connect the input to the merge_ with the retval, except if it is a
+ // Swich node, which is handled separately.
+ for (auto e : m->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ int branch_index = static_cast<int>(find_branch(e));
+ auto& node_map = node_maps_[branch_index];
+ auto output = bodies_[branch_index].get();
+ Node* in = e->src();
+ if (!IsSwitch(in)) {
+ if (node_map.at(in->id()) == nullptr) {
+ node_map[in->id()] = output->CopyNode(in);
+ }
+ output->AddEdge(node_map[in->id()], e->src_output(),
+ node_map.at(m->id()), 0);
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status Conditional::BuildIfNode(Graph* graph,
+ FunctionLibraryDefinition* library) {
+ VLOG(2) << "Build cond function for " << name();
+ NodeDefBuilder builder(name(), "If");
+ const string branch_name[] = {"else_branch", "then_branch"};
+ for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
+ int branch_index = static_cast<int>(branch);
+ static std::atomic<int64> sequence_num(0LL);
+ int64 id = ++sequence_num;
+
+ NameAttrList body_name;
+ body_name.set_name(strings::StrCat("_functionalize_if_",
+ branch_name[branch_index], "_", id));
+
+ VLOG(3) << "FunctionalizeControlFlow (" << branch_name[branch_index]
+ << "): "
+ << dump_graph::DumpGraphToFile(
+ "functionalize_cond_body_" + branch_name[branch_index],
+ *bodies_[branch_index], nullptr);
+
+ FunctionDef body_fdef;
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(*bodies_[branch_index],
+ body_name.name(), &body_fdef));
+ TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
+ builder.Attr(branch_name[branch_index], body_name);
+ }
+
+ VLOG(3) << "Build input type";
+ std::vector<NodeDefBuilder::NodeOut> inputs;
+ DataTypeVector in_arg_types;
+ for (auto& kv : cond_arg_nodes_) {
+ bool inserted = false;
+ for (const Node* arg : kv.switches) {
+ const Edge* in_edge;
+ TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
+ if (in_edge->IsControlEdge()) {
+ builder.ControlInput(in_edge->src()->name());
+ } else {
+ if (!inserted) {
+ DataType dtype = arg->input_type(0);
+ inputs.emplace_back(NodeDefBuilder::NodeOut(
+ in_edge->src()->name(), in_edge->src_output(), dtype));
+ in_arg_types.push_back(dtype);
+ inserted = true;
+ }
+ }
+ }
+ }
+ builder.Attr("Tin", in_arg_types);
+
+ DataTypeVector out_type;
+ for (const Node* merge : merges_) {
+ DataType dtype = merge->output_type(0);
+ out_type.push_back(dtype);
+ }
+ builder.Attr("Tout", out_type);
+ VLOG(3) << "Build output type: " << DataTypeVectorString(out_type);
+
+ builder.Attr("Tcond", DT_BOOL);
+ builder.Device(predicate_.node->assigned_device_name());
+ // Conditional should be the first input ...
+ builder.Input(NodeDefBuilder::NodeOut(predicate_.node->name(),
+ predicate_.index,
+ predicate_.node->output_type(0)));
+ // ... followed by the other inputs.
+ builder.Input(inputs);
+
+ VLOG(3) << "Build If node";
+ NodeDef if_def;
+ TF_RETURN_IF_ERROR(builder.Finalize(&if_def));
+ TF_ASSIGN_OR_RETURN(if_node_, parent_->AddIfNode(if_def, *merges_.begin()));
+
+ return Status::OK();
+}
+
+Status Conditional::AddInputEdges(Graph* graph) {
+ VLOG(2) << "AddInputEdges for " << if_node_->name();
+ int index = 0;
+ // Add predicate input.
+ graph->AddEdge(const_cast<Node*>(predicate_.node), predicate_.index, if_node_,
+ index++);
+ // Add function body inputs.
+ for (auto& arg : cond_arg_nodes_) {
+ if (arg.src_output == Graph::kControlSlot) {
+ graph->AddControlEdge(arg.src, if_node_);
+ } else {
+ graph->AddEdge(arg.src, arg.src_output, if_node_, index++);
+ }
+ }
+ for (Node* n : external_control_inputs_) {
+ graph->AddControlEdge(n, if_node_);
+ }
+ return Status::OK();
+}
+
+Status Conditional::AddOutputEdges(Graph* graph) {
+ VLOG(2) << "AddOutputEdges for " << if_node_->name();
+ int i = 0;
+ for (Node* node : merges_) {
+ TF_RETURN_IF_ERROR(parent_->AddIdentityNode(node, if_node_, i));
+ std::vector<const Edge*> edges(node->out_edges().begin(),
+ node->out_edges().end());
+ for (const Edge* edge : edges) {
+ Node* dst = edge->dst();
+ int dst_input = edge->dst_input();
+ if (edge->src_output() > 0) {
+ return errors::Unimplemented("Output of index (", edge->src_output(),
+ ") of merge node ", node->name());
+ }
+
+ bool control_edge = edge->IsControlEdge();
+ graph->RemoveEdge(edge);
+ if (control_edge) {
+ graph->AddControlEdge(if_node_, dst);
+ } else {
+ graph->AddEdge(if_node_, i, dst, dst_input);
+ }
+ }
+ ++i;
+ }
+ for (Node* n : external_control_outputs_) {
+ graph->AddControlEdge(if_node_, n);
+ }
+
+ return Status::OK();
+}
+
+Status Conditional::BuildAndReplace(Graph* graph,
+ FunctionLibraryDefinition* library) {
+ VLOG(1) << "Build If and replace merge nodes " << name();
+ if (replaced_) return Status::OK();
+
+ TF_RETURN_IF_ERROR(ExtractBodies(graph));
+ TF_RETURN_IF_ERROR(BuildArgumentNodes());
+
+ if (VLOG_IS_ON(3)) {
+ LOG(INFO) << "Extracted bodies:";
+ for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
+ int branch_index = static_cast<int>(branch);
+ auto output = bodies_[branch_index].get();
+ LOG(INFO) << Branch_Name(branch) << ": "
+ << DebugString(output->ToGraphDefDebug());
+ }
+ }
+
+ TF_RETURN_IF_ERROR(BuildIfNode(graph, library));
+ TF_RETURN_IF_ERROR(AddInputEdges(graph));
+ TF_RETURN_IF_ERROR(AddOutputEdges(graph));
+ TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_));
+ for (Node* m : merges_) cond_state_map_->MarkDead(m);
+
+ // Check that the if_node doesn't feed into itself.
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ CheckNodeNotInCycle(if_node_, graph->num_node_ids()),
+ "Converting to If failed.");
+
+ replaced_ = true;
+ return Status::OK();
+}
+
+string Conditional::name() const {
+ CHECK(!merges_.empty());
+ return strings::StrCat((*merges_.begin())->name(), "_if");
+}
+
+bool CondStateMap::ScopeIn(CondStateMap::CondId id,
+ CondStateMap::CondId* scope) {
+ if (id == nullptr) {
+ *scope = nullptr;
+ return true;
+ }
+ CondState state;
+ for (const CondNode& node : *id) {
+ if (node.type == CondNode::Type::kSwitch) {
+ state.push_back(node);
+ }
+ if (node.type == CondNode::Type::kMerge) {
+ if (state.empty()) {
+ return false;
+ }
+ DCHECK(state.back().type == CondNode::Type::kSwitch &&
+ state.back().branch == BranchType::kBoth);
+ state.pop_back();
+ }
+ }
+ *scope = GetUniqueId(state);
+ return true;
+}
+
+Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node,
+ int port) {
+ Node* id;
+ TF_RETURN_IF_ERROR(NodeBuilder(replacee->name(), "Identity")
+ .Input(if_node, port)
+ .Finalize(graph_, &id));
+ cond_state_map_.ResetId(id, cond_state_map_.LookupId(if_node));
+ return Status::OK();
+}
+
+StatusOr<Node*> FunctionalizeCond::AddIfNode(const NodeDef& def,
+ const Node* replacee) {
+ Status status;
+ Node* ret = graph_->AddNode(def, &status);
+ TF_RETURN_IF_ERROR(status);
+ CondStateMap::CondState state = cond_state_map_.LookupState(replacee);
+ state.pop_back();
+ VLOG(1) << "Adding If for " << replacee->name();
+ cond_state_map_.ResetId(ret, cond_state_map_.GetUniqueId(state));
+ return ret;
+}
+
+Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) {
+ VLOG(2) << "Propagating update state for " << replacee->name() << " "
+ << cond_state_map_.CondStateToString(replacee);
+ // Redo topological sort as the order could have changed.
+ // TODO(jpienaar): The original topological order could also be updated
+ // dynamically if needed.
+ std::vector<Node*> rev_topo_order;
+ GetPostOrder(*graph_, &rev_topo_order);
+
+ // All the outputs of the new node could potentially be updated.
+ std::unordered_set<Node*> changed;
+ for (auto n : replacee->out_nodes())
+ if (n->IsOp()) changed.insert(n);
+
+ // Iterate through the changed/possible changed nodes in topological order.
+ for (auto it = rev_topo_order.rbegin();
+ it != rev_topo_order.rend() && !changed.empty(); ++it) {
+ if (changed.find(*it) != changed.end()) {
+ // Update the node state.
+ Node* n = *it;
+ CondStateMap::CondId old_state = cond_state_map_.LookupId(n);
+ cond_state_map_.ResetId(n, nullptr);
+ TF_RETURN_IF_ERROR(DetermineCondState(n));
+ if (cond_state_map_.LookupId(n) != old_state) {
+ for (auto out : n->out_nodes())
+ if (out->IsOp()) changed.insert(out);
+ }
+ changed.erase(n);
+ }
+ }
+ return Status::OK();
+}
+
+// Returns the most restrictive branch of two branches or neither. This is the
+// meet operator of the BranchType lattice.
+BranchType MeetBranch(const BranchType& lhs, const BranchType& rhs) {
+ if (lhs == rhs) return lhs;
+ if (lhs == BranchType::kNeither) return rhs;
+ if (rhs == BranchType::kNeither) return lhs;
+ if (lhs == BranchType::kBoth) return rhs;
+ if (rhs == BranchType::kBoth) return lhs;
+ return BranchType::kNeither;
+}
+
+CondStateMap::ContainsResult CondStateMap::LhsHoldsWhereverRhsHolds(
+ CondStateMap::CondId lhs, CondStateMap::CondId rhs) {
+ CondId lhs_scope;
+ CondId rhs_scope;
+ bool could_determine_scope = ScopeIn(lhs, &lhs_scope);
+ could_determine_scope = could_determine_scope && ScopeIn(rhs, &rhs_scope);
+ if (!could_determine_scope) return kIncomparable;
+
+ // Returns whether a contains b.
+ auto contains = [&](CondId a, CondId b) {
+ // Handle empty states.
+ if (a == nullptr && b != nullptr) return true;
+ if (a == nullptr && b == nullptr) return true;
+ if (a != nullptr && b == nullptr) return false;
+
+ if (a->size() > b->size()) return false;
+ auto a_it = a->begin();
+ auto b_it = b->begin();
+ while (a_it != a->end()) {
+ if (*a_it != *b_it) {
+ if (!(a_it->predicate == b_it->predicate)) return false;
+ BranchType mb = MeetBranch(a_it->branch, b_it->branch);
+ if (mb != b_it->branch) return false;
+ }
+ ++a_it;
+ ++b_it;
+ }
+ return true;
+ };
+
+ bool lhs_contains_rhs = contains(lhs_scope, rhs_scope);
+ bool rhs_contains_lhs = contains(rhs_scope, lhs_scope);
+ if (lhs_contains_rhs && rhs_contains_lhs) return kEqual;
+ if (lhs_contains_rhs) return kLhsContainsRhs;
+ if (rhs_contains_lhs) return kRhsContainsLhs;
+ return kIncomparable;
+}
+
+BranchType CondStateMap::FindBranchOf(CondId id, OutputTensor predicate) const {
+ if (IsEmpty(id)) return BranchType::kNeither;
+ gtl::optional<BranchType> b;
+ const CondState& nodes = *id;
+ for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) {
+ if (it->type == CondStateMap::CondNode::Type::kSwitch &&
+ it->predicate == predicate) {
+ if (b.has_value()) {
+ b = MeetBranch(*b, it->branch);
+ } else {
+ b = it->branch;
+ }
+ if (*b == BranchType::kNeither) {
+ LOG(FATAL) << "Inconsistent state for node: " << DebugString(id);
+ }
+ }
+ }
+ return b.has_value() ? *b : BranchType::kNeither;
+}
+
+StatusOr<CondStateMap::CondId> FunctionalizeCond::JoinCondStatesNonMerge(
+ CondStateMap::CondId src, CondStateMap::CondId dst) {
+ VLOG(4) << "Joining src=" << DebugString(src) << " [" << src
+ << "] and dst=" << DebugString(dst) << " [" << dst << "]";
+
+ if (cond_state_map_.IsEmpty(dst) || cond_state_map_.IsDead(src)) return src;
+ if (cond_state_map_.IsDead(dst)) return dst;
+
+ // Nothing to do if the CondState is the same.
+ if (src == dst) return src;
+
+ CondStateMap::CondId src_scope;
+ CondStateMap::CondId dst_scope;
+ if (!cond_state_map_.ScopeIn(src, &src_scope))
+ return errors::Unimplemented(
+ "Predicates that must hold for node to execute are invalid! ",
+ DebugString(src));
+ if (!cond_state_map_.ScopeIn(dst, &dst_scope))
+ return errors::Unimplemented(
+ "Predicates that must hold for node to execute are invalid! ",
+ DebugString(dst));
+
+ auto result = cond_state_map_.LhsHoldsWhereverRhsHolds(src_scope, dst_scope);
+ switch (result) {
+ case CondStateMap::kIncomparable:
+ return errors::InvalidArgument(
+ "Graph contains node with inputs predicated on incompatible "
+ "predicates: ",
+ DebugString(src), " and ", DebugString(dst));
+ case CondStateMap::kEqual:
+ // If both respect the same predicates, propagate the longer constraint.
+ if ((src != nullptr && dst == nullptr) ||
+ (src != nullptr && dst != nullptr && src->size() > dst->size()))
+ return src;
+ else
+ return dst;
+ case CondStateMap::kLhsContainsRhs:
+ // src contains dst, so dst is already more restrictive.
+ return dst;
+ case CondStateMap::kRhsContainsLhs:
+ // dst contains src, so src is more restrictive.
+ return src;
+ }
+}
+
+StatusOr<CondStateMap::CondState::const_iterator>
+FindThenElseSwitchForPredicate(const OutputTensor& pred,
+ CondStateMap::CondId id) {
+ for (auto it = id->begin(); it != id->end(); ++it) {
+ // Along every path one there can be only one instance of a then or else
+ // switch for a given predicate, so return once found.
+ if (it->type == CondStateMap::CondNode::Type::kSwitch &&
+ it->predicate == pred &&
+ (it->branch == BranchType::kThenBranch ||
+ it->branch == BranchType::kElseBranch))
+ return it;
+ }
+ return errors::Internal("Unable to find then/else branch with predicate ",
+ DebugString(pred), " for ", DebugString(id));
+}
+
+StatusOr<CondStateMap::CondId> FunctionalizeCond::JoinCondStatesMerge(
+ CondStateMap::CondId src, CondStateMap::CondId dst) {
+ // Determine the flow state when joining two states for a merge
+ // node. Combining the two states for a merge node is effectively performing a
+ // disjunction of the states along the different input edges. For a merge that
+ // can be transformed into a If the two inputs paths have to have a predicate
+ // on which they differ (e.g., along one edge predicate `p` has to hold while
+ // on another it should not). This function first determines this predicate
+ // and then the resultant state is the common path between the two inputs
+ // followed by s(p, both).
+ VLOG(4) << "Joining (for merge) " << DebugString(src) << " and "
+ << DebugString(dst);
+ if (cond_state_map_.IsEmpty(dst)) return src;
+
+ if (cond_state_map_.IsDead(src)) return src;
+ if (cond_state_map_.IsDead(dst)) return dst;
+
+ CondStateMap::CondId src_scope;
+ CondStateMap::CondId dst_scope;
+ if (!cond_state_map_.ScopeIn(src, &src_scope))
+ return errors::Unimplemented(
+ "Predicates that must hold for node to execute are invalid! ",
+ DebugString(src));
+ if (!cond_state_map_.ScopeIn(dst, &dst_scope))
+ return errors::Unimplemented(
+ "Predicates that must hold for node to execute are invalid! ",
+ DebugString(dst));
+
+ TF_RET_CHECK(src_scope != nullptr && dst_scope != nullptr)
+ << "Illegal merge inputs from outer scope: src=" << DebugString(src)
+ << " dst=" << DebugString(dst);
+ auto src_it = src_scope->begin();
+ auto dst_it = dst_scope->begin();
+
+ // Find branch divergent condition.
+ OutputTensor pred;
+ while (src_it != src_scope->end() && dst_it != dst_scope->end()) {
+ if (*src_it != *dst_it) {
+ VLOG(5) << "Diverges with: " << DebugString(*src_it) << " and "
+ << DebugString(*dst_it);
+ if (!(src_it->predicate == dst_it->predicate)) {
+ return errors::InvalidArgument(
+ "Unable to find common predicate which holds for one input "
+ "but not the other of the merge node.");
+ }
+ pred = src_it->predicate;
+ break;
+ }
+ ++src_it;
+ ++dst_it;
+ }
+
+ if (pred.node == nullptr)
+ return errors::InvalidArgument("Unable to determine predicate for merge.");
+
+ TF_ASSIGN_OR_RETURN(auto div_src_it,
+ FindThenElseSwitchForPredicate(pred, src));
+ TF_ASSIGN_OR_RETURN(auto div_dst_it,
+ FindThenElseSwitchForPredicate(pred, dst));
+ TF_RET_CHECK(*div_src_it != *div_dst_it);
+
+ CondStateMap::CondState result;
+ // Populate result with the longest/most restrictive path up to the divergent
+ // node. For example, if the one input is `[switch(pred:0, then)]` and the
+ // other is `[switch(pred:0, both), merge, switch(pred:0, else)]` (as created
+ // in gradient of cond test), then the resultant state here should be
+ // `[switch(pred:0, both), merge, switch(pred:0, both)]`.
+ if (std::distance(src->begin(), div_src_it) >
+ std::distance(dst->begin(), div_dst_it)) {
+ result.assign(src->begin(), std::next(div_src_it));
+ } else {
+ result.assign(dst->begin(), std::next(div_dst_it));
+ }
+ result.back().branch = BranchType::kBoth;
+ return cond_state_map_.GetUniqueId(result);
+}
+
+CondStateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) {
+ Node* src = e->src();
+ CondStateMap::CondId id = cond_state_map_.LookupId(e->src());
+ if (IsMerge(src)) {
+ CondStateMap::CondState state;
+ if (id != nullptr) state = *id;
+ state.emplace_back(CondStateMap::CondNode::Type::kMerge);
+ return cond_state_map_.GetUniqueId(state);
+ }
+ if (IsSwitch(src)) {
+ CondStateMap::CondState state;
+ if (id != nullptr) state = *id;
+ if (e->IsControlEdge()) {
+ state.emplace_back(CondStateMap::CondNode::Type::kSwitch, src,
+ BranchType::kBoth);
+ } else {
+ state.emplace_back(CondStateMap::CondNode::Type::kSwitch, src,
+ BranchType(e->src_output()));
+ }
+ return cond_state_map_.GetUniqueId(state);
+ }
+ return id;
+}
+
+Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) {
+ // Only Merge nodes with two inputs are supported, but if this is a redundant
+ // merge, then the dead edge may already have been removed (if due to a
+ // switch) and so the input count would be incorrect.
+ if (cond_state_map_.IsDead(cond_state_map_.LookupId(dst)))
+ return Status::OK();
+
+ int data_inputs = 0;
+ for (auto e : dst->in_edges()) {
+ Node* src = e->src();
+ VLOG(5) << "Processing forward flow for merge: " << e->DebugString() << " "
+ << cond_state_map_.CondStateToString(src);
+ if (!src->IsOp()) continue;
+ if (!e->IsControlEdge()) ++data_inputs;
+
+ CondStateMap::CondId prop = StateAlongEdge(e);
+ auto id_or = JoinCondStatesMerge(prop, cond_state_map_.LookupId(dst));
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", dst->name());
+ cond_state_map_.ResetId(dst, id_or.ValueOrDie());
+ }
+
+ // Incomplete Merge nodes are not supported.
+ if (data_inputs != 2) {
+ return errors::Unimplemented(
+ dst->name(), " only has ", data_inputs,
+ " inputs, while only merge nodes with two inputs supported.");
+ }
+ return Status::OK();
+}
+
+Status FunctionalizeCond::DetermineCondState(Node* dst) {
+ // The logic for the merge and non-merge case differ: for non-merge it is
+ // the most restrictive CondState, while for merge nodes the
+ // resultant state is less restrictive than either.
+ if (IsMerge(dst)) {
+ TF_RETURN_IF_ERROR(DetermineCondStateMerge(dst));
+ } else {
+ // Handle non-merge join.
+ for (auto e : dst->in_edges()) {
+ VLOG(5) << "Processing forward flow for: " << e->DebugString() << " "
+ << cond_state_map_.CondStateToString(dst);
+ Node* src = e->src();
+ if (!src->IsOp()) continue;
+
+ // Joining the state between the current and propagated state.
+ CondStateMap::CondId prop = StateAlongEdge(e);
+ auto id_or = JoinCondStatesNonMerge(prop, cond_state_map_.LookupId(dst));
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", dst->name());
+ cond_state_map_.ResetId(dst, id_or.ValueOrDie());
+ }
+ }
+ return Status::OK();
+}
+
+Status FunctionalizeCond::RemoveRedundantMerge(Node* node) {
+ // Handle redundant merge nodes. A merge node is considered redundant if
+ // one input edge is dead while the other has a value.
+ if (!cond_state_map_.IsDead(cond_state_map_.LookupId(node)))
+ return Status::OK();
+
+ const Edge* non_dead_edge = nullptr;
+ for (auto e : node->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ Node* src = e->src();
+
+ // Handle merge with dead state.
+ const auto& src_id = cond_state_map_.LookupId(src);
+ if (!cond_state_map_.IsDead(src_id)) {
+ non_dead_edge = e;
+ break;
+ }
+ }
+
+ if (non_dead_edge == nullptr) {
+ return errors::InvalidArgument("Merge node ", node->name(),
+ " has no non-dead inputs.");
+ }
+ cond_state_map_.MarkDead(node);
+ delete_nodes_.push_back(node->id());
+ VLOG(5) << "removing redundant merge: " << node->name();
+ while (!node->out_edges().empty()) {
+ const Edge* oe = *node->out_edges().begin();
+ Node* dst_node = oe->dst();
+ int dst_port = oe->dst_input();
+ graph_->RemoveEdge(oe);
+ graph_->AddEdge(non_dead_edge->src(),
+ dst_port == Graph::kControlSlot
+ ? Graph::kControlSlot
+ : non_dead_edge->src_output(),
+ dst_node, dst_port);
+ }
+ return Status::OK();
+}
+
+Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) {
+ // Handle redundant switch nodes. A switch node is considered redundant if
+ // the predicate of the switch already holds on the current branch. E.g., if
+ // p is the predicate of the switch but p is already known to hold on this
+ // branch, then the switch can be removed and the dead state propagated
+ // along one. The checking of predicate is based on the exact predicate
+ // (rather than boolean equivalence) and aimed at redundant switches as
+ // currently generated by gradient code.
+ OutputTensor pred;
+ TF_RETURN_IF_ERROR(GetSwitchPredicate(*node, &pred));
+ auto dst_id = cond_state_map_.LookupId(node);
+ BranchType b = cond_state_map_.FindBranchOf(dst_id, pred);
+ // Determine if we are already on a branch where the switch predicate is
+ // true/false.
+ if (b != BranchType::kThenBranch && b != BranchType::kElseBranch)
+ return Status::OK();
+
+ VLOG(5) << "Redundant switch " << node->name();
+ const Edge* value_edge;
+ TF_RETURN_IF_ERROR(node->input_edge(0, &value_edge));
+ Node* val_node = value_edge->src();
+ int val_port = value_edge->src_output();
+ while (!node->out_edges().empty()) {
+ auto e = *node->out_edges().begin();
+ Node* dst_node = e->dst();
+ int dst_input = e->dst_input();
+ int switch_branch = e->src_output();
+ graph_->RemoveEdge(e);
+ if (switch_branch == Graph::kControlSlot) {
+ if (IsMerge(dst_node)) {
+ auto id_or =
+ JoinCondStatesMerge(dst_id, cond_state_map_.LookupId(dst_node));
+ TF_RETURN_IF_ERROR(id_or.status());
+ cond_state_map_.ResetId(dst_node, id_or.ValueOrDie());
+ } else {
+ auto id_or =
+ JoinCondStatesNonMerge(dst_id, cond_state_map_.LookupId(dst_node));
+ TF_RETURN_IF_ERROR(id_or.status());
+ cond_state_map_.ResetId(dst_node, id_or.ValueOrDie());
+ }
+ } else if (BranchType(switch_branch) != b) {
+ cond_state_map_.MarkDead(dst_node);
+ delete_nodes_.push_back(dst_node->id());
+ continue;
+ }
+ graph_->AddEdge(
+ val_node,
+ switch_branch == Graph::kControlSlot ? Graph::kControlSlot : val_port,
+ dst_node, dst_input);
+ }
+ return Status::OK();
+}
+
+Status FunctionalizeCond::DetermineCondStates(
+ std::vector<Node*> rev_topo_order) {
+ // The state that is propagated along the given edge.
+ for (auto it = rev_topo_order.rbegin(); it != rev_topo_order.rend(); ++it) {
+ Node* dst = *it;
+ TF_RETURN_IF_ERROR(DetermineCondState(dst));
+ if (IsSwitch(dst)) TF_RETURN_IF_ERROR(RemoveRedundantSwitch(dst));
+ if (IsMerge(dst)) TF_RETURN_IF_ERROR(RemoveRedundantMerge(dst));
+
+ VLOG(5) << dst->name() << " :: " << cond_state_map_.CondStateToString(dst);
+ }
+ return Status::OK();
+}
+
+void FunctionalizeCond::DeleteReachableNodes() {
+ // Delete all nodes that have been extracted or are reachable from
+ // deleted/dead nodes. The input and outgoing edges should have already been
+ // removed.
+ std::vector<bool> deleted(graph_->num_node_ids(), false);
+ // Don't try to delete source or sink nodes.
+ deleted[graph_->kSourceId] = true;
+ deleted[graph_->kSinkId] = true;
+ while (!delete_nodes_.empty()) {
+ int d_id = delete_nodes_.front();
+ delete_nodes_.pop_front();
+ if (deleted[d_id]) continue;
+ Node* d = graph_->FindNodeId(d_id);
+ // Switch and Merge nodes could have been deleted already.
+ if (d == nullptr) continue;
+ for (const Edge* e : d->out_edges()) {
+ delete_nodes_.push_back(e->dst()->id());
+ }
+ deleted[d_id] = true;
+ graph_->RemoveNode(d);
+ }
+}
+
+void FunctionalizeCond::SortMergeNodes(std::vector<Node*>* merge_order) {
+ // Sort merge nodes by nesting depth.
+ using sort_pair = std::pair<int, Node*>;
+ std::vector<sort_pair> inner_to_outer_merge_order;
+ inner_to_outer_merge_order.reserve(merge_order->size());
+ for (auto it = merge_order->rbegin(); it != merge_order->rend(); ++it) {
+ Node* merge = *it;
+ CondStateMap::CondId id = cond_state_map_.LookupId(merge);
+ int depth = 0;
+ for (auto cond_node_it = id->begin(); cond_node_it != id->end();
+ ++cond_node_it) {
+ if (cond_node_it->type == CondStateMap::CondNode::Type::kSwitch &&
+ (cond_node_it->branch == BranchType::kThenBranch ||
+ cond_node_it->branch == BranchType::kElseBranch)) {
+ ++depth;
+ }
+ }
+ inner_to_outer_merge_order.emplace_back(depth, merge);
+ }
+ std::stable_sort(
+ inner_to_outer_merge_order.begin(), inner_to_outer_merge_order.end(),
+ [](sort_pair lhs, sort_pair rhs) { return lhs.first > rhs.first; });
+ merge_order->clear();
+ for (sort_pair t : inner_to_outer_merge_order) {
+ merge_order->push_back(t.second);
+ }
+}
+
+Status FunctionalizeCond::FunctionalizeInternal() {
+ // The general approach for converting a tf.cond (as lowered via switch/merge
+ // nodes) to a functional if is as follows:
+ // 1. Determine the topological order and collect all the switch and merge
+ // nodes in the graph;
+ // 2. Compute the predicates and dominance structure for all the nodes in the
+ // graph - this includes which predicate must be true for a op to execute
+ // (predicate values are considered directly rather than attempting to
+ // determine deeper equivalence). We shall refer to this structure as the
+ // CondState;
+ // 3. Sort the merge nodes by nesting depth;
+ // 4. Extract merge nodes together that have the same CondState and whose
+ // input nodes have the same state from the innermost to the outermost into
+ // IfOps; Note: In the above only nodes paths that converge to a merge node
+ // will be considered for removal.
+
+ // Perform a DFS over the graph and
+ // * Determine the reverse topological order of the nodes (there should be no
+ // cycles at this point so the post-order numbering corresponds to the
+ // reverse topological sorting);
+ // * Record reverse topological for merge and switch nodes;
+ std::vector<Node*> rev_topo_order;
+ std::vector<int> switch_ids;
+ std::vector<Node*> merge_order;
+ DFS(*graph_, nullptr, [&](Node* n) {
+ if (IsSwitch(n)) {
+ switch_ids.push_back(n->id());
+ }
+ if (IsMerge(n)) {
+ merge_order.push_back(n);
+ }
+ if (n->IsOp()) {
+ rev_topo_order.push_back(n);
+ }
+ });
+
+ // No merges to functionalize.
+ if (merge_order.empty()) {
+ // No merges mean no switch values consumed (as only considering values
+ // fetchable as output of merge);
+ for (auto it = switch_ids.begin(); it != switch_ids.end(); ++it) {
+ graph_->RemoveNode(graph_->FindNodeId(*it));
+ }
+ return Status::OK();
+ }
+
+ TF_RETURN_IF_ERROR(DetermineCondStates(std::move(rev_topo_order)));
+
+ if (VLOG_IS_ON(4)) DumpGraphWithCondState("cond_id");
+
+ // Sort the merge nodes from innermost outwards.
+ SortMergeNodes(&merge_order);
+
+ // Extract from innermost out.
+ for (auto it = merge_order.begin(); it != merge_order.end(); ++it) {
+ Node* merge = *it;
+ auto id = cond_state_map_.LookupId(merge);
+ if (cond_state_map_.IsDead(id)) continue;
+
+ // Construct a Conditional with the predicate of the merge (which is the
+ // last entry of the CondState for the merge) and this as parent.
+ DCHECK(id->back().predicate.node != nullptr);
+ Conditional cond(id->back().predicate, this, &cond_state_map_);
+ TF_RETURN_IF_ERROR(cond.AddMerge(merge));
+
+ // Find all merge nodes with the same CondId. This is done repeatedly as
+ // the CondId can change due replaced conditionals. E.g., the one branch
+ // could previously have had a conditional nested in it, and so would have
+ // had CondState with sub-state [switch(p,b),m] (where p is some predicate),
+ // post removing the nested conditional that sub-state would no longer be
+ // path of the propagated state along that path.
+ auto end = merge_order.end();
+ for (auto merge_candidate_it = std::next(it); merge_candidate_it != end;
+ ++merge_candidate_it) {
+ auto merge_candidate_it_id =
+ cond_state_map_.LookupId(*merge_candidate_it);
+ if (merge_candidate_it_id != id) continue;
+ TF_RETURN_IF_ERROR(cond.AddMerge(*merge_candidate_it));
+ }
+
+ TF_RETURN_IF_ERROR(cond.BuildAndReplace(graph_, library_));
+
+ if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract");
+ }
+
+ // All remaining Switch nodes are not reachable from a Merge node and
+ // removed. This is to account for dead Switch nodes.
+ for (int s_id : switch_ids) delete_nodes_.push_back(s_id);
+ for (Node* m : merge_order) delete_nodes_.push_back(m->id());
+ DeleteReachableNodes();
+
+ return Status::OK();
+}
+
+void FunctionalizeCond::DumpGraphWithCondState(const string& name) {
+ const char* const kCondGroupDebugAttr = "_XlaFunctionalizeCondGroup";
+
+ for (Node* n : graph_->nodes()) {
+ n->ClearAttr(kCondGroupDebugAttr);
+ n->AddAttr(kCondGroupDebugAttr, cond_state_map_.CondStateToString(n));
+ }
+ LOG(INFO) << "FunctionalizeControlFlow (" << name << "): "
+ << dump_graph::DumpGraphToFile(
+ strings::StrCat("functionalize_", name), *graph_, library_);
+}
+
+Status FunctionalizeCond::Functionalize(Graph* graph,
+ FunctionLibraryDefinition* library) {
+ VLOG(1) << "FunctionalizeCond::Functionalize";
+ FunctionalizeCond fc(graph, library);
+ return fc.FunctionalizeInternal();
+}
+
+} // namespace functionalize_cond
+
+Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) {
+ // FunctionalizeControlFlow is invoked for every function, so the loops's
+ // bodies and conditionals that were extracted into functions will be handled
+ // in successive invocations.
+ return functionalize_cond::FunctionalizeCond::Functionalize(graph, library);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h
new file mode 100644
index 0000000000..86436011c6
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.h
@@ -0,0 +1,248 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_
+#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_
+
+#include <deque>
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+
+// Functionalize all the switch-merge nodes of a loop-free graph into If
+// nodes. That is, attempt to transform every remaining switch and merge nodes
+// in the graph into If nodes.
+// Precondition: All while loops have been removed from graph.
+Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library);
+
+// Internal functions/classes exposed for testing purposes.
+namespace functionalize_cond {
+
+// All nodes are assumed to be either in no branch, then branch, else branch,
+// or both branches (such as merge nodes).
+// The code below relies on Else and Then being 0 and 1 (corresponding to the
+// switch outputs). Both and Neither are arbitrary.
+enum class BranchType {
+ kElseBranch = 0,
+ kThenBranch = 1,
+ kBoth = 2,
+ kNeither = 3,
+};
+
+// CondStateMap is responsible for mapping from each graph Node to a CondState,
+// where each CondState is the array of CondNodes (corresponding to switch,
+// merge or dead states) as described below. For efficiency, this class interns
+// the CondState, so that CondState equality comparisons are simply pointer
+// comparisons.
+class CondStateMap {
+ public:
+ explicit CondStateMap(Graph* graph);
+
+ // Represents an entry in the CondState. An entry can either be the
+ // switch (along with predicate), merge, or dead:
+ // * switch node indicates a node that is executed along a branch with the
+ // given predicate - a branch can be then, else or both;
+ // * merge node indicates that the node is executed as output of a merge;
+ // * dead indicates that this node can never be executed;
+ struct CondNode {
+ enum class Type { kSwitch = 1, kMerge = 2, kDead = 3 };
+
+ CondNode(Type type, Node* switch_node = nullptr,
+ BranchType branch = BranchType::kNeither);
+
+ string ToString() const;
+ bool operator==(const CondNode& other) const;
+ bool operator!=(const CondNode& other) const;
+
+ // Type of node.
+ Type type;
+
+ // Predicate and branch, only used when type is kSwitch.
+ OutputTensor predicate;
+ BranchType branch;
+ };
+
+ // A node in the graph is executed when multiple conditions hold. The order
+ // represents the nesting of the predicates that hold and is used when
+ // extracting the nested conditionals.
+ using CondState = std::vector<CondNode>;
+
+ // Every unique ID is mapped to a CondState.
+ using CondId = const CondState*;
+
+ // Returns the CondId for a given node.
+ CondId LookupId(const Node* node) const;
+
+ // Returns the unique CondId for CondState.
+ CondId GetUniqueId(const CondState& state);
+
+ // Returns the CondState for a Node.
+ // REQUIRES: node has a non-empty CondState.
+ const CondState& LookupState(const Node* node) const;
+
+ // Resets the CondId for a given node.
+ void ResetId(const Node* node, CondId id);
+
+ // Marks `node` as dead.
+ void MarkDead(const Node* node);
+
+ // Determine branch execution of CondState.
+ BranchType FindBranchOf(CondId id, OutputTensor predicate) const;
+
+ // Enum to represent whether one cond flow state contains another.
+ enum ContainsResult {
+ kIncomparable,
+ kEqual,
+ kLhsContainsRhs,
+ kRhsContainsLhs
+ };
+
+ // Returns whether the lhs CondState holds wherever rhs CondState hols. I.e.,
+ // [(p,t)] contains [(p,t), (r,t)].
+ ContainsResult LhsHoldsWhereverRhsHolds(CondId lhs, CondId rhs);
+
+ // Returns textual representation of node's CondState.
+ string CondStateToString(const Node* node) const;
+ string CondStateToString(CondId id) const;
+
+ // Returns whether the cond state is the dead state.
+ bool IsDead(CondId id) const;
+
+ // Returns whether the cond state is the empty state.
+ bool IsEmpty(CondId id) const;
+
+ // Computes the predicates that have to hold for a node to execute and returns
+ // whether it was possible to determine the predicates that must hold. `scope`
+ // is populated with these predicates. Scope differs from state in that it
+ // does not include merge and both nodes.
+ bool ScopeIn(CondId id, CondId* scope);
+
+ private:
+ // Hash for CondNode and CondState.
+ struct CondHash {
+ size_t operator()(const CondNode& item) const;
+ size_t operator()(const CondState& vec) const;
+ };
+
+ // Set to keep track of unique CondStates.
+ // Pointers to the entries in the unordered set are used as identifiers:
+ // unordered_set guarantees that the pointers remain the same.
+ std::unordered_set<CondState, CondHash> condstate_set_;
+
+ // Mapping from Node id to CondId.
+ std::vector<CondId> node_to_condid_map_;
+
+ // Track the CondId for newly inserted nodes. We use a vector to quickly map
+ // from Node id in the original graph to the CondId, but there will be nodes
+ // added to the original graph (such as If nodes) whose CondState needs to be
+ // tracked too.
+ std::unordered_map<int, CondId> added_node_mapping_;
+
+ // Identifier of the dead flow state. The empty flow state is represented with
+ // a nullptr.
+ CondId dead_id_;
+};
+
+// FunctionalizeCond groups all the state used by functionalizing conditionals
+// of the given graph together.
+class FunctionalizeCond {
+ public:
+ // Functionalize all the switch-merge nodes of a loop-free graph into If
+ // nodes. That is, attempt to transform every remaining switch and merge nodes
+ // in the graph into If nodes.
+ // Precondition: All while loops have been removed from graph.
+ static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library);
+
+ // Build identity node with the same name as the merge that will be replaced
+ // in case the output is fetched/colocated.
+ Status AddIdentityNode(const Node* replacee, Node* if_node, int port);
+
+ // Add a If node to the graph defined by def that will, amongst other, replace
+ // replacee in the graph.
+ xla::StatusOr<Node*> AddIfNode(const NodeDef& def, const Node* replacee);
+
+ // Propagates the state of a newly inserted node.
+ Status PropagateUpdatedState(const Node* replacee);
+
+ // Dump graph with the CondState annotated.
+ void DumpGraphWithCondState(const string& name);
+
+ private:
+ FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library);
+
+ // Performs the actual cond functionalization. Iterate over groups of merge
+ // nodes (linked by common predicate & CondIds of the incomming edges),
+ // from innermost to outermost, and extract into If nodes.
+ Status FunctionalizeInternal();
+
+ // Returns the forward flow state propagated along edge `e`.
+ // This may modify cond_state_map_.
+ CondStateMap::CondId StateAlongEdge(const Edge* e);
+
+ // Determines the CondState of all the nodes in the given vector where
+ // the input is expected in reverse topological order.
+ // This populates the cond_state_map_.
+ Status DetermineCondStates(std::vector<Node*> rev_topo_order);
+
+ // Determine the CondState for a given node using the incomming edges
+ // to the node. Note: it is expected that this node's CondState is only
+ // determined once its input's CondState is.
+ Status DetermineCondState(Node* dst);
+
+ // Helper functions for DetermineCondState.
+ Status DetermineCondStateMerge(Node* dst);
+
+ // Helper functions for DetermineCondStates. Determines the dst node's
+ // CondState by joining the src and dst's CondState where either
+ // the dst node is a merge or not.
+ // These may modify cond_state_map_.
+ xla::StatusOr<CondStateMap::CondId> JoinCondStatesMerge(
+ CondStateMap::CondId src, CondStateMap::CondId dst);
+ xla::StatusOr<CondStateMap::CondId> JoinCondStatesNonMerge(
+ CondStateMap::CondId src, CondStateMap::CondId dst);
+
+ // Checks if a merge node is redundant and if so removes it from the graph.
+ Status RemoveRedundantMerge(Node* node);
+
+ // Checks if a switch node is redundant and if so removes it from the graph.
+ Status RemoveRedundantSwitch(Node* node);
+
+ // Sorts merge nodes (in reverse topological order) in order of increasing
+ // nesting depth.
+ void SortMergeNodes(std::vector<Node*>* merge_order);
+
+ // Deletes all nodes in/consumers of `delete_nodes_`.
+ void DeleteReachableNodes();
+
+ // Member used to unique the CondState to a unique CondId and keep track of
+ // CondState/CondId per Node.
+ CondStateMap cond_state_map_;
+
+ // Nodes to be deleted.
+ std::deque<int> delete_nodes_;
+
+ FunctionLibraryDefinition* library_;
+ Graph* graph_;
+
+ friend class FunctionalizeCondTest;
+};
+
+} // namespace functionalize_cond
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc
new file mode 100644
index 0000000000..88a942648f
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc
@@ -0,0 +1,182 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests for the backward const analysis.
+
+#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace functionalize_cond {
+
+class FunctionalizeCondTest : public ::testing::Test {
+ protected:
+ FunctionalizeCondTest() {
+ graph_.reset(new Graph(OpRegistry::Global()));
+ flib_def_.reset(
+ new FunctionLibraryDefinition(OpRegistry::Global(), fdef_lib_));
+ fc_.reset(new functionalize_cond::FunctionalizeCond(graph_.get(),
+ flib_def_.get()));
+ }
+
+ CondStateMap::CondId GetUniqueId(
+ const CondStateMap::CondStateMap::CondState& state) {
+ return fc_->cond_state_map_.GetUniqueId(state);
+ }
+
+ xla::StatusOr<CondStateMap::CondId> JoinCondStatesNonMerge(
+ CondStateMap::CondId src, CondStateMap::CondId dst) {
+ return fc_->JoinCondStatesNonMerge(src, dst);
+ }
+
+ xla::StatusOr<CondStateMap::CondId> JoinCondStatesMerge(
+ CondStateMap::CondId src, CondStateMap::CondId dst) {
+ return fc_->JoinCondStatesMerge(src, dst);
+ }
+
+ bool ScopeIn(CondStateMap::CondId ff, CondStateMap::CondId* scope) {
+ return fc_->cond_state_map_.ScopeIn(ff, scope);
+ }
+
+ CondStateMap::ContainsResult LhsHoldsWhereverRhsHolds(
+ CondStateMap::CondId lhs, CondStateMap::CondId rhs) {
+ return fc_->cond_state_map_.LhsHoldsWhereverRhsHolds(lhs, rhs);
+ }
+
+ FunctionDefLibrary fdef_lib_;
+ std::unique_ptr<functionalize_cond::FunctionalizeCond> fc_;
+ std::unique_ptr<FunctionLibraryDefinition> flib_def_;
+ std::unique_ptr<Graph> graph_;
+};
+
+namespace {
+
+// TODO(jpienaar): Re-enable. Disabling for ASAN failure.
+TEST_F(FunctionalizeCondTest, DISABLED_ScopeIn) {
+ Tensor pred_tensor(DT_BOOL, TensorShape());
+ Node* pred = test::graph::Constant(graph_.get(), pred_tensor, "pred");
+ Tensor val_tensor(DT_INT32, TensorShape());
+ Node* val = test::graph::Constant(graph_.get(), val_tensor, "val");
+ Node* s = test::graph::Switch(graph_.get(), val, pred);
+
+ {
+ CondStateMap::CondStateMap::CondState ss;
+ ss.emplace_back(CondStateMap::CondNode(
+ CondStateMap::CondNode::Type::kSwitch, s, BranchType::kThenBranch));
+ CondStateMap::CondId id = GetUniqueId(ss);
+ CondStateMap::CondId scope;
+ ASSERT_TRUE(ScopeIn(id, &scope));
+ ASSERT_TRUE(id == scope);
+ }
+
+ CondStateMap::CondState empty;
+ {
+ CondStateMap::CondState ss;
+ ss.emplace_back(CondStateMap::CondNode(
+ CondStateMap::CondNode::Type::kSwitch, s, BranchType::kBoth));
+ ss.emplace_back(
+ CondStateMap::CondNode(CondStateMap::CondNode::Type::kMerge));
+ CondStateMap::CondId id = GetUniqueId(ss);
+ CondStateMap::CondId scope_1;
+ ASSERT_TRUE(ScopeIn(id, &scope_1));
+ ASSERT_TRUE(scope_1 == GetUniqueId(empty));
+ ASSERT_TRUE(id != scope_1);
+
+ ss.clear();
+ ss.emplace_back(CondStateMap::CondNode(
+ CondStateMap::CondNode::Type::kSwitch, s, BranchType::kBoth));
+ id = GetUniqueId(ss);
+ CondStateMap::CondId scope_2;
+ ASSERT_TRUE(ScopeIn(id, &scope_2));
+
+ ASSERT_TRUE(LhsHoldsWhereverRhsHolds(scope_1, scope_2) ==
+ CondStateMap::ContainsResult::kLhsContainsRhs);
+ }
+}
+
+// TODO(jpienaar): Re-enable. Disabling for ASAN failure.
+TEST_F(FunctionalizeCondTest, DISABLED_JoinCondStates) {
+ Tensor pred_tensor(DT_BOOL, TensorShape());
+ Node* pred = test::graph::Constant(graph_.get(), pred_tensor, "pred");
+ Tensor val_tensor(DT_INT32, TensorShape());
+ Node* val = test::graph::Constant(graph_.get(), val_tensor, "val");
+ Node* s = test::graph::Switch(graph_.get(), val, pred);
+
+ CondStateMap::CondId empty = GetUniqueId({});
+
+ CondStateMap::CondId then_branch;
+ {
+ CondStateMap::CondState ss;
+ ss.emplace_back(CondStateMap::CondNode(
+ CondStateMap::CondNode::Type::kSwitch, s, BranchType::kThenBranch));
+ then_branch = GetUniqueId(ss);
+ }
+ CondStateMap::CondId else_branch;
+ {
+ CondStateMap::CondState ss;
+ ss.emplace_back(CondStateMap::CondNode(
+ CondStateMap::CondNode::Type::kSwitch, s, BranchType::kElseBranch));
+ else_branch = GetUniqueId(ss);
+ }
+
+ // An non-merge op with inputs from then and else branch.
+ Status status = JoinCondStatesNonMerge(then_branch, else_branch).status();
+ EXPECT_TRUE(errors::IsInvalidArgument(status));
+
+ // Merge between then and else branch.
+ auto joined_or = JoinCondStatesMerge(then_branch, else_branch);
+ TF_EXPECT_OK(joined_or.status());
+ CondStateMap::CondId joined = joined_or.ValueOrDie();
+
+ // Merge between then branch and both branch.
+ auto t = JoinCondStatesNonMerge(then_branch, joined);
+ // Note: this is OK in terms of constraint predication, but
+ TF_EXPECT_OK(t.status());
+
+ // Post merge the propagated forward flow state has an additional merge.
+ CondStateMap::CondId post_merge;
+ {
+ CondStateMap::CondState ss;
+ ss = *joined;
+ ss.emplace_back(
+ CondStateMap::CondNode(CondStateMap::CondNode::Type::kMerge));
+ post_merge = GetUniqueId(ss);
+ }
+
+ t = JoinCondStatesNonMerge(post_merge, joined);
+ TF_EXPECT_OK(t.status());
+ EXPECT_TRUE(joined == t.ValueOrDie());
+
+ // No predicate that results in two paths predicated on different conditions
+ // merge.
+ t = JoinCondStatesMerge(post_merge, joined);
+ EXPECT_FALSE(t.ok());
+
+ // Post the merge we are effectively in the root scope and merging should
+ // result in the more restrictive post merge state.
+ t = JoinCondStatesNonMerge(post_merge, empty);
+ TF_EXPECT_OK(t.status());
+ EXPECT_TRUE(post_merge == t.ValueOrDie());
+}
+
+} // namespace
+} // namespace functionalize_cond
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 0904778f97..188ada7255 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -21,1440 +21,24 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
+#include "tensorflow/compiler/tf2xla/functionalize_while.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
+#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/gtl/optional.h"
namespace tensorflow {
-namespace {
-
-using xla::StatusOr;
-
-const char* const kArgOp = "_Arg";
-const char* const kRetValOp = "_Retval";
-
-// Information about a loop argument.
-struct Arg {
- // Every loop argument has an Enter node.
- Node* enter;
-
- // Is the loop argument a loop-invariant value? Taken from the `is_constant`
- // attribute on the Enter node.
- bool is_loop_invariant;
-
- // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant
- // arguments must have all of the following nodes:
- Node* merge = nullptr;
- Node* switch_node = nullptr;
- Node* next_iteration = nullptr;
- Node* exit = nullptr;
-};
-
-// Information about a loop frame.
-struct Frame {
- string name;
-
- // Pointer to the parent frame. The root frame has a pointer to itself.
- Frame* parent = nullptr;
- int num_children = 0;
-
- // Arguments to this loop.
- std::vector<Arg> args;
-
- // The loop condition of the loop. There should be exactly one loop condition
- // in every loop.
- Node* loop_cond = nullptr;
-
- // Set of nodes that belong to the loop frame.
- std::unordered_set<Node*> nodes;
-};
-
-// Comparison function used for sorting nodes consistently.
-// a) resource variables are last, and
-// b) sort lexicographically by name (for deterministic output).
-struct NodeCmp {
- bool operator()(const Node* lhs, const Node* rhs) const {
- bool lhs_is_resource =
- lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false;
- bool rhs_is_resource =
- rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false;
- return std::tie(lhs_is_resource, lhs->name()) <
- std::tie(rhs_is_resource, rhs->name());
- }
-};
-
-// Returns a textual representation of the names of the nodes in the input.
-template <typename T>
-string NodesToString(const T& nodes) {
- return strings::StrCat("{",
- str_util::Join(nodes, ",",
- [](string* output, const Node* node) {
- strings::StrAppend(output,
- node->name());
- }),
- "}");
-}
-
-// Copies a subgraph from `graph` to `output` by performing a reverse DFS
-// starting at nodes in vector `stack`.
-// `node_map` is a vector indexed by source node ID to dest nodes.
-// Does not traverse into nodes in `node_map`, so by adding nodes to `node_map`
-// before the traversal clients can cut the graph. If a frame is provided (frame
-// != nullptr), then this functions will return an error if the
-// traversal leaves 'frame'; the client must add enough nodes to `node_map` to
-// cut the graph and prevent the traversal from escaping.
-//
-// `squash_src_outputs` contains a bool for each source node ID. If true, then
-// the source output on that node will be replaced by zero when copied. This is
-// used when replacing a Switch node with an _Arg node. The output we are
-// taking from the Switch node was not necessarily the first output, but _Arg
-// nodes only have one output. By adding the Switch node to `squash_src_outputs`
-// we rewrite the src_output of the corresponding edge to be 0.
-Status CopySubgraph(const Graph& graph, const Frame* frame,
- std::vector<Node*> stack,
- const std::vector<bool>& squash_src_outputs,
- std::vector<Node*>* node_map, Graph* output) {
- VLOG(3) << "Stack: " << NodesToString(stack);
- std::vector<bool> visited(graph.num_node_ids(), false);
- while (!stack.empty()) {
- Node* n = stack.back();
- stack.pop_back();
-
- VLOG(5) << "Copying node " << n->name();
-
- if (visited[n->id()]) continue;
- visited[n->id()] = true;
-
- for (const Edge* e : n->in_edges()) {
- Node* src = e->src();
- if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) {
- // We traversed out of the loop frame, without encountering a cut node.
- return errors::Internal("Graph traversal of loop frame ", frame->name,
- " escaped frame at ", src->name(),
- " without encountering an argument node.");
- }
- if ((*node_map)[src->id()] == nullptr) {
- (*node_map)[src->id()] = output->CopyNode(src);
- stack.push_back(src);
- }
- Node* src_copy = (*node_map)[e->src()->id()];
- int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge()
- ? 0
- : e->src_output();
- Node* dst_copy = (*node_map)[e->dst()->id()];
- output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
- }
- }
- return Status::OK();
-}
-
-StatusOr<Node*> AddNode(const NodeDef& node_def, Graph* graph) {
- Status status;
- Node* inserted_node = graph->AddNode(node_def, &status);
- if (!status.ok()) {
- return status;
- }
- return inserted_node;
-}
-
-// Check that the graph has no cycle containing the given node.
-Status CheckNoCycleContains(const Node* node, const int num_nodes) {
- std::vector<const Node*> ready;
- ready.push_back(node);
- std::vector<bool> visited(num_nodes);
- while (!ready.empty()) {
- const Node* current_node = ready.back();
- ready.pop_back();
- visited[current_node->id()] = true;
- for (const Edge* out : current_node->out_edges()) {
- if (out->dst() == node) {
- return errors::Internal("Detected a cycle: ", FormatNodeForError(*node),
- "(", node->def().op(), ") feeds into itself.");
- } else if (!visited[out->dst()->id()]) {
- ready.push_back(out->dst());
- }
- }
- }
- return Status::OK();
-}
-
-StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
- NodeDef arg_def;
- NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp);
- builder.Attr("T", type);
- builder.Attr("index", index);
- TF_RETURN_IF_ERROR(builder.Finalize(&arg_def));
- return AddNode(arg_def, graph);
-}
-
-StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index) {
- NodeDef ret_def;
- ret_def.set_op(kRetValOp);
- ret_def.set_name(strings::StrCat(kRetValOp, index));
- AddNodeAttr("T", type, &ret_def);
- AddNodeAttr("index", index, &ret_def);
- return AddNode(ret_def, graph);
-}
-
-// Builds a graph for the loop condition.
-Status BuildLoopCondition(const Graph& graph, Frame* frame,
- std::unique_ptr<Graph>* cond_output) {
- VLOG(2) << "Building loop condition for " << frame->name;
- *cond_output = xla::MakeUnique<Graph>(graph.op_registry());
- Graph* output = cond_output->get();
-
- // Map from nodes in the original graph to the condition graph.
- std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
- std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
-
- // Build one _Arg node for each Enter node.
- for (int i = 0; i < frame->args.size(); ++i) {
- const Arg& arg = frame->args[i];
-
- TF_ASSIGN_OR_RETURN(Node * arg_node,
- BuildArgNode(output, arg.enter->input_type(0), i));
- if (arg.is_loop_invariant) {
- node_map[arg.enter->id()] = arg_node;
- } else {
- node_map[arg.merge->id()] = arg_node;
- }
- }
-
- // Build a Retval node for the loop condition. The LoopCond nodes are always
- // boolean because of the type constraints on the LoopCond op.
- TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()],
- BuildRetvalNode(output, DT_BOOL, 0));
-
- // Performs a reverse DFS, copying nodes and edges to the output graph.
- // The _Arg and _Retval nodes were added unconditionally above, so we are
- // guaranteed to get the correct function signature.
- return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs,
- &node_map, output);
-}
-
-// Builds a graph for the loop body.
-Status BuildLoopBody(const Graph& graph, Frame* frame,
- DataTypeVector* arg_types,
- std::unique_ptr<Graph>* body_output) {
- VLOG(2) << "Building loop body for " << frame->name;
- *body_output = xla::MakeUnique<Graph>(graph.op_registry());
- Graph* output = body_output->get();
-
- // Map from nodes in the original graph to the condition graph.
- std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
- std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
-
- // Build one _Arg node for each Enter node.
- std::vector<Node*> next_iterations;
- next_iterations.reserve(frame->args.size());
- arg_types->reserve(frame->args.size());
- for (int i = 0; i < frame->args.size(); ++i) {
- const Arg& arg = frame->args[i];
-
- DataType dtype = arg.enter->input_type(0);
- arg_types->push_back(dtype);
-
- TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i));
-
- if (dtype == DT_RESOURCE) {
- // The convention of the XLA bridge is that resource variable arguments
- // are only inputs to the loop body and have no corresponding output.
- // TODO(b/37741920): change the convention so that DT_RESOURCE variables
- // are both inputs and outputs, and then remove this case.
- TF_RET_CHECK(arg.is_loop_invariant);
- node_map[arg.enter->id()] = arg_node;
- } else {
- TF_ASSIGN_OR_RETURN(Node * retval_node,
- BuildRetvalNode(output, dtype, i));
-
- if (arg.is_loop_invariant) {
- // Argument is loop-invariant. Forward it from the Arg to the Retval.
- node_map[arg.enter->id()] = arg_node;
- output->AddEdge(arg_node, 0, retval_node, 0);
- } else {
- // Argument is loop-varying.
- node_map[arg.switch_node->id()] = arg_node;
- // The Switch node has two outputs, but _Arg only has one. This tells
- // the CopySubgraph function to rewrite the output number of edges from
- // the _Arg node to be 0 rather than copying the output number from the
- // Switch node.
- squash_src_outputs[arg.switch_node->id()] = true;
- node_map[arg.next_iteration->id()] = retval_node;
- next_iterations.push_back(arg.next_iteration);
- }
- }
- }
-
- // Performs a reverse DFS, copying nodes and edges to the output graph.
- // The _Arg and _Retval nodes were added unconditionally above, so we are
- // guaranteed to get the correct function signature.
- TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations),
- squash_src_outputs, &node_map, output));
-
- return Status::OK();
-}
-
-// Copy the FunctionDef of given function from lookup_library to library, if
-// it can be found in lookup_library but is missing from library.
-Status AddMissingFunctionByName(const string& function_name,
- const FunctionLibraryDefinition* lookup_library,
- FunctionLibraryDefinition* library) {
- if (!library->Find(function_name) && lookup_library->Find(function_name)) {
- return library->AddFunctionDef(*lookup_library->Find(function_name));
- }
- return Status::OK();
-}
-
-// Iterate over all functions that the given fdef refers to. Copy the missing
-// FunctionDefs from lookup_library to library.
-Status AddMissingFunctionDef(const FunctionDef& fdef,
- const FunctionLibraryDefinition* lookup_library,
- FunctionLibraryDefinition* library) {
- TF_RET_CHECK(lookup_library);
- for (const NodeDef& node : fdef.node_def()) {
- if (library->Find(node.op())) {
- continue;
- }
- // The function referred by 'SymbolicGradient' node is specified in its
- // attribute 'f'.
- if (node.op() == FunctionLibraryDefinition::kGradientOp) {
- const AttrValue* attr =
- AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr);
- if (!attr) {
- return errors::InvalidArgument("SymbolicGradient is missing attr: f");
- }
- const string& func_name = attr->func().name();
- TF_RETURN_IF_ERROR(
- AddMissingFunctionByName(func_name, lookup_library, library));
- // Copy the user-defined gradient function if it exists.
- const string grad_name = lookup_library->FindGradient(func_name);
- if (!grad_name.empty() && library->FindGradient(func_name).empty()) {
- TF_RETURN_IF_ERROR(
- AddMissingFunctionByName(grad_name, lookup_library, library));
- GradientDef grad_def;
- grad_def.set_function_name(func_name);
- grad_def.set_gradient_func(grad_name);
- TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def));
- }
- } else if (lookup_library->Find(node.op())) {
- TF_RETURN_IF_ERROR(
- library->AddFunctionDef(*lookup_library->Find(node.op())));
- }
- }
- return Status::OK();
-}
-
-Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
- Graph* graph, Frame* frame,
- FunctionLibraryDefinition* library) {
- VLOG(2) << "Frame " << frame->name << " before: "
- << dump_graph::DumpGraphToFile("functionalize_before", *graph,
- library);
-
- // Split loop-varying Enter nodes with multiple successors. If the same
- // Tensor is fed as input to multiple loop arguments, we may end up with a
- // shared Enter node. We clone Enter nodes with multiple successors to
- // maintain the invariant of a unique Enter node per argument of the final
- // loop.
- std::vector<Arg> args;
- for (const Arg& arg : frame->args) {
- if (arg.is_loop_invariant) {
- args.push_back(arg);
- } else {
- std::vector<const Edge*> edges(arg.enter->out_edges().begin(),
- arg.enter->out_edges().end());
- for (int i = 0; i < edges.size(); ++i) {
- if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) {
- continue;
- }
- TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name();
- Arg new_arg;
- new_arg.is_loop_invariant = false;
- if (i == 0) {
- new_arg.enter = arg.enter;
- } else {
- new_arg.enter = graph->CopyNode(arg.enter);
- frame->nodes.insert(new_arg.enter);
- for (Edge const* e : arg.enter->in_edges()) {
- graph->AddEdge(e->src(), e->src_output(), new_arg.enter,
- e->IsControlEdge() ? Graph::kControlSlot : 0);
- }
- Node* dst = edges[i]->dst();
- int dst_input = edges[i]->dst_input();
- graph->RemoveEdge(edges[i]);
- graph->AddEdge(new_arg.enter, 0, dst, dst_input);
- }
- args.push_back(new_arg);
- }
- }
- }
- frame->args = std::move(args);
-
- std::sort(
- frame->args.begin(), frame->args.end(),
- [](const Arg& a, const Arg& b) { return NodeCmp()(a.enter, b.enter); });
-
- if (frame->loop_cond == nullptr) {
- return errors::InvalidArgument("Loop ", frame->name,
- " has no LoopCond node");
- }
-
- // Find the set of Switch nodes that are successors of the LoopCond.
- std::unordered_set<Node*> switches;
- for (const Edge* edge : frame->loop_cond->out_edges()) {
- if (!edge->IsControlEdge() && IsSwitch(edge->dst()) &&
- edge->dst_input() == 1) {
- switches.insert(edge->dst());
- }
- }
-
- // For each non-constant argument, looks for the following pattern of nodes:
- // Enter ----> Merge --------> Switch --> Exit
- // ^ ^
- // | |
- // NextIteration LoopCond
- // ^ ^
- // | |
- // ... ...
- for (Arg& arg : frame->args) {
- if (!arg.is_loop_invariant) {
- // Follow the edge from the Enter to Merge.
- const Edge* enter_merge = nullptr;
- for (const Edge* e : arg.enter->out_edges()) {
- // Ignore control-edges to the sink node. These are allowed by the
- // graph invariants, although probably they should have been stripped
- // off earlier.
- if (e->IsControlEdge() && e->dst()->IsSink()) {
- continue;
- }
- if (enter_merge != nullptr) {
- return errors::Internal("Enter node for loop-varying argument ",
- FormatNodeForError(*arg.enter),
- " has multiple successors: ",
- FormatNodeForError(*enter_merge->dst()),
- " and ", FormatNodeForError(*e->dst()));
- }
- enter_merge = e;
- }
- if (enter_merge == nullptr) {
- return errors::Internal("Enter node for loop-varying argument ",
- FormatNodeForError(*arg.enter),
- " has zero successors");
- }
- arg.merge = enter_merge->dst();
- if (!IsMerge(arg.merge)) {
- return errors::InvalidArgument(
- "Successor of Enter node for loop-varying argument ",
- FormatNodeForError(*arg.merge),
- " is not a Merge node; got: ", arg.merge->type_string());
- }
-
- // Find the NextIteration from the merge. There should be two inputs to
- // the Merge and the NextIteration should be the other input.
- if (arg.merge->input_types().size() != 2) {
- return errors::InvalidArgument(
- "Unexpected number of inputs to Merge node for loop-varying "
- "argument ",
- FormatNodeForError(*arg.merge), "; expected 2, got ",
- arg.merge->input_types().size());
- }
- TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(),
- &arg.next_iteration));
- if (!IsNextIteration(arg.next_iteration)) {
- return errors::InvalidArgument(
- "Expected NextIteration node as input to Merge node; got node ",
- FormatNodeForError(*arg.next_iteration), " with kind ",
- arg.next_iteration->type_string());
- }
-
- // Find the Switch successor of the Merge. There should be exactly one
- // Switch node that is a successor of both the Merge and the LoopCond.
- for (const Edge* edge : arg.merge->out_edges()) {
- if (edge->dst_input() == 0 && IsSwitch(edge->dst()) &&
- switches.find(edge->dst()) != switches.end()) {
- if (arg.switch_node != nullptr) {
- return errors::InvalidArgument("Duplicate Switch successors to ",
- FormatNodeForError(*arg.merge));
- }
- arg.switch_node = edge->dst();
- }
- }
- if (arg.switch_node == nullptr) {
- return errors::InvalidArgument("Missing Switch successor to ",
- FormatNodeForError(*arg.merge));
- }
-
- // Update the device on the Identity outputs of the switch to match their
- // target. These Identity outputs do not
-
- // Loop over the switch node's output to:
- // - Find the Exit successor.
- // - Set the sharding on all Identity outputs of the switch. These
- // identity nodes are values used by the loop body or condition.
- // The Identity node may have the wrong device so copy the device from
- // one of its outputs instead.
- std::deque<const Edge*> possible_exit;
- for (const Edge* edge : arg.switch_node->out_edges()) {
- if (edge->src_output() == 0) {
- possible_exit.push_back(edge);
- }
- if (IsIdentity(edge->dst())) {
- TF_RETURN_IF_ERROR(
- SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
- }
- }
- // TODO(b/67425339): Allow general graph between switch and exit.
- while (!possible_exit.empty()) {
- const Edge* edge = possible_exit.front();
- possible_exit.pop_front();
- if (IsExit(edge->dst())) {
- if (arg.exit != nullptr) {
- return errors::InvalidArgument(
- "Duplicate Exit successors to ",
- FormatNodeForError(*arg.switch_node));
- }
- arg.exit = edge->dst();
- } else {
- if (!IsIdentity(edge->dst())) {
- return errors::Unimplemented("General graph between switch (",
- FormatNodeForError(*arg.switch_node),
- ") and exit node of frame ",
- frame->name, " not supported yet.");
- }
- for (const Edge* out : edge->dst()->out_edges()) {
- possible_exit.push_back(out);
- }
- }
- }
- }
- }
-
- // Builds the condition and body functions.
- std::unique_ptr<Graph> cond_graph;
- TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
- DataTypeVector arg_types;
- std::unique_ptr<Graph> body_graph;
- TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
-
- VLOG(2) << "Frame " << frame->name << " condition: "
- << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library)
- << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph);
-
- static std::atomic<int64> sequence_num(0LL);
- int64 id = ++sequence_num;
- NameAttrList cond_name;
- cond_name.set_name(strings::StrCat("_functionalize_cond_", id));
- NameAttrList body_name;
- body_name.set_name(strings::StrCat("_functionalize_body_", id));
- FunctionDef cond_fdef;
- TF_RETURN_IF_ERROR(
- GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef));
- FunctionDef body_fdef;
- TF_RETURN_IF_ERROR(
- GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef));
-
- TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
- TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
- if (lookup_library) {
- // Copy missing FunctionDefs from lookup_library to library to make library
- // self-contained.
- TF_RETURN_IF_ERROR(
- AddMissingFunctionDef(cond_fdef, lookup_library, library));
- TF_RETURN_IF_ERROR(
- AddMissingFunctionDef(body_fdef, lookup_library, library));
- }
-
- // Builds a While operator.
- NodeDef while_def;
- NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile");
- builder.Attr("T", arg_types);
- builder.Attr("cond", cond_name);
- builder.Attr("body", body_name);
- std::vector<NodeDefBuilder::NodeOut> inputs;
- for (int i = 0; i < frame->args.size(); ++i) {
- const Arg& arg = frame->args[i];
- const Edge* in_edge;
- TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
- if (in_edge->IsControlEdge()) {
- builder.ControlInput(in_edge->src()->name());
- } else {
- inputs.push_back(NodeDefBuilder::NodeOut(
- in_edge->src()->name(), in_edge->src_output(), arg_types[i]));
- }
- }
- builder.Input(inputs);
- TF_RETURN_IF_ERROR(builder.Finalize(&while_def));
- TF_ASSIGN_OR_RETURN(Node * while_node, AddNode(while_def, graph));
-
- // Copies edges to the Enter nodes and from the Exit nodes onto the While.
- for (int i = 0; i < frame->args.size(); ++i) {
- const Arg& arg = frame->args[i];
- const Edge* in_edge;
- TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
- if (in_edge->IsControlEdge()) {
- graph->AddControlEdge(in_edge->src(), while_node);
- } else {
- graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i);
- }
-
- if (!arg.is_loop_invariant) {
- // Add output edges if the output of the loop is consumed.
- if (arg.exit != nullptr) {
- std::vector<const Edge*> edges(arg.exit->out_edges().begin(),
- arg.exit->out_edges().end());
- for (const Edge* edge : edges) {
- Node* dst = edge->dst();
- int dst_input = edge->dst_input();
- graph->RemoveEdge(edge);
-
- if (dst_input == Graph::kControlSlot) {
- graph->AddControlEdge(while_node, dst);
- } else {
- graph->AddEdge(while_node, i, dst, dst_input);
- }
- }
- }
- }
- }
-
- // Remove the old nodes from the graph, and add the while node to the parent
- // frame.
- for (Node* node : frame->nodes) {
- graph->RemoveNode(node);
- }
- frame->nodes.clear();
- frame->parent->nodes.insert(while_node);
-
- VLOG(2) << "Frame " << frame->name << " after: "
- << dump_graph::DumpGraphToFile("functionalize_after", *graph,
- library);
-
- return Status::OK();
-}
-
-class FunctionalizeCond {
- public:
- // All nodes are assumed to be either in no branch, then branch, else branch,
- // or both branches (such as merge nodes).
- enum Branch {
- kElseBranch = 0,
- kThenBranch = 1,
- kBoth = 2,
- kNeither = 3,
- kNumBranchTypes = 4
- };
-
- // Returns a textual representation of the Branch b.
- static string Branch_Name(FunctionalizeCond::Branch b);
-
- // Functionalize all the switch-merge nodes of a loop-free graph into XlaIf
- // nodes. That is, attempt to transform every remaining switch and merge nodes
- // in the graph into XlaIf nodes.
- // Precondition: All while loops have been removed from graph.
- static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library);
-
- private:
- // CondArgNode represents a input to the conditional and its corresponding
- // switch nodes.
- struct CondArgNode {
- explicit CondArgNode(Node* src, int src_output)
- : src(src), src_output(src_output) {}
- string ToString() const {
- return strings::StrCat("src=", src->name(), ":", src_output,
- " switches=", NodesToString(switches));
- }
-
- Node* src;
- int src_output;
- std::vector<Node*> switches;
- };
- using CondArgNodes = std::vector<CondArgNode>;
-
- struct ForwardFlowNode {
- explicit ForwardFlowNode(Branch branch = Branch::kNeither)
- : branch(branch), count(0) {}
- string ToString() const {
- return strings::StrCat("branch=", Branch_Name(branch), " count=", count);
- }
- Branch branch;
- int count;
- };
-
- // Group of switch nodes that will be part of the same XlaIf.
- struct SwitchCluster {
- explicit SwitchCluster(const Edge* predicate_edge)
- : predicate_edge(predicate_edge) {}
- string ToString() const {
- return strings::StrCat(name, " predicate=", predicate_edge->src()->name(),
- " switches=", NodesToString(switches));
- }
-
- string name;
- const Edge* predicate_edge;
- std::vector<Node*> switches;
- };
-
- FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library,
- bool dump_graphs)
- : library_(library), graph_(graph), dump_graphs_(dump_graphs) {}
-
- // Perform the actual cond functionalization. Iterate over groups of switch
- // nodes (linked by common predicate), from innermost to outermost, and
- // extract into XlaIf nodes.
- Status FunctionalizeInternal();
-
- // Determines the branch_map (mapping from node to branch of cond) and
- // frontier (the nodes where the cond ends).
- StatusOr<std::pair<std::unordered_map<Node*, ForwardFlowNode>,
- std::unordered_set<Node*>>>
- DetermineBranchMapAndFrontier(const SwitchCluster& switch_cluster);
-
- // Returns XlaIf node created from subgraph of merge and switch nodes. This
- // encapsulates the process of extracting the bodies needed for the then and
- // else branch, creates a XlaIf node, removing the nodes of the branches from
- // the graph and replacing the merge node with a XlaIf.
- StatusOr<Node*> ConvertToXlaIf(const CondArgNodes& cond_arg_nodes,
- const SwitchCluster& switch_cluster,
- const std::vector<Node*>& switches);
-
- // Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with.
- StatusOr<Node*> BuildAndAddXlaIfOp(const CondArgNodes& cond_arg_nodes,
- const SwitchCluster& switch_cluster,
- const std::vector<Node*>& merge_nodes);
-
- // Extracts a function body corresponding to the given input edge of the merge
- // node.
- Status ExtractBody(const CondArgNodes& cond_arg_nodes,
- const std::vector<Node*>& switches,
- const std::vector<Node*>& merge_nodes, int input_edge,
- Graph* body);
-
- // Adds all the input edges to `if_node` corresponding to the arguments.
- Status AddInputEdges(const CondArgNodes& cond_arg_nodes,
- const Edge* predicate_edge, Node* if_node);
-
- // Adds all output edges from the `if_node`.
- Status AddOutputEdges(const std::vector<Node*>& outputs, Node* if_node);
-
- // Returns the switch clusters of graph_ in postorder. Dead switch nodes are
- // skipped and removed from the graph.
- StatusOr<std::vector<SwitchCluster>> DeterminePredicateSwitchOrder();
-
- // Update the state for destination based on the state of source and the node
- // being updated.
- Status Join(const ForwardFlowNode& src_state, const Node* dst,
- ForwardFlowNode* dst_state);
-
- // Ensure that all nodes in the branch_map are dominated by the switch
- // nodes. Returns nodes that are not dominated by the switches but are a
- // control dependency of a node in the cond, and remove such control
- // dependencies.
- StatusOr<std::vector<Node*>> EnsureDominanceAndReturnNonDominatedControlNodes(
- const std::unordered_map<Node*, ForwardFlowNode>& branch_map,
- const std::vector<Node*>& switches);
-
- // Validates that the frontier of nodes for the conditional
- // section are as expected.
- Status ValidateFrontier(
- const std::unordered_map<Node*, ForwardFlowNode>& branch_map,
- const std::unordered_set<Node*>& frontier);
-
- FunctionLibraryDefinition* library_;
- Graph* graph_;
- bool dump_graphs_;
-};
-
-bool IsDeadSwitch(const Node* node) {
- for (const Edge* e : node->out_edges()) {
- const Node* dst = e->dst();
- if (!dst->IsIdentity()) {
- return false;
- }
- for (const Edge* ee : dst->out_edges()) {
- if (!ee->IsControlEdge() || !ee->dst()->IsSink()) {
- return false;
- }
- }
- }
- return true;
-}
-
-string FunctionalizeCond::Branch_Name(FunctionalizeCond::Branch b) {
- const string branch_name[FunctionalizeCond::kNumBranchTypes + 1] = {
- "else", "then", "both", "neither", "count"};
- return branch_name[b];
-}
-
-Status FunctionalizeCond::ValidateFrontier(
- const std::unordered_map<Node*, FunctionalizeCond::ForwardFlowNode>&
- branch_map,
- const std::unordered_set<Node*>& frontier) {
- std::unordered_set<const Node*> pending[kNumBranchTypes];
- for (Node* n : frontier) {
- pending[branch_map.at(n).branch].insert(n);
- }
- TF_RET_CHECK(pending[kNeither].empty()) << NodesToString(pending[kNeither]);
- for (const Node* n : pending[kBoth]) {
- TF_RET_CHECK(IsMerge(n)) << n->DebugString();
- // Merge nodes may be in then or else branch too
- }
- int index = (pending[kThenBranch].size() <= pending[kElseBranch].size())
- ? kThenBranch
- : kElseBranch;
- int other = 1 - index;
- for (const Node* n : pending[index]) {
- if (pending[other].find(n) != pending[other].end()) {
- return errors::Internal(
- "Node (", n->DebugString().c_str(),
- ") in both Else and Then branch should be in Both.");
- }
- }
- // An empty frontier indicates a dead switch. Above we attempt to remove dead
- // switch nodes, but not all are removed so don't treat it as an error yet.
- // TODO(jpienaar): Find out why dead switch nodes remain.
- // if (pending[kBoth].empty() && pending[kThenBranch].empty() &&
- // pending[kElseBranch].empty()) {
- // return errors::Internal("Unexpected empty frontier for switch nodes");
- // }
- return Status::OK();
-}
-
-Status FunctionalizeCond::Join(const ForwardFlowNode& src_state,
- const Node* dst, ForwardFlowNode* dst_state) {
- TF_RET_CHECK(dst_state->branch != Branch::kBoth &&
- dst_state->branch != Branch::kNumBranchTypes)
- << "Unexpected/Invalid branch type: Merging "
- << Branch_Name(src_state.branch) << " with "
- << Branch_Name(dst_state->branch);
- if (dst_state->branch == Branch::kNeither) {
- dst_state->branch = src_state.branch;
- } else if (src_state.branch != dst_state->branch &&
- src_state.branch != Branch::kNeither) {
- if (IsMerge(dst)) {
- dst_state->branch = Branch::kBoth;
- } else {
- return errors::Internal("Illegal merge:\n", src_state.ToString(),
- " with ", dst_state->ToString(), " for\n",
- dst->DebugString());
- }
- }
- ++dst_state->count;
- return Status::OK();
-}
-
-StatusOr<std::vector<FunctionalizeCond::SwitchCluster>>
-FunctionalizeCond::DeterminePredicateSwitchOrder() {
- struct Cluster {
- bool operator==(const Cluster& other) const {
- return representative == other.representative;
- }
- int representative = -1;
- };
-
- // Perform a DFS over the graph and
- // * Determine the reverse topological order of the nodes (there should be no
- // cycles at this point so the post-order numbering corresponds to the
- // reverse topological sorting);
- // * Identify dead switches;
- // * Initialize the cluster's representative;
- std::vector<UnionFind<Cluster>> clusters(graph_->num_node_ids());
- std::vector<Node*> dead_switches;
- std::vector<Node*> switch_order;
- std::vector<Node*> rev_topo_sorted_nodes;
- DFS(*graph_, nullptr, [&](Node* n) {
- clusters[n->id()].Get().representative = n->id();
- if (IsSwitch(n)) {
- if (IsDeadSwitch(n)) {
- dead_switches.push_back(n);
- } else {
- rev_topo_sorted_nodes.push_back(n);
- switch_order.push_back(n);
- }
- } else if (n->IsOp()) {
- // Exclude src and sink nodes from further consideration.
- rev_topo_sorted_nodes.push_back(n);
- }
- });
-
- std::vector<SwitchCluster> switch_clusters;
- // Return early if there are no switches in the graph.
- if (switch_order.empty()) {
- return switch_clusters;
- }
-
- // Remove all dead switch nodes.
- for (Node* n : dead_switches) {
- VLOG(2) << "Removing dead switch: " << n->DebugString();
- graph_->RemoveNode(n);
- }
-
- // Identify switch nodes that are part of the same control flow context by
- // considering the operands of operations: an operation is part of the same
- // control context as its operands unless the operation is a switch. Control
- // dependencies are considered part of the same control flow context if the
- // switch depth is the same (see comment below).
-
- // entry_cluster records the input cluster to a switch node. This is used when
- // merging with a merge node where the dst's cluster is merged with the entry
- // cluster of the merge node's cluster (which corresponds to a switch cluster
- // and so has an entry cluster).
- std::unordered_map<int, UnionFind<Cluster>*> entry_cluster;
-
- // Returns the output cluster of a node. Where the output cluster is cluster
- // where the output of the node is used. For non-merge nodes this is simply
- // the cluster they are part of, while for merge nodes it is the entry cluster
- // of the cluster they are part of (this will correspond to the entry node of
- // a switch node that dominates the merge).
- auto find_output_cluster = [&](Node* n) {
- UnionFind<Cluster>* cluster = &clusters[n->id()];
- if (!IsMerge(n)) return cluster;
- auto it = entry_cluster.find(clusters[n->id()].Get().representative);
- // If the cluster is not found in the entry_cluster map then an
- // instruction not dominated by a switch node has been merged into the
- // cluster of the merge. This indicates a failure of the clustering.
- CHECK(it != entry_cluster.end())
- << "Unable to find entry for n=" << n->id() << " ("
- << cluster->Get().representative << ")";
- return it->second;
- };
-
- // TODO(jpienaar): This could be combined with DetermineBranchMapAndFrontier.
- std::vector<int> switch_depth(graph_->num_node_ids());
- for (auto it = rev_topo_sorted_nodes.rbegin();
- it != rev_topo_sorted_nodes.rend(); ++it) {
- Node* n = *it;
-
- // Compute switch depth.
- int new_switch_depth = 0;
- for (const Edge* e : n->in_edges()) {
- Node* src = e->src();
- new_switch_depth = std::max(
- new_switch_depth, switch_depth[src->id()] - (IsMerge(src) ? 1 : 0));
- }
- switch_depth[n->id()] = new_switch_depth + (IsSwitch(n) ? 1 : 0);
-
- // Only merge the input operands of a switch. The switch's clustering itself
- // is determined by the interaction of the switch's outputs.
- if (IsSwitch(n)) {
- Node* input;
- TF_CHECK_OK(n->input_node(0, &input));
- entry_cluster[n->id()] = find_output_cluster(input);
- UnionFind<Cluster>* cluster = entry_cluster[n->id()];
- int cluster_depth = switch_depth[cluster->Get().representative];
- // Merge the inputs of the switch node with one another. This results in
- // predicates and control input residing in the same cluster.
- for (const Edge* e : n->in_edges()) {
- // Only consider the data inputs to the Switch node.
- if (e->IsControlEdge()) continue;
-
- Node* src = e->src();
- UnionFind<Cluster>* src_cluster = find_output_cluster(src);
- int src_cluster_depth = switch_depth[src_cluster->Get().representative];
- if (cluster_depth != src_cluster_depth) {
- return errors::InvalidArgument(
- "Unable to functionalize control flow in graph: Switch ('",
- n->name(), "') has operands ('", input->name(), "' and '",
- src->name(), "') that have different switch depths (",
- cluster_depth, " != ", src_cluster_depth, ")");
- }
- cluster->Merge(src_cluster);
- }
- continue;
- }
-
- for (const Edge* e : n->in_edges()) {
- Node* src = e->src();
- if (!src->IsOp()) continue;
- UnionFind<Cluster>* cluster = find_output_cluster(src);
- // Merge a node with its data operands and with its control operands if
- // the src and dst are in the same ControlContext. The ControlContext is
- // not explicitly available here, and instead the switch depth is used as
- // a proxy here. Due to the invariant that control edges can only be from
- // a containing scope to an inner scope or from the inner scope to its
- // containing scope (for exit nodes), the switch depth will only match if
- // the src and dst are in the same ControlContext. Control edges between
- // ControlContexts are handled during the extraction.
- int src_id = cluster->Get().representative;
- int src_depth = switch_depth[src_id];
- if (!e->IsControlEdge() || new_switch_depth == src_depth) {
- if (src_depth != new_switch_depth) {
- // TODO(b/77601805) remove this when outside_compilation supports
- // control flow.
- if (str_util::StrContains(src->name(), "outside_compilation") ||
- str_util::StrContains(n->name(), "outside_compilation")) {
- return errors::InvalidArgument(
- "outside_compilation is not yet supported within TensorFlow "
- "control flow constructs b/77601805");
- }
- return errors::InvalidArgument(
- "Unable to functionalize control flow in graph: Operand ('",
- src->name(), "') and operator ('", n->name(),
- "') have different switch depths (", src_depth,
- " != ", new_switch_depth, ")");
- }
- cluster->Merge(&clusters[n->id()]);
- }
- }
- }
-
- if (dump_graphs_) {
- // Mark the switch cluster each node is part of.
- for (Node* n : graph_->nodes()) {
- n->ClearAttr("_XlaFunctionalizeSwitchGroup");
- n->AddAttr("_XlaFunctionalizeSwitchGroup",
- clusters[n->id()].Get().representative);
- }
- LOG(INFO) << "FunctionalizeControlFlow (with_clusters): "
- << dump_graph::DumpGraphToFile("functionalize_clustered", *graph_,
- library_);
- }
-
- // Verify all the nodes of a cluster are at the same depth.
- std::unordered_map<int, std::pair<int, Node*>> cluster_to_depth_node;
- for (Node* n : graph_->nodes()) {
- int depth = switch_depth[n->id()];
- int cluster_rep = clusters[n->id()].Get().representative;
- auto it = cluster_to_depth_node.find(cluster_rep);
- if (it == cluster_to_depth_node.end()) {
- cluster_to_depth_node[cluster_rep] = std::make_pair(depth, n);
- } else {
- if (it->second.first != depth) {
- return errors::Internal(
- "Illegal clustering created, mismatch in depths:", "\n\t",
- n->DebugString(), "(", clusters[n->id()].Get().representative,
- ") at depth=", depth, " vs\n\t", it->second.second->DebugString(),
- "(", clusters[n->id()].Get().representative, ") at depth ",
- it->second.first);
- }
- }
- }
-
- struct Hash {
- size_t operator()(const std::pair<Node*, Cluster>& item) const {
- return Hash64Combine(hash<Node*>()(item.first),
- std::hash<int>()(item.second.representative));
- }
- };
-
- // Merge Switch nodes with common predicate.
- std::unordered_map<std::pair<Node*, Cluster>, int, Hash> predicate_index;
- // The nodes in switch_order are in reverse topological order, but the
- // clustered switches need not be (i.e., when considered as a cluster one
- // element of a cluster may be later in the topological order than another
- // node whose cluster is later in the topological order of clustered
- // switches).
- for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) {
- const Edge* pred_edge;
- TF_CHECK_OK((*it)->input_edge(1, &pred_edge));
- // The predicate can be preceded by a identity node. Look through identity
- // nodes to predicate.
- while (pred_edge->src()->IsIdentity()) {
- TF_CHECK_OK(pred_edge->src()->input_edge(0, &pred_edge));
- }
- auto repr = std::make_pair(pred_edge->src(), clusters[(*it)->id()].Get());
- if (predicate_index.find(repr) == predicate_index.end()) {
- predicate_index[repr] = switch_clusters.size();
- switch_clusters.emplace_back(pred_edge);
- // Generate a name by concatenating with the cluster representative as
- // there could be multiple switch clusters with the same predicate.
- switch_clusters[predicate_index[repr]].name = strings::StrCat(
- pred_edge->src()->name(), "_", repr.second.representative, "_If");
- }
- switch_clusters[predicate_index[repr]].switches.push_back(*it);
- }
-
- return switch_clusters;
-}
-
-StatusOr<std::vector<Node*>>
-FunctionalizeCond::EnsureDominanceAndReturnNonDominatedControlNodes(
- const std::unordered_map<Node*, ForwardFlowNode>& branch_map,
- const std::vector<Node*>& switches) {
- std::vector<Node*> old_control_nodes;
- for (const auto& kv : branch_map) {
- if (kv.second.count != kv.first->in_edges().size()) {
- std::vector<const Edge*> delete_edges;
- for (const Edge* in : kv.first->in_edges()) {
- auto it = branch_map.find(in->src());
- if (it == branch_map.end()) {
- if (in->IsControlEdge()) {
- old_control_nodes.push_back(in->src());
- delete_edges.push_back(in);
- } else {
- if (IsSwitch(in->src())) {
- if (std::find(switches.begin(), switches.end(), in->src()) ==
- switches.end()) {
- return errors::Internal(
- "Unexpected switch node found during flow forward: ",
- in->src()->DebugString());
- }
- continue;
- }
- return errors::InvalidArgument(
- "Value ", kv.first->name(), "'s input, ", in->src()->name(),
- ", is not dominated by switch nodes ", NodesToString(switches));
- }
- }
- }
- // Remove control edges from nodes that are not dominated by the switch
- // nodes. New control dependencies will be added between these nodes and
- // the XlaIf node inserted.
- for (const Edge* e : delete_edges) {
- graph_->RemoveEdge(e);
- }
- }
- }
- return old_control_nodes;
-}
-
-StatusOr<
- std::pair<std::unordered_map<Node*, FunctionalizeCond::ForwardFlowNode>,
- std::unordered_set<Node*>>>
-FunctionalizeCond::DetermineBranchMapAndFrontier(
- const SwitchCluster& switch_cluster) {
- std::unordered_map<Node*, ForwardFlowNode> branch_map;
- std::unordered_set<Node*> frontier;
- std::vector<Node*> stack = switch_cluster.switches;
- std::vector<bool> visited(graph_->num_node_ids(), false);
- while (!stack.empty()) {
- Node* n = stack.back();
- stack.pop_back();
-
- if (visited[n->id()]) {
- continue;
- }
- visited[n->id()] = true;
-
- // Propagate branch state along each edge of a switch node.
- bool sink_only = true;
- for (const Edge* e : n->out_edges()) {
- Node* out = e->dst();
- if (!out->IsOp()) {
- continue;
- }
- sink_only = false;
- // Propagate branch information.
- ForwardFlowNode& ffn = branch_map[out];
- if (IsSwitch(n)) {
- int index = e->IsControlEdge() ? Branch::kNeither : e->src_output();
- TF_RETURN_WITH_CONTEXT_IF_ERROR(
- Join(ForwardFlowNode(Branch(index)), out, &ffn), " when joining ",
- e->DebugString());
- } else {
- TF_RETURN_WITH_CONTEXT_IF_ERROR(Join(branch_map[n], out, &ffn),
- " when joining ", e->DebugString());
- }
- if (IsMerge(out)) {
- if (out->in_edges().size() == ffn.count) {
- frontier.insert(out);
- }
- } else if (!visited[out->id()]) {
- stack.push_back(out);
- }
- }
- if (sink_only) {
- if (!IsIdentity(n)) {
- VLOG(1) << "Feeding into sink: " << n->DebugString();
- }
- }
- }
-
- if (dump_graphs_) {
- for (const auto& kv : branch_map) {
- // Append attribute to the graph if running with logging to make the
- // changes clearer in the visualization.
- kv.first->AddAttr("_XlaFunctionalizeBranch",
- Branch_Name(kv.second.branch));
- }
- }
- return std::make_pair(std::move(branch_map), std::move(frontier));
-}
-
-Status FunctionalizeCond::FunctionalizeInternal() {
- TF_ASSIGN_OR_RETURN(std::vector<SwitchCluster> predicate_switch_order,
- DeterminePredicateSwitchOrder());
-
- // Iterate from innermost set of clustered switches to outermost, replacing
- // matching switch->merge subgraphs with single XlaIf nodes.
- for (auto it = predicate_switch_order.rbegin();
- it != predicate_switch_order.rend(); ++it) {
- auto& ps = *it;
- VLOG(3) << "Flow down from: " << ps.ToString();
-
- std::unordered_map<Node*, ForwardFlowNode> branch_map;
- std::unordered_set<Node*> frontier;
- TF_ASSIGN_OR_RETURN(std::tie(branch_map, frontier),
- DetermineBranchMapAndFrontier(ps));
-
- if (dump_graphs_)
- LOG(INFO) << "FunctionalizeControlFlow (before XlaIf conversion): "
- << dump_graph::DumpGraphToFile("functionalize_bc", *graph_,
- library_);
- TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier));
-
- struct Hash {
- size_t operator()(const std::pair<Node*, int>& item) const {
- return Hash64Combine(hash<Node*>()(item.first),
- std::hash<int>()(item.second));
- }
- };
-
- // Sort the merge and switch nodes using NodeCmp. The switch-nodes are
- // further grouped (post sorting) by input to the switch node as in the
- // functionalized form each input will be passed in only once. This grouping
- // should retain the sorted order.
- CondArgNodes cond_arg_nodes;
- std::sort(ps.switches.begin(), ps.switches.end(), NodeCmp());
- std::unordered_map<std::pair<Node*, int>, int, Hash> input_index;
- for (Node* switch_node : ps.switches) {
- const Edge* e;
- TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e));
- std::pair<Node*, int> key = std::make_pair(e->src(), e->src_output());
- if (input_index.find(key) == input_index.end()) {
- input_index[key] = cond_arg_nodes.size();
- cond_arg_nodes.emplace_back(key.first, key.second);
- }
- cond_arg_nodes.at(input_index.at(key)).switches.push_back(switch_node);
- }
- std::vector<Node*> merge_nodes(frontier.begin(), frontier.end());
- std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp());
-
- TF_ASSIGN_OR_RETURN(std::vector<Node*> old_control_nodes,
- EnsureDominanceAndReturnNonDominatedControlNodes(
- branch_map, ps.switches));
-
- TF_ASSIGN_OR_RETURN(Node * if_node,
- ConvertToXlaIf(cond_arg_nodes, ps, merge_nodes));
- for (Node* old : old_control_nodes) {
- graph_->AddControlEdge(old, if_node);
- }
-
- for (auto& del_kv : branch_map) {
- graph_->RemoveNode(del_kv.first);
- }
- for (auto& kv : cond_arg_nodes) {
- for (Node* node : kv.switches) {
- graph_->RemoveNode(node);
- }
- }
- if (dump_graphs_)
- LOG(INFO) << "FunctionalizeControlFlow (after XlaIf conversion): "
- << dump_graph::DumpGraphToFile("functionalize_ac", *graph_,
- library_);
- }
- return Status::OK();
-}
-
-StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
- const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster,
- const std::vector<Node*>& merge_nodes) {
- VLOG(2) << "Build if op for " << switch_cluster.name;
-
- NodeDef if_def;
- // Create a new If node using the name of the merge node.
- NodeDefBuilder builder(switch_cluster.name, "XlaIf");
- string branch[] = {"else_branch", "then_branch"};
- for (int i = 0; i < 2; ++i) {
- static std::atomic<int64> sequence_num(0LL);
- int64 id = ++sequence_num;
-
- NameAttrList body_name;
- body_name.set_name(
- strings::StrCat("_functionalize_if_", branch[i], "_", id));
- auto body = xla::MakeUnique<Graph>(graph_->op_registry());
- TF_RETURN_IF_ERROR(ExtractBody(cond_arg_nodes, switch_cluster.switches,
- merge_nodes, i, body.get()));
- VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get());
- FunctionDef body_fdef;
- TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef));
- TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef));
- builder.Attr(branch[i], body_name);
- }
-
- // Build input type.
- std::vector<NodeDefBuilder::NodeOut> inputs;
- DataTypeVector in_arg_types;
- for (auto& kv : cond_arg_nodes) {
- bool inserted = false;
- for (const Node* arg : kv.switches) {
- const Edge* in_edge;
- TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
- if (in_edge->IsControlEdge()) {
- builder.ControlInput(in_edge->src()->name());
- } else {
- if (!inserted) {
- DataType dtype = arg->input_type(0);
- inputs.emplace_back(NodeDefBuilder::NodeOut(
- in_edge->src()->name(), in_edge->src_output(), dtype));
- in_arg_types.push_back(dtype);
- inserted = true;
- }
- }
- }
- }
- builder.Attr("Tin", in_arg_types);
-
- // Build output type.
- DataTypeVector out_type;
- for (const Node* merge : merge_nodes) {
- DataType dtype = merge->output_type(0);
- out_type.push_back(dtype);
- }
- builder.Attr("Tout", out_type);
-
- builder.Attr("Tcond", DT_BOOL);
- builder.Device(switch_cluster.predicate_edge->src()->assigned_device_name());
- // Conditional should be the first input ...
- builder.Input(NodeDefBuilder::NodeOut(
- switch_cluster.predicate_edge->src()->name(),
- switch_cluster.predicate_edge->src_output(),
- switch_cluster.predicate_edge->src()->output_type(0)));
- // ... followed by the other inputs.
- builder.Input(inputs);
-
- TF_RETURN_IF_ERROR(builder.Finalize(&if_def));
- TF_ASSIGN_OR_RETURN(Node * if_node, AddNode(if_def, graph_));
- return if_node;
-}
-
-Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes,
- const std::vector<Node*>& switches,
- const std::vector<Node*>& merge_nodes,
- int input_edge, Graph* body) {
- VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge "
- << input_edge;
- std::vector<bool> squash_src_outputs(graph_->num_node_ids(), false);
- std::vector<Node*> node_map(graph_->num_node_ids(), nullptr);
- int arg_count = 0;
- for (auto& kv : cond_arg_nodes) {
- Node* arg_node = nullptr;
- for (const auto* arg : kv.switches) {
- DataType dtype = arg->input_type(0);
- if (arg_node == nullptr) {
- TF_ASSIGN_OR_RETURN(arg_node, BuildArgNode(body, dtype, arg_count++));
- }
- node_map.at(arg->id()) = arg_node;
- squash_src_outputs.at(arg->id()) = true;
- }
- }
-
- std::vector<Node*> stack;
- stack.reserve(merge_nodes.size());
- for (int j = 0; j < merge_nodes.size(); ++j) {
- Node* node = merge_nodes[j];
- TF_ASSIGN_OR_RETURN(node_map.at(node->id()),
- BuildRetvalNode(body, node->output_type(0),
- /*index=*/j));
- const Edge* in_edge;
- TF_RETURN_IF_ERROR(node->input_edge(input_edge, &in_edge));
- Node* in = in_edge->src();
- if (node_map.at(in->id()) == nullptr) {
- node_map.at(in->id()) = body->CopyNode(in);
- }
-
- if (std::find(switches.begin(), switches.end(), in) == switches.end()) {
- body->AddEdge(node_map.at(in->id()), in_edge->src_output(),
- node_map.at(node->id()), 0);
- } else {
- body->AddEdge(node_map.at(in->id()), 0, node_map.at(node->id()), 0);
- // Don't include input nodes that are already just returned in stack.
- continue;
- }
- stack.push_back(in);
- }
-
- return CopySubgraph(*graph_, nullptr, stack, squash_src_outputs, &node_map,
- body);
-}
-
-Status FunctionalizeCond::AddInputEdges(const CondArgNodes& cond_arg_nodes,
- const Edge* predicate_edge,
- Node* if_node) {
- VLOG(3) << "AddInputEdges for " << if_node->name();
- int index = 0;
- graph_->AddEdge(predicate_edge->src(), predicate_edge->src_output(), if_node,
- index++);
- for (auto& arg : cond_arg_nodes) {
- if (arg.src_output == Graph::kControlSlot) {
- graph_->AddControlEdge(arg.src, if_node);
- } else {
- graph_->AddEdge(arg.src, arg.src_output, if_node, index++);
- }
- }
- return Status::OK();
-}
-
-Status FunctionalizeCond::AddOutputEdges(const std::vector<Node*>& outputs,
- Node* if_node) {
- VLOG(3) << "AddOutputEdges for " << if_node->name();
- for (int i = 0; i < outputs.size(); ++i) {
- Node* node = outputs[i];
- std::vector<const Edge*> edges(node->out_edges().begin(),
- node->out_edges().end());
- for (const Edge* edge : edges) {
- Node* dst = edge->dst();
- int dst_input = edge->dst_input();
-
- if (edge->src_output() > 0) {
- return errors::Unimplemented("Output of index (", edge->src_output(),
- ") of merge node ", node->name());
- }
-
- int src_output =
- dst_input == Graph::kControlSlot ? Graph::kControlSlot : i;
- graph_->RemoveEdge(edge);
- graph_->AddEdge(if_node, src_output, dst, dst_input);
- }
- }
- return Status::OK();
-}
-
-StatusOr<Node*> FunctionalizeCond::ConvertToXlaIf(
- const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster,
- const std::vector<Node*>& merge_nodes) {
- VLOG(1) << "ConvertToXlaIf for " << switch_cluster.ToString() << " -> "
- << NodesToString(merge_nodes);
-
- // Extract bodies and builds a If operator.
- TF_ASSIGN_OR_RETURN(
- Node * if_node,
- BuildAndAddXlaIfOp(cond_arg_nodes, switch_cluster, merge_nodes));
- TF_RETURN_IF_ERROR(
- AddInputEdges(cond_arg_nodes, switch_cluster.predicate_edge, if_node));
- TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node));
- // Check that the if_node doesn't feed into itself.
- TF_RETURN_WITH_CONTEXT_IF_ERROR(
- CheckNoCycleContains(if_node, graph_->num_node_ids()),
- "ConvertToXlaIf failed.");
-
- return if_node;
-}
-
-Status FunctionalizeCond::Functionalize(Graph* graph,
- FunctionLibraryDefinition* library) {
- VLOG(1) << "FunctionalizeCond::Functionalize";
- FunctionalizeCond fc(graph, library, /*dump_graphs=*/VLOG_IS_ON(2));
- return fc.FunctionalizeInternal();
-}
-
-} // namespace
-
-// Transformation that converts TensorFlow's graph control flow constructs into
-// functional equivalents.
-Status FunctionalizeControlFlow(Graph* graph,
- FunctionLibraryDefinition* library) {
- return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library);
-}
-
Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
Graph* graph,
FunctionLibraryDefinition* library) {
@@ -1462,98 +46,26 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
<< dump_graph::DumpGraphToFile("functionalize_initial", *graph,
library);
- // Note: BuildControlFlowInfo() requires that the graph's source node is
- // connected to all source nodes in the graph. Many graphs violate this
- // invariant.
- std::vector<ControlFlowInfo> cf_info;
- std::vector<string> unreachable_nodes;
- TF_RETURN_WITH_CONTEXT_IF_ERROR(
- BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes),
- "FunctionalizeControlFlow failed");
- if (!unreachable_nodes.empty()) {
- return errors::InvalidArgument(
- "The following nodes are unreachable from the source in the graph: ",
- errors::FormatNodeNamesForError(unreachable_nodes));
- }
-
- // Builds Frames, indexed by name.
- std::unordered_map<string, Frame> frames;
- for (Node* node : graph->op_nodes()) {
- const ControlFlowInfo& cf = cf_info[node->id()];
-
- VLOG(2) << "node: " << node->name() << " (" << node->id()
- << ") frame_name: " << cf.frame_name
- << " frame: " << (cf.frame ? cf.frame->name() : "---")
- << " parent_frame: "
- << (cf.parent_frame ? cf.parent_frame->name() : "---");
- TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr);
-
- Frame& frame = frames[cf.frame_name];
- Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name];
- if (frame.parent == nullptr) {
- frame.parent = parent;
- frame.name = cf.frame_name;
- ++parent->num_children;
- }
-
- if (IsEnter(node)) {
- Arg arg;
- arg.enter = node;
- TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant",
- &arg.is_loop_invariant));
- frame.args.push_back(arg);
- } else if (IsLoopCond(node)) {
- frame.loop_cond = node;
- }
- frame.nodes.insert(node);
- }
-
- // Adds frames with no children (i.e., the innermost frames) to a worklist.
- std::deque<Frame*> worklist;
- for (auto& frame : frames) {
- if (frame.second.num_children == 0) {
- worklist.push_back(&frame.second);
- }
- }
-
- // Eliminate loops from innermost to outermost.
- while (!worklist.empty()) {
- Frame* frame = worklist.front();
- worklist.pop_front();
- if (frame->parent == frame) {
- // Skip the root frame.
- continue;
- }
-
- TF_RETURN_IF_ERROR(
- FunctionalizeLoop(lookup_library, graph, frame, library));
-
- // If the parent has no remaining children, add it to the worklist.
- --frame->parent->num_children;
- if (frame->parent->num_children == 0) {
- worklist.push_back(frame->parent);
- }
- }
- // There should be no cycle at this point, since while loops have been removed
- // from graph.
- // Check that the newly added XlaWhile nodes don't feed into themselves.
- for (const Node* node : graph->op_nodes()) {
- if (node->def().op() == "XlaWhile") {
- TF_RETURN_WITH_CONTEXT_IF_ERROR(
- CheckNoCycleContains(node, graph->num_node_ids()),
- "FunctionalizeLoop failed.");
- }
- }
+ // Functionalize and remove while loops from graph.
+ TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(lookup_library, graph, library));
// FunctionalizeControlFlow is invoked for every function, so the loops's
// bodies and conditionals that were extracted into functions will be handled
// in successive invocations.
- TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library));
+ TF_RETURN_IF_ERROR(FunctionalizeCond(graph, library));
VLOG(2) << "FunctionalizeControlFlow (final): "
<< dump_graph::DumpGraphToFile("functionalize_final", *graph,
library);
+
return Status::OK();
}
+// Transformation that converts TensorFlow's graph control flow constructs into
+// functional equivalents.
+Status FunctionalizeControlFlow(Graph* graph,
+ FunctionLibraryDefinition* library) {
+ return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
index d941041d15..55600f2a8b 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
@@ -16,14 +16,16 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
// Transformation that converts tf.while_loop() loops into functional While
-// operators, suitable for XLA compilation. If lookup_library is provided, use
-// it to make the library for control flow self-contained.
+// operators and tf.cond() conditionals into function If operators, suitable for
+// XLA compilation. If lookup_library is provided, use it to make the library
+// for control flow self-contained.
Status FunctionalizeControlFlow(Graph* graph,
FunctionLibraryDefinition* library);
Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
index ccf249b35d..cc52057f21 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
@@ -37,12 +37,12 @@ limitations under the License.
namespace tensorflow {
namespace {
-// Returns the names of the "then" and "else" functions for the XlaIf node in a
+// Returns the names of the "then" and "else" functions for the If node in a
// graph.
Status FindIfThenAndElse(const GraphDef& graph, string* op_name,
NameAttrList* then_fn, NameAttrList* else_fn) {
for (const NodeDef& node : graph.node()) {
- if (node.op() == "XlaIf") {
+ if (node.op() == "If") {
*op_name = node.name();
const NameAttrList* result;
TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result));
@@ -52,7 +52,7 @@ Status FindIfThenAndElse(const GraphDef& graph, string* op_name,
return Status::OK();
}
}
- return errors::NotFound("No XlaIf node found in graph");
+ return errors::NotFound("No If node found in graph");
}
// Graph:
@@ -115,8 +115,13 @@ TEST(FunctionalizeControlFlow, Conditional) {
auto if_op = ops::XlaIf(scope.WithOpName(op_name), less,
std::initializer_list<Input>{less, y, x}, then_fn,
else_fn, {DT_INT32});
+ auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
+ // TODO(jpienaar): Create wrapper for IfOp.
+ for (NodeDef& n : *expected.mutable_node()) {
+ if (n.op() == "XlaIf") n.set_op("If");
+ }
TF_EXPECT_GRAPH_EQ(expected, graph_def);
}
@@ -1013,63 +1018,5 @@ TEST(FunctionalizeControlFlow, Complex) {
}
}
-TEST(FunctionalizeControlFlow, Cycle) {
- std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
- // -----------------------------------------------------
- // | |
- // | v
- // less -> switch_1 --> add -> merge_1 -> identity -> switch_2
- // | ^ |
- // | | v
- // --------> one -------------------------> add_2 ---> merge_2
- {
- Scope scope = Scope::NewRootScope().ExitOnError();
-
- auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
- auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
- auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
- auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), x, less);
- auto two =
- ops::Const<int32>(scope.WithOpName("cond/two")
- .WithControlDependencies(switch_1.output_true),
- 2);
- auto mul = ops::Multiply(scope.WithOpName("cond/true/mul"),
- switch_1.output_true, two);
- auto one =
- ops::Const<int32>(scope.WithOpName("cond/one")
- .WithControlDependencies(switch_1.output_false),
- 1);
- auto add = ops::Add(scope.WithOpName("cond/false/add"),
- switch_1.output_false, one);
-
- auto merge_1 = ops::Merge(scope.WithOpName("cond/Merge"),
- std::initializer_list<Input>{add, mul});
- auto identity =
- ops::Identity(scope.WithOpName("cond/Merge/identity"), merge_1.output);
- auto switch_2 =
- ops::Switch(scope.WithOpName("grad/cond/Switch"), identity, less);
- auto add_2 = ops::Add(scope.WithOpName("cond_2/false/add"),
- switch_2.output_false, one);
- auto mul_2 = ops::Multiply(scope.WithOpName("cond_2/true/mul"),
- switch_2.output_true, two);
- auto merge_2 = ops::Merge(scope.WithOpName("cond_2/Merge"),
- std::initializer_list<Input>{add_2, mul_2});
- TF_ASSERT_OK(scope.ToGraph(graph.get()));
- }
- // No cycle before functionalize control flow.
- TF_EXPECT_OK(graph::ValidateGraphHasNoCycle(*graph));
- FunctionLibraryDefinition library(OpRegistry::Global(), {});
- // switch_1 and switch_2 have the same switch depth. They are replaced by a
- // single XlaIf node during FunctionalizeControlFlow, resulting in a cycle:
- // less -> XlaIf <--> identity.
- Status status = FunctionalizeControlFlow(graph.get(), &library);
- EXPECT_FALSE(status.ok());
- EXPECT_TRUE(str_util::StrContains(status.error_message(), "Detected a cycle"))
- << status.error_message();
- EXPECT_TRUE(
- str_util::StrContains(status.error_message(), "{{node cond/Less_5_If}}"))
- << status.error_message();
-}
-
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc
new file mode 100644
index 0000000000..924fcdd9cd
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc
@@ -0,0 +1,72 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
+
+#include "tensorflow/core/framework/node_def.pb.h"
+
+namespace tensorflow {
+
+bool NodeCmpByNameResourcesLast::operator()(const Node* lhs,
+ const Node* rhs) const {
+ bool lhs_is_resource =
+ lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false;
+ bool rhs_is_resource =
+ rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false;
+ return std::tie(lhs_is_resource, lhs->name()) <
+ std::tie(rhs_is_resource, rhs->name());
+}
+
+xla::StatusOr<Node*> AddNodeDefToGraph(const NodeDef& node_def, Graph* graph) {
+ Status status;
+ Node* inserted_node = graph->AddNode(node_def, &status);
+ if (!status.ok()) {
+ return status;
+ }
+ return inserted_node;
+}
+
+xla::StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index) {
+ const char* const kRetValOp = "_Retval";
+ NodeDef ret_def;
+ ret_def.set_op(kRetValOp);
+ ret_def.set_name(strings::StrCat(kRetValOp, index));
+ AddNodeAttr("T", type, &ret_def);
+ AddNodeAttr("index", index, &ret_def);
+ return AddNodeDefToGraph(ret_def, graph);
+}
+
+// Check that the graph has no cycle containing the given node.
+Status CheckNodeNotInCycle(const Node* node, const int num_nodes) {
+ std::vector<const Node*> ready;
+ ready.push_back(node);
+ std::vector<bool> visited(num_nodes);
+ while (!ready.empty()) {
+ const Node* current_node = ready.back();
+ ready.pop_back();
+ visited[current_node->id()] = true;
+ for (const Edge* out : current_node->out_edges()) {
+ if (out->dst() == node) {
+ return errors::Internal("Detected a cycle: ", FormatNodeForError(*node),
+ " (", node->def().op(), ") feeds into itself.");
+ } else if (!visited[out->dst()->id()]) {
+ ready.push_back(out->dst());
+ }
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h
new file mode 100644
index 0000000000..a0544b69e9
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h
@@ -0,0 +1,56 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_
+#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_
+
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/graph/graph.h"
+
+// Utility functions shared between functionalize cond and while.
+
+namespace tensorflow {
+
+// Check that the graph has no cycle containing the given node.
+Status CheckNodeNotInCycle(const Node* node, const int num_nodes);
+
+// Comparison function used for sorting nodes consistently.
+// a) resource variables are last, and
+// b) sort lexicographically by name (for deterministic output).
+struct NodeCmpByNameResourcesLast {
+ bool operator()(const Node* lhs, const Node* rhs) const;
+};
+
+// Returns the Node* created from the NodeDef in the Graph.
+xla::StatusOr<Node*> AddNodeDefToGraph(const NodeDef& node_def, Graph* graph);
+
+// Build a retval node of given type and index.
+xla::StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index);
+
+// Returns a textual representation of the names of the nodes in the input.
+template <typename T>
+string NodesToString(const T& nodes) {
+ return strings::StrCat("{",
+ str_util::Join(nodes, ",",
+ [](string* output, const Node* node) {
+ strings::StrAppend(output,
+ node->name());
+ }),
+ "}");
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc
new file mode 100644
index 0000000000..4fd134c698
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/functionalize_while.cc
@@ -0,0 +1,668 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/functionalize_while.h"
+
+#include <algorithm>
+#include <deque>
+#include <stack>
+#include <unordered_set>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/compiler/jit/union_find.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/control_flow.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/gtl/optional.h"
+
+namespace tensorflow {
+namespace {
+
+using xla::StatusOr;
+
+// Information about a loop argument.
+struct Arg {
+ // Every loop argument has an Enter node.
+ Node* enter;
+
+ // Is the loop argument a loop-invariant value? Taken from the `is_constant`
+ // attribute on the Enter node.
+ bool is_loop_invariant;
+
+ // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant
+ // arguments must have all of the following nodes:
+ Node* merge = nullptr;
+ Node* switch_node = nullptr;
+ Node* next_iteration = nullptr;
+ Node* exit = nullptr;
+};
+
+// Information about a loop frame.
+struct Frame {
+ string name;
+
+ // Pointer to the parent frame. The root frame has a pointer to itself.
+ Frame* parent = nullptr;
+ int num_children = 0;
+
+ // Arguments to this loop.
+ std::vector<Arg> args;
+
+ // The loop condition of the loop. There should be exactly one loop condition
+ // in every loop.
+ Node* loop_cond = nullptr;
+
+ // Set of nodes that belong to the loop frame.
+ std::unordered_set<Node*> nodes;
+};
+
+// Copies a subgraph from `graph` to `output` by performing a reverse DFS
+// starting at nodes in vector `stack`.
+// `node_map` is a vector indexed by source node ID to dest nodes.
+// Does not traverse into nodes in `node_map`, so by adding nodes to `node_map`
+// before the traversal clients can cut the graph. If a frame is provided (frame
+// != nullptr), then this functions will return an error if the
+// traversal leaves 'frame'; the client must add enough nodes to `node_map` to
+// cut the graph and prevent the traversal from escaping.
+//
+// `squash_src_outputs` contains a bool for each source node ID. If true, then
+// the source output on that node will be replaced by zero when copied. This is
+// used when replacing a Switch node with an _Arg node. The output we are
+// taking from the Switch node was not necessarily the first output, but _Arg
+// nodes only have one output. By adding the Switch node to `squash_src_outputs`
+// we rewrite the src_output of the corresponding edge to be 0.
+Status CopySubgraph(const Graph& graph, const Frame* frame,
+ std::vector<Node*> stack,
+ const std::vector<bool>& squash_src_outputs,
+ std::vector<Node*>* node_map, Graph* output) {
+ VLOG(3) << "Stack: " << NodesToString(stack);
+ std::vector<bool> visited(graph.num_node_ids(), false);
+ while (!stack.empty()) {
+ Node* n = stack.back();
+ stack.pop_back();
+
+ VLOG(5) << "Copying node " << n->name();
+
+ if (visited[n->id()]) continue;
+ visited[n->id()] = true;
+
+ for (const Edge* e : n->in_edges()) {
+ Node* src = e->src();
+ if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) {
+ // We traversed out of the loop frame, without encountering a cut node.
+ return errors::Internal("Graph traversal of loop frame ", frame->name,
+ " escaped frame at ", src->name(),
+ " without encountering an argument node.");
+ }
+ if ((*node_map)[src->id()] == nullptr) {
+ (*node_map)[src->id()] = output->CopyNode(src);
+ stack.push_back(src);
+ }
+ Node* src_copy = (*node_map)[e->src()->id()];
+ int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge()
+ ? 0
+ : e->src_output();
+ Node* dst_copy = (*node_map)[e->dst()->id()];
+ output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
+ }
+ }
+ return Status::OK();
+}
+
+StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
+ const char* const kArgOp = "_Arg";
+ NodeDef arg_def;
+ NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp);
+ builder.Attr("T", type);
+ builder.Attr("index", index);
+ TF_RETURN_IF_ERROR(builder.Finalize(&arg_def));
+ return AddNodeDefToGraph(arg_def, graph);
+}
+
+// Builds a graph for the loop condition.
+Status BuildLoopCondition(const Graph& graph, Frame* frame,
+ std::unique_ptr<Graph>* cond_output) {
+ VLOG(2) << "Building loop condition for " << frame->name;
+ *cond_output = absl::make_unique<Graph>(graph.op_registry());
+ Graph* output = cond_output->get();
+
+ // Map from nodes in the original graph to the condition graph.
+ std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
+ std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
+
+ // Build one _Arg node for each Enter node.
+ for (int i = 0; i < frame->args.size(); ++i) {
+ const Arg& arg = frame->args[i];
+
+ TF_ASSIGN_OR_RETURN(Node * arg_node,
+ BuildArgNode(output, arg.enter->input_type(0), i));
+ if (arg.is_loop_invariant) {
+ node_map[arg.enter->id()] = arg_node;
+ } else {
+ node_map[arg.merge->id()] = arg_node;
+ }
+ }
+
+ // Build a Retval node for the loop condition. The LoopCond nodes are always
+ // boolean because of the type constraints on the LoopCond op.
+ TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()],
+ BuildRetvalNode(output, DT_BOOL, 0));
+
+ // Performs a reverse DFS, copying nodes and edges to the output graph.
+ // The _Arg and _Retval nodes were added unconditionally above, so we are
+ // guaranteed to get the correct function signature.
+ return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs,
+ &node_map, output);
+}
+
+// Builds a graph for the loop body.
+Status BuildLoopBody(const Graph& graph, Frame* frame,
+ DataTypeVector* arg_types,
+ std::unique_ptr<Graph>* body_output) {
+ VLOG(2) << "Building loop body for " << frame->name;
+ *body_output = absl::make_unique<Graph>(graph.op_registry());
+ Graph* output = body_output->get();
+
+ // Map from nodes in the original graph to the condition graph.
+ std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
+ std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
+
+ // Build one _Arg node for each Enter node.
+ std::vector<Node*> next_iterations;
+ next_iterations.reserve(frame->args.size());
+ arg_types->reserve(frame->args.size());
+ for (int i = 0; i < frame->args.size(); ++i) {
+ const Arg& arg = frame->args[i];
+
+ DataType dtype = arg.enter->input_type(0);
+ arg_types->push_back(dtype);
+
+ TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i));
+
+ if (dtype == DT_RESOURCE) {
+ // The convention of the XLA bridge is that resource variable arguments
+ // are only inputs to the loop body and have no corresponding output.
+ // TODO(b/37741920): change the convention so that DT_RESOURCE variables
+ // are both inputs and outputs, and then remove this case.
+ TF_RET_CHECK(arg.is_loop_invariant);
+ node_map[arg.enter->id()] = arg_node;
+ } else {
+ TF_ASSIGN_OR_RETURN(Node * retval_node,
+ BuildRetvalNode(output, dtype, i));
+
+ if (arg.is_loop_invariant) {
+ // Argument is loop-invariant. Forward it from the Arg to the Retval.
+ node_map[arg.enter->id()] = arg_node;
+ output->AddEdge(arg_node, 0, retval_node, 0);
+ } else {
+ // Argument is loop-varying.
+ node_map[arg.switch_node->id()] = arg_node;
+ // The Switch node has two outputs, but _Arg only has one. This tells
+ // the CopySubgraph function to rewrite the output number of edges from
+ // the _Arg node to be 0 rather than copying the output number from the
+ // Switch node.
+ squash_src_outputs[arg.switch_node->id()] = true;
+ node_map[arg.next_iteration->id()] = retval_node;
+ next_iterations.push_back(arg.next_iteration);
+ }
+ }
+ }
+
+ // Performs a reverse DFS, copying nodes and edges to the output graph.
+ // The _Arg and _Retval nodes were added unconditionally above, so we are
+ // guaranteed to get the correct function signature.
+ TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations),
+ squash_src_outputs, &node_map, output));
+
+ return Status::OK();
+}
+
+// Copy the FunctionDef of given function from lookup_library to library, if
+// it can be found in lookup_library but is missing from library.
+Status AddMissingFunctionByName(const string& function_name,
+ const FunctionLibraryDefinition* lookup_library,
+ FunctionLibraryDefinition* library) {
+ if (!library->Find(function_name) && lookup_library->Find(function_name)) {
+ return library->AddFunctionDef(*lookup_library->Find(function_name));
+ }
+ return Status::OK();
+}
+
+// Iterate over all functions that the given fdef refers to. Copy the missing
+// FunctionDefs from lookup_library to library.
+Status AddMissingFunctionDef(const FunctionDef& fdef,
+ const FunctionLibraryDefinition* lookup_library,
+ FunctionLibraryDefinition* library) {
+ TF_RET_CHECK(lookup_library);
+ for (const NodeDef& node : fdef.node_def()) {
+ if (library->Find(node.op())) {
+ continue;
+ }
+ // The function referred by 'SymbolicGradient' node is specified in its
+ // attribute 'f'.
+ if (node.op() == FunctionLibraryDefinition::kGradientOp) {
+ const AttrValue* attr =
+ AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr);
+ if (!attr) {
+ return errors::InvalidArgument("SymbolicGradient is missing attr: f");
+ }
+ const string& func_name = attr->func().name();
+ TF_RETURN_IF_ERROR(
+ AddMissingFunctionByName(func_name, lookup_library, library));
+ // Copy the user-defined gradient function if it exists.
+ const string grad_name = lookup_library->FindGradient(func_name);
+ if (!grad_name.empty() && library->FindGradient(func_name).empty()) {
+ TF_RETURN_IF_ERROR(
+ AddMissingFunctionByName(grad_name, lookup_library, library));
+ GradientDef grad_def;
+ grad_def.set_function_name(func_name);
+ grad_def.set_gradient_func(grad_name);
+ TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def));
+ }
+ } else if (lookup_library->Find(node.op())) {
+ TF_RETURN_IF_ERROR(
+ library->AddFunctionDef(*lookup_library->Find(node.op())));
+ }
+ }
+ return Status::OK();
+}
+
+Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
+ Graph* graph, Frame* frame,
+ FunctionLibraryDefinition* library) {
+ VLOG(2) << "Frame " << frame->name << " before: "
+ << dump_graph::DumpGraphToFile("functionalize_before", *graph,
+ library);
+
+ // Split loop-varying Enter nodes with multiple successors. If the same
+ // Tensor is fed as input to multiple loop arguments, we may end up with a
+ // shared Enter node. We clone Enter nodes with multiple successors to
+ // maintain the invariant of a unique Enter node per argument of the final
+ // loop.
+ std::vector<Arg> args;
+ for (const Arg& arg : frame->args) {
+ if (arg.is_loop_invariant) {
+ args.push_back(arg);
+ } else {
+ std::vector<const Edge*> edges(arg.enter->out_edges().begin(),
+ arg.enter->out_edges().end());
+ for (int i = 0; i < edges.size(); ++i) {
+ if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) {
+ continue;
+ }
+ TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name();
+ Arg new_arg;
+ new_arg.is_loop_invariant = false;
+ if (i == 0) {
+ new_arg.enter = arg.enter;
+ } else {
+ new_arg.enter = graph->CopyNode(arg.enter);
+ frame->nodes.insert(new_arg.enter);
+ for (Edge const* e : arg.enter->in_edges()) {
+ graph->AddEdge(e->src(), e->src_output(), new_arg.enter,
+ e->IsControlEdge() ? Graph::kControlSlot : 0);
+ }
+ Node* dst = edges[i]->dst();
+ int dst_input = edges[i]->dst_input();
+ graph->RemoveEdge(edges[i]);
+ graph->AddEdge(new_arg.enter, 0, dst, dst_input);
+ }
+ args.push_back(new_arg);
+ }
+ }
+ }
+ frame->args = std::move(args);
+
+ std::sort(frame->args.begin(), frame->args.end(),
+ [](const Arg& a, const Arg& b) {
+ return NodeCmpByNameResourcesLast()(a.enter, b.enter);
+ });
+
+ if (frame->loop_cond == nullptr) {
+ return errors::InvalidArgument("Loop ", frame->name,
+ " has no LoopCond node");
+ }
+
+ // Find the set of Switch nodes that are successors of the LoopCond.
+ std::unordered_set<Node*> switches;
+ for (const Edge* edge : frame->loop_cond->out_edges()) {
+ if (!edge->IsControlEdge() && IsSwitch(edge->dst()) &&
+ edge->dst_input() == 1) {
+ switches.insert(edge->dst());
+ }
+ }
+
+ // For each non-constant argument, looks for the following pattern of nodes:
+ // Enter ----> Merge --------> Switch --> Exit
+ // ^ ^
+ // | |
+ // NextIteration LoopCond
+ // ^ ^
+ // | |
+ // ... ...
+ for (Arg& arg : frame->args) {
+ if (!arg.is_loop_invariant) {
+ // Follow the edge from the Enter to Merge.
+ const Edge* enter_merge = nullptr;
+ for (const Edge* e : arg.enter->out_edges()) {
+ // Ignore control-edges to the sink node. These are allowed by the
+ // graph invariants, although probably they should have been stripped
+ // off earlier.
+ if (e->IsControlEdge() && e->dst()->IsSink()) {
+ continue;
+ }
+ if (enter_merge != nullptr) {
+ return errors::Internal("Enter node for loop-varying argument ",
+ FormatNodeForError(*arg.enter),
+ " has multiple successors: ",
+ FormatNodeForError(*enter_merge->dst()),
+ " and ", FormatNodeForError(*e->dst()));
+ }
+ enter_merge = e;
+ }
+ if (enter_merge == nullptr) {
+ return errors::Internal("Enter node for loop-varying argument ",
+ FormatNodeForError(*arg.enter),
+ " has zero successors");
+ }
+ arg.merge = enter_merge->dst();
+ if (!IsMerge(arg.merge)) {
+ return errors::InvalidArgument(
+ "Successor of Enter node for loop-varying argument ",
+ FormatNodeForError(*arg.merge),
+ " is not a Merge node; got: ", arg.merge->type_string());
+ }
+
+ // Find the NextIteration from the merge. There should be two inputs to
+ // the Merge and the NextIteration should be the other input.
+ if (arg.merge->input_types().size() != 2) {
+ return errors::InvalidArgument(
+ "Unexpected number of inputs to Merge node for loop-varying "
+ "argument ",
+ FormatNodeForError(*arg.merge), "; expected 2, got ",
+ arg.merge->input_types().size());
+ }
+ TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(),
+ &arg.next_iteration));
+ if (!IsNextIteration(arg.next_iteration)) {
+ return errors::InvalidArgument(
+ "Expected NextIteration node as input to Merge node; got node ",
+ FormatNodeForError(*arg.next_iteration), " with kind ",
+ arg.next_iteration->type_string());
+ }
+
+ // Find the Switch successor of the Merge. There should be exactly one
+ // Switch node that is a successor of both the Merge and the LoopCond.
+ for (const Edge* edge : arg.merge->out_edges()) {
+ if (edge->dst_input() == 0 && IsSwitch(edge->dst()) &&
+ switches.find(edge->dst()) != switches.end()) {
+ if (arg.switch_node != nullptr) {
+ return errors::InvalidArgument("Duplicate Switch successors to ",
+ FormatNodeForError(*arg.merge));
+ }
+ arg.switch_node = edge->dst();
+ }
+ }
+ if (arg.switch_node == nullptr) {
+ return errors::InvalidArgument("Missing Switch successor to ",
+ FormatNodeForError(*arg.merge));
+ }
+
+ // Update the device on the Identity outputs of the switch to match their
+ // target. These Identity outputs do not
+
+ // Loop over the switch node's output to:
+ // - Find the Exit successor.
+ // - Set the sharding on all Identity outputs of the switch. These
+ // identity nodes are values used by the loop body or condition.
+ // The Identity node may have the wrong device so copy the device from
+ // one of its outputs instead.
+ std::deque<const Edge*> possible_exit;
+ for (const Edge* edge : arg.switch_node->out_edges()) {
+ if (edge->src_output() == 0) {
+ possible_exit.push_back(edge);
+ }
+ if (IsIdentity(edge->dst())) {
+ TF_RETURN_IF_ERROR(
+ SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
+ }
+ }
+ // TODO(b/67425339): Allow general graph between switch and exit.
+ while (!possible_exit.empty()) {
+ const Edge* edge = possible_exit.front();
+ possible_exit.pop_front();
+ if (IsExit(edge->dst())) {
+ if (arg.exit != nullptr) {
+ return errors::InvalidArgument(
+ "Duplicate Exit successors to ",
+ FormatNodeForError(*arg.switch_node));
+ }
+ arg.exit = edge->dst();
+ } else {
+ if (!IsIdentity(edge->dst())) {
+ return errors::Unimplemented("General graph between switch (",
+ FormatNodeForError(*arg.switch_node),
+ ") and exit node of frame ",
+ frame->name, " not supported yet.");
+ }
+ for (const Edge* out : edge->dst()->out_edges()) {
+ possible_exit.push_back(out);
+ }
+ }
+ }
+ }
+ }
+
+ // Builds the condition and body functions.
+ std::unique_ptr<Graph> cond_graph;
+ TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
+ DataTypeVector arg_types;
+ std::unique_ptr<Graph> body_graph;
+ TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
+
+ VLOG(2) << "Frame " << frame->name << " condition: "
+ << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library)
+ << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph);
+
+ static std::atomic<int64> sequence_num(0LL);
+ int64 id = ++sequence_num;
+ NameAttrList cond_name;
+ cond_name.set_name(strings::StrCat("_functionalize_cond_", id));
+ NameAttrList body_name;
+ body_name.set_name(strings::StrCat("_functionalize_body_", id));
+ FunctionDef cond_fdef;
+ TF_RETURN_IF_ERROR(
+ GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef));
+ FunctionDef body_fdef;
+ TF_RETURN_IF_ERROR(
+ GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef));
+
+ TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
+ TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
+ if (lookup_library) {
+ // Copy missing FunctionDefs from lookup_library to library to make library
+ // self-contained.
+ TF_RETURN_IF_ERROR(
+ AddMissingFunctionDef(cond_fdef, lookup_library, library));
+ TF_RETURN_IF_ERROR(
+ AddMissingFunctionDef(body_fdef, lookup_library, library));
+ }
+
+ // Builds a While operator.
+ NodeDef while_def;
+ NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile");
+ builder.Attr("T", arg_types);
+ builder.Attr("cond", cond_name);
+ builder.Attr("body", body_name);
+ std::vector<NodeDefBuilder::NodeOut> inputs;
+ for (int i = 0; i < frame->args.size(); ++i) {
+ const Arg& arg = frame->args[i];
+ const Edge* in_edge;
+ TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
+ if (in_edge->IsControlEdge()) {
+ builder.ControlInput(in_edge->src()->name());
+ } else {
+ inputs.push_back(NodeDefBuilder::NodeOut(
+ in_edge->src()->name(), in_edge->src_output(), arg_types[i]));
+ }
+ }
+ builder.Input(inputs);
+ TF_RETURN_IF_ERROR(builder.Finalize(&while_def));
+ TF_ASSIGN_OR_RETURN(Node * while_node, AddNodeDefToGraph(while_def, graph));
+
+ // Copies edges to the Enter nodes and from the Exit nodes onto the While.
+ for (int i = 0; i < frame->args.size(); ++i) {
+ const Arg& arg = frame->args[i];
+ const Edge* in_edge;
+ TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
+ if (in_edge->IsControlEdge()) {
+ graph->AddControlEdge(in_edge->src(), while_node);
+ } else {
+ graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i);
+ }
+
+ if (!arg.is_loop_invariant) {
+ // Add output edges if the output of the loop is consumed.
+ if (arg.exit != nullptr) {
+ std::vector<const Edge*> edges(arg.exit->out_edges().begin(),
+ arg.exit->out_edges().end());
+ for (const Edge* edge : edges) {
+ Node* dst = edge->dst();
+ int dst_input = edge->dst_input();
+ graph->RemoveEdge(edge);
+
+ if (dst_input == Graph::kControlSlot) {
+ graph->AddControlEdge(while_node, dst);
+ } else {
+ graph->AddEdge(while_node, i, dst, dst_input);
+ }
+ }
+ }
+ }
+ }
+
+ // Remove the old nodes from the graph, and add the while node to the parent
+ // frame.
+ for (Node* node : frame->nodes) {
+ graph->RemoveNode(node);
+ }
+ frame->nodes.clear();
+ frame->parent->nodes.insert(while_node);
+
+ VLOG(2) << "Frame " << frame->name << " after: "
+ << dump_graph::DumpGraphToFile("functionalize_after", *graph,
+ library);
+
+ return Status::OK();
+}
+} // namespace
+
+Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
+ Graph* graph,
+ FunctionLibraryDefinition* library) {
+ // Note: BuildControlFlowInfo() requires that the graph's source node is
+ // connected to all source nodes in the graph. Many graphs violate this
+ // invariant.
+ std::vector<ControlFlowInfo> cf_info;
+ std::vector<string> unreachable_nodes;
+ TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes));
+ if (!unreachable_nodes.empty()) {
+ return errors::InvalidArgument(
+ "The following nodes are unreachable from the source in the graph: ",
+ errors::FormatNodeNamesForError(unreachable_nodes));
+ }
+
+ // Builds Frames, indexed by name.
+ std::unordered_map<string, Frame> frames;
+ for (Node* node : graph->op_nodes()) {
+ const ControlFlowInfo& cf = cf_info[node->id()];
+
+ VLOG(2) << "node: " << node->name() << " (" << node->id()
+ << ") frame_name: " << cf.frame_name
+ << " frame: " << (cf.frame ? cf.frame->name() : "---")
+ << " parent_frame: "
+ << (cf.parent_frame ? cf.parent_frame->name() : "---");
+ TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr);
+
+ Frame& frame = frames[cf.frame_name];
+ Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name];
+ if (frame.parent == nullptr) {
+ frame.parent = parent;
+ frame.name = cf.frame_name;
+ ++parent->num_children;
+ }
+
+ if (IsEnter(node)) {
+ Arg arg;
+ arg.enter = node;
+ TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant",
+ &arg.is_loop_invariant));
+ frame.args.push_back(arg);
+ } else if (IsLoopCond(node)) {
+ frame.loop_cond = node;
+ }
+ frame.nodes.insert(node);
+ }
+
+ // Adds frames with no children (i.e., the innermost frames) to a worklist.
+ std::deque<Frame*> worklist;
+ for (auto& frame : frames) {
+ if (frame.second.num_children == 0) {
+ worklist.push_back(&frame.second);
+ }
+ }
+
+ // Eliminate loops from innermost to outermost.
+ while (!worklist.empty()) {
+ Frame* frame = worklist.front();
+ worklist.pop_front();
+ if (frame->parent == frame) {
+ // Skip the root frame.
+ continue;
+ }
+
+ TF_RETURN_IF_ERROR(
+ FunctionalizeLoop(lookup_library, graph, frame, library));
+
+ // If the parent has no remaining children, add it to the worklist.
+ --frame->parent->num_children;
+ if (frame->parent->num_children == 0) {
+ worklist.push_back(frame->parent);
+ }
+ }
+
+ // There should be no cycle at this point, since while loops have been removed
+ // from graph.
+ // Check that the newly added XlaWhile nodes don't feed into themselves.
+ for (const Node* node : graph->op_nodes()) {
+ if (node->def().op() == "XlaWhile") {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ CheckNodeNotInCycle(node, graph->num_node_ids()),
+ "Functionalizing loop failed.");
+ }
+ }
+
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_while.h b/tensorflow/compiler/tf2xla/functionalize_while.h
new file mode 100644
index 0000000000..a708c6e4ec
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/functionalize_while.h
@@ -0,0 +1,32 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_
+#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+
+// Transformation that converts tf.while_loop() loops into functional While
+// operators, suitable for XLA compilation. If lookup_library is provided, use
+// it to make the library for control flow self-contained.
+Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
+ Graph* graph, FunctionLibraryDefinition* library);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index 5da7972397..674720e22f 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -120,45 +120,30 @@ xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape,
{expanded_filter_shape.dims() - 2});
}
-// Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding
-// zeros for the cross-depth filters. Used to build a depthwise convolution.
-xla::XlaOp ExpandFilterForDepthwiseConvolution(const TensorShape& filter_shape,
- DataType dtype,
- const xla::XlaOp& filter,
- xla::XlaBuilder* builder) {
- int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1);
- int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2);
- TensorShape expanded_filter_shape =
- ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
+// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
+// build a depthwise convolution.
+xla::XlaOp ReshapeFilterForDepthwiseConvolution(const TensorShape& filter_shape,
+ const xla::XlaOp& filter) {
+ int64 input_feature_dim = filter_shape.dims() - 2;
+ int64 output_feature_dim = filter_shape.dims() - 1;
+ int64 depthwise_multiplier = filter_shape.dim_size(output_feature_dim);
+ int64 input_feature = filter_shape.dim_size(input_feature_dim);
// Create a [H, W, ..., 1, N*M] reshape of the filter.
- TensorShape implicit_broadcast_filter_shape = expanded_filter_shape;
- implicit_broadcast_filter_shape.set_dim(
- implicit_broadcast_filter_shape.dims() - 2, 1);
- implicit_broadcast_filter_shape.set_dim(
- implicit_broadcast_filter_shape.dims() - 1,
- depthwise_multiplier * input_feature);
- auto implicit_broadcast_filter =
- xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes());
-
- // Broadcast the filter to [H, W, ..., M, M*N].
- auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder);
- auto expanded_filter = xla::Add(implicit_broadcast_filter, expanded_zero);
-
- // If the filter mask is set, choose the broadcasted filter, othwerwise,
- // choose zero.
- return xla::Select(CreateExpandedFilterMask(filter_shape, builder),
- expanded_filter, expanded_zero);
+ TensorShape implicit_broadcast_filter_shape = filter_shape;
+ implicit_broadcast_filter_shape.set_dim(input_feature_dim, 1);
+ implicit_broadcast_filter_shape.set_dim(output_feature_dim,
+ depthwise_multiplier * input_feature);
+ return xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes());
}
-// Inverse of ExpandFilterForDepthwiseConvolution.
+// Reduces the results of the convolution with an expanded filter to the
+// non-expanded filter.
xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx,
const TensorShape& filter_shape,
DataType dtype,
const xla::XlaOp& filter_backprop,
xla::XlaBuilder* builder) {
- TensorShape expanded_filter_shape =
- ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
auto masked_expanded_filter = xla::Select(
CreateExpandedFilterMask(filter_shape, builder), filter_backprop,
CreateExpandedZero(filter_shape, dtype, builder));
@@ -168,8 +153,7 @@ xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx,
// ExpandedZero guarantees that only one element is non zero, so there
// cannot be accumulated precision error.
xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype),
- *ctx->GetOrCreateAdd(dtype),
- {expanded_filter_shape.dims() - 2}),
+ *ctx->GetOrCreateAdd(dtype), {filter_shape.dims() - 2}),
filter_shape.dim_sizes());
}
@@ -245,15 +229,9 @@ class ConvOp : public XlaOpKernel {
"input and filter must have the same depth: ", in_depth,
" vs ", input_shape.dim_size(feature_dim)));
- xla::XlaBuilder* b = ctx->builder();
-
xla::XlaOp filter = ctx->Input(1);
- TensorShape expanded_filter_shape = filter_shape;
if (depthwise_) {
- filter = ExpandFilterForDepthwiseConvolution(
- filter_shape, ctx->input_type(0), filter, b);
- expanded_filter_shape =
- ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
+ filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
}
xla::ConvolutionDimensionNumbers dims;
@@ -280,14 +258,15 @@ class ConvOp : public XlaOpKernel {
int64 unused_output_size;
OP_REQUIRES_OK(
ctx, GetWindowedOutputSizeVerboseV2(
- input_shape.dim_size(dim), expanded_filter_shape.dim_size(i),
+ input_shape.dim_size(dim), filter_shape.dim_size(i),
rhs_dilation[i], window_strides[i], padding_,
&unused_output_size, &padding[i].first, &padding[i].second));
}
- xla::XlaOp conv =
- xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
- lhs_dilation, rhs_dilation, dims);
+ xla::XlaOp conv = xla::ConvGeneralDilated(
+ ctx->Input(0), filter, window_strides, padding, lhs_dilation,
+ rhs_dilation, dims,
+ /*feature_group_count=*/depthwise_ ? in_depth : 1);
ctx->SetOutput(0, conv);
}
@@ -388,7 +367,6 @@ class ConvBackpropInputOp : public XlaOpKernel {
expanded_filter_shape, out_backprop_shape, dilations_,
strides_, padding_, data_format_, &dims));
- xla::XlaBuilder* b = ctx->builder();
auto filter = ctx->Input(1);
auto out_backprop = ctx->Input(2);
@@ -425,12 +403,6 @@ class ConvBackpropInputOp : public XlaOpKernel {
rhs_dilation[i] = dilations_[dim];
}
- // If this is a depthwise convolution, expand the filter.
- if (depthwise_) {
- filter = ExpandFilterForDepthwiseConvolution(
- filter_shape, ctx->input_type(1), filter, b);
- }
-
// Mirror the filter in the spatial dimensions.
xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims);
@@ -438,7 +410,11 @@ class ConvBackpropInputOp : public XlaOpKernel {
// = gradients (with padding and dilation) <conv> mirrored_weights
xla::XlaOp in_backprop = xla::ConvGeneralDilated(
out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
- lhs_dilation, rhs_dilation, dnums);
+ lhs_dilation, rhs_dilation, dnums,
+ /*feature_group_count=*/
+ depthwise_ ? out_backprop_shape.dim_size(feature_dim) /
+ filter_shape.dim_size(num_spatial_dims_ + 1)
+ : 1);
ctx->SetOutput(0, in_backprop);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
index 35de96e0aa..44140304fd 100644
--- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
@@ -95,11 +95,11 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
// operand = s32[3,3] parameter(0)
// indices = s32[2] parameter(1)
// gather = s32[3,2] gather(operand, indices),
- // output_window_dims={0},
- // elided_window_dims={1},
- // gather_dims_to_operand_dims={1},
+ // offset_dims={0},
+ // collapsed_slice_dims={1},
+ // start_index_map={1},
// index_vector_dim=1,
- // window_bounds={3, 1}
+ // slice_sizes={3, 1}
//
//
// Example of an N-D gather pulling out slices of shape [1,1,2] out of a
@@ -108,42 +108,42 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
// operand = s32[3,3,2] parameter(0)
// indices = s32[2,2] parameter(1)
// gather = s32[2,2] gather(operand, indices),
- // output_window_dims={1},
- // elided_window_dims={0,1},
- // gather_dims_to_operand_dims={0,1},
+ // offset_dims={1},
+ // collapsed_slice_dims={0,1},
+ // start_index_map={0,1},
// index_vector_dim=0,
- // window_bounds={1,1,2}
+ // slice_sizes={1,1,2}
xla::GatherDimensionNumbers dim_numbers;
- std::vector<int64> window_bounds;
- window_bounds.reserve(input_shape.dims());
+ std::vector<int64> slice_sizes;
+ slice_sizes.reserve(input_shape.dims());
for (int64 i = 0; i < input_shape.dims(); i++) {
int64 window_bound;
if (axis <= i && i < (axis + num_index_dims)) {
- dim_numbers.add_elided_window_dims(i);
+ dim_numbers.add_collapsed_slice_dims(i);
window_bound = 1;
} else {
window_bound = input_shape.dim_size(i);
}
- window_bounds.push_back(window_bound);
+ slice_sizes.push_back(window_bound);
if (i < axis) {
- dim_numbers.add_output_window_dims(i);
+ dim_numbers.add_offset_dims(i);
} else if (i >= (axis + num_index_dims)) {
int64 indices_rank =
indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims();
- dim_numbers.add_output_window_dims(i + indices_rank - num_index_dims);
+ dim_numbers.add_offset_dims(i + indices_rank - num_index_dims);
}
}
dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1)
: indices_shape.dims());
for (int64 i = axis; i < axis + num_index_dims; i++) {
- dim_numbers.add_gather_dims_to_operand_dims(i);
+ dim_numbers.add_start_index_map(i);
}
- *gather_output = xla::Gather(input, indices, dim_numbers, window_bounds);
+ *gather_output = xla::Gather(input, indices, dim_numbers, slice_sizes);
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc
index e72200bfbc..19dd38c46e 100644
--- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc
@@ -25,7 +25,10 @@ class IdentityOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
for (int i = 0; i < ctx->num_inputs(); ++i) {
- ctx->SetOutput(i, ctx->Input(i));
+ // Forwards using the underlying op_kernel_context so both tensor and
+ // resource values are forwarded correctly.
+ ctx->op_kernel_context()->set_output(i,
+ ctx->op_kernel_context()->input(i));
}
}
@@ -35,9 +38,10 @@ class IdentityOp : public XlaOpKernel {
// XLA_* devices also register a "real" Identity operator so we suppress the
// dummy operator using CompilationOnly().
-REGISTER_XLA_OP(Name("Identity").CompilationOnly(), IdentityOp);
-
-REGISTER_XLA_OP(Name("IdentityN").CompilationOnly(), IdentityOp);
+REGISTER_XLA_OP(Name("Identity").AllowResourceTypes().CompilationOnly(),
+ IdentityOp);
+REGISTER_XLA_OP(Name("IdentityN").AllowResourceTypes().CompilationOnly(),
+ IdentityOp);
REGISTER_XLA_OP(Name("PlaceholderWithDefault"), IdentityOp);
REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp);
REGISTER_XLA_OP(Name("StopGradient"), IdentityOp);
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc
index 6a7eb8d90c..6e1dbf5472 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc
@@ -200,21 +200,10 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
}
}
- bool resource_variable_seen = false;
- for (int i = 0; i < ctx->num_inputs(); ++i) {
- if (ctx->input_type(i) == DT_RESOURCE) {
- resource_variable_seen = true;
- } else {
- OP_REQUIRES(
- ctx, !resource_variable_seen,
- errors::FailedPrecondition(
- "Resource variables and regular inputs cannot be interleaved."));
- }
- }
-
- xla::XlaOp outputs = xla::Conditional(
- ctx->Input(0), xla::Tuple(b, inputs), *then_result.computation,
- xla::Tuple(b, inputs), *else_result.computation);
+ auto input_tuple = xla::Tuple(b, inputs);
+ xla::XlaOp outputs =
+ xla::Conditional(ctx->Input(0), input_tuple, *then_result.computation,
+ input_tuple, *else_result.computation);
// Sets non-variable outputs.
for (int i = 0; i < output_types_.size(); ++i) {
xla::XlaOp output_handle = xla::GetTupleElement(outputs, i);
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
index d962ef4a5f..c0afccaa5b 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
@@ -95,10 +95,24 @@ class ReverseV2Op : public XlaOpKernel {
std::vector<int64> axes;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &axes));
+ // witnessed_axes is used to ensure that the same axis is not marked to be
+ // reversed multiple times.
+ gtl::InlinedVector<bool, 8> witnessed_axes(x_shape.dims(), false);
+
for (int d = 0; d < axes.size(); ++d) {
- OP_REQUIRES(ctx, (0 <= axes[d]) && (axes[d] < x_shape.dims()),
- errors::InvalidArgument(axes[d], " is out of range [0, ",
- x_shape.dims(), ")."));
+ OP_REQUIRES(
+ ctx, (-x_shape.dims() <= axes[d]) && (axes[d] < x_shape.dims()),
+ errors::InvalidArgument(axes[d], " is out of range [-",
+ x_shape.dims(), ", ", x_shape.dims(), ")."));
+ // Axes can be negative and are shifted to the canonical index before
+ // being lowered to HLO.
+ if (axes[d] < 0) {
+ axes[d] += x_shape.dims();
+ }
+ OP_REQUIRES(ctx, !witnessed_axes[axes[d]],
+ errors::InvalidArgument("canonicalized axis ", axes[d],
+ " was repeated."));
+ witnessed_axes[axes[d]] = true;
}
ctx->SetOutput(0, xla::Rev(ctx->Input(0), axes));
diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
index 1233a37565..2c7213f322 100644
--- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
@@ -70,7 +70,7 @@ class TileOp : public XlaOpKernel {
bool one_dimension_is_broadcasted_without_multiple = true;
for (int i = 0; i < input_dims; ++i) {
int multiple = literal.Get<int>({i});
- OP_REQUIRES(ctx, multiple,
+ OP_REQUIRES(ctx, multiple >= 0,
errors::InvalidArgument("Expected multiples[", i,
"] >= 0, but got ", multiple));
int64 new_dim = input_shape.dim_size(i) * multiple;
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index 04fa10108c..febb638e5e 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -57,7 +57,7 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) {
// We can grab entire blocks using gather
if (n > block_size) {
// Construct the starting indices of the diagonal blocks
- auto gather_indices =
+ auto start_indices =
Transpose(Broadcast(Mul(Iota(builder, xla::S32, num_blocks),
xla::ConstantR0<int32>(builder, block_size)),
/*broadcast_sizes=*/{2}),
@@ -65,13 +65,13 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) {
// Gather the diagonal blocks
xla::GatherDimensionNumbers dim_numbers;
- dim_numbers.add_output_window_dims(ndims - 1);
- dim_numbers.add_output_window_dims(ndims);
- dim_numbers.add_gather_dims_to_operand_dims(ndims - 2);
- dim_numbers.add_gather_dims_to_operand_dims(ndims - 1);
+ dim_numbers.add_offset_dims(ndims - 1);
+ dim_numbers.add_offset_dims(ndims);
+ dim_numbers.add_start_index_map(ndims - 2);
+ dim_numbers.add_start_index_map(ndims - 1);
dim_numbers.set_index_vector_dim(1);
- diag_blocks = Gather(a, gather_indices, dim_numbers,
- /*window_bounds=*/{block_size, block_size});
+ diag_blocks = Gather(a, start_indices, dim_numbers,
+ /*slice_sizes=*/{block_size, block_size});
}
// The last block might be smaller than the block size,
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 226c89bcf1..ac1deae4b2 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
@@ -310,7 +311,7 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
// unique_ptr so we can capture the cleanup status in the end.
xla_context->Ref();
Status status;
- auto step_container = xla::MakeUnique<ScopedStepContainer>(
+ auto step_container = absl::make_unique<ScopedStepContainer>(
step_id, [&status, device](const string& name) {
status = device->resource_manager()->Cleanup(name);
});
@@ -791,14 +792,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
VLOG(2) << "XLA output shape: "
<< xla::ShapeUtil::HumanString(result->xla_output_shape);
- // Copy the host transfer metadata to the result.
- for (const auto& send : host_compute_sends_) {
- *result->host_compute_metadata.add_device_to_host() = send.second;
- }
- for (const auto& recv : host_compute_recvs_) {
- *result->host_compute_metadata.add_host_to_device() = recv.second;
- }
-
// Tensorflow expects a major-to-minor order of results.
xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
@@ -816,6 +809,30 @@ Status XlaCompiler::GetChannelHandle(const string& key,
return Status::OK();
}
+Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key,
+ xla::ChannelHandle* channel) {
+ auto result = channels_.emplace(key, xla::ChannelHandle());
+ if (result.second) {
+ TF_ASSIGN_OR_RETURN(result.first->second,
+ client()->CreateHostToDeviceChannelHandle());
+ }
+ *channel = result.first->second;
+ VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString();
+ return Status::OK();
+}
+
+Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key,
+ xla::ChannelHandle* channel) {
+ auto result = channels_.emplace(key, xla::ChannelHandle());
+ if (result.second) {
+ TF_ASSIGN_OR_RETURN(result.first->second,
+ client()->CreateDeviceToHostChannelHandle());
+ }
+ *channel = result.first->second;
+ VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString();
+ return Status::OK();
+}
+
namespace {
void SetTransfer(const string& key, gtl::ArraySlice<DataType> types,
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 25332c8d8e..fde47dbdec 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -332,6 +332,16 @@ class XlaCompiler {
// same XlaCompiler.
Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
+ // Retrieves the host-to-device channel handle associated with `key`.
+ // Allocates a new channel handle if none exists.
+ Status GetHostToDeviceChannelHandle(const string& key,
+ xla::ChannelHandle* channel);
+
+ // Retrieves the device-to-host channel handle associated with `key`.
+ // Allocates a new channel handle if none exists.
+ Status GetDeviceToHostChannelHandle(const string& key,
+ xla::ChannelHandle* channel);
+
// Sets the shapes and types for the device to host transfer associated with
// 'key'.
Status SetDeviceToHostMetadata(const string& key,
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index be00ed8813..7227df9649 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -821,7 +821,10 @@ TEST_F(XlaCompilerTest, Variables) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
- auto write = ops::AssignAddVariableOp(scope, var, a);
+ // Adds an identity op around the resource to make sure identity ops propagate
+ // resources correctly.
+ auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
+ auto write = ops::AssignAddVariableOp(scope, identity, a);
auto read = ops::ReadVariableOp(
scope.WithControlDependencies(std::vector<Operation>{write}), var,
DT_INT32);
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index fdf13bb18c..2cf77b71fb 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -161,7 +161,6 @@ cc_library(
"iterator_util.h",
"map_util.h",
"overflow_util.h",
- "ptr_util.h",
"util.h",
],
visibility = ["//visibility:public"],
@@ -172,7 +171,8 @@ cc_library(
":types",
":xla_data_proto",
"//tensorflow/core:lib",
- "//tensorflow/core:ptr_util",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
],
)
@@ -210,6 +210,7 @@ tf_cc_test(
":test",
":util",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -297,6 +298,7 @@ cc_library(
":util",
":xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -315,6 +317,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -335,6 +338,7 @@ cc_library(
":util",
":xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -405,8 +409,8 @@ cc_library(
deps = [
":array",
":types",
- ":util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -489,6 +493,7 @@ cc_library(
":util",
":xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -521,6 +526,7 @@ cc_library(
":xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/memory",
],
)
@@ -576,10 +582,10 @@ cc_library(
deps = [
":shape_util",
":status_macros",
- ":util",
":xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/memory",
],
)
@@ -593,6 +599,7 @@ tf_cc_test(
":xla_data_proto",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -642,6 +649,7 @@ cc_library(
"//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -660,6 +668,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/client:padding",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h
index a17e81f448..340f94fab7 100644
--- a/tensorflow/compiler/xla/array2d.h
+++ b/tensorflow/compiler/xla/array2d.h
@@ -24,8 +24,8 @@ limitations under the License.
#include <random>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/array.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -101,7 +101,7 @@ class Array2D : public Array<T> {
template <typename NativeT = float>
std::unique_ptr<Array2D<NativeT>> MakeLinspaceArray2D(double from, double to,
int64 n1, int64 n2) {
- auto array = MakeUnique<Array2D<NativeT>>(n1, n2);
+ auto array = absl::make_unique<Array2D<NativeT>>(n1, n2);
int64 count = n1 * n2;
NativeT step =
static_cast<NativeT>((count > 1) ? (to - from) / (count - 1) : 0);
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index ad3fcee05b..6be44b1c39 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -71,12 +71,12 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -104,7 +104,6 @@ cc_library(
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:compiler",
@@ -117,6 +116,7 @@ cc_library(
"//tensorflow/compiler/xla/service:stream_pool",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
"@llvm//:support",
],
)
@@ -130,11 +130,11 @@ cc_library(
":xla_computation",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:compile_only_service",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
"@llvm//:support",
],
)
@@ -159,6 +159,7 @@ cc_library(
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -186,6 +187,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_proto",
+ "@com_google_absl//absl/memory",
],
)
@@ -211,6 +213,8 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index d0ce5e8a6a..25608d6616 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -18,11 +18,11 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -89,7 +89,7 @@ StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
"TransferToServer request");
}
- return MakeUnique<GlobalData>(stub_, response.data());
+ return absl::make_unique<GlobalData>(stub_, response.data());
}
Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id,
@@ -248,7 +248,7 @@ StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
}
}
- return MakeUnique<GlobalData>(stub_, response.output());
+ return absl::make_unique<GlobalData>(stub_, response.output());
}
StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
@@ -278,7 +278,7 @@ StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
std::vector<std::unique_ptr<GlobalData>> outputs;
for (size_t i = 0; i < computations.size(); ++i) {
outputs.push_back(
- MakeUnique<GlobalData>(stub_, response.responses(i).output()));
+ absl::make_unique<GlobalData>(stub_, response.responses(i).output()));
if (computations[i].execution_profile != nullptr) {
*computations[i].execution_profile = response.responses(i).profile();
}
@@ -340,7 +340,7 @@ StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::DeconstructTuple(
std::vector<std::unique_ptr<GlobalData>> handles;
for (auto& handle : response.element_handles()) {
- handles.push_back(MakeUnique<GlobalData>(stub_, handle));
+ handles.push_back(absl::make_unique<GlobalData>(stub_, handle));
}
return std::move(handles);
}
@@ -369,7 +369,7 @@ StatusOr<ComputationStats> Client::GetComputationStats(
StatusOr<std::unique_ptr<ProgramShape>> Client::GetComputationShape(
const XlaComputation& computation) {
TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape());
- return MakeUnique<ProgramShape>(result);
+ return absl::make_unique<ProgramShape>(result);
}
StatusOr<Shape> Client::GetShape(const GlobalData& data) {
diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc
index 803a9e4009..27b7fa7b29 100644
--- a/tensorflow/compiler/xla/client/client_library.cc
+++ b/tensorflow/compiler/xla/client/client_library.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -94,10 +95,10 @@ ClientLibrary::~ClientLibrary() = default;
service_options.set_intra_op_parallelism_threads(
options.intra_op_parallelism_threads());
- auto instance = MakeUnique<LocalInstance>();
+ auto instance = absl::make_unique<LocalInstance>();
TF_ASSIGN_OR_RETURN(instance->service,
LocalService::NewService(service_options));
- instance->client = MakeUnique<LocalClient>(instance->service.get());
+ instance->client = absl::make_unique<LocalClient>(instance->service.get());
LocalClient* cl = instance->client.get();
client_library.local_instances_.insert(
@@ -134,10 +135,11 @@ ClientLibrary::GetOrCreateCompileOnlyClient(se::Platform* platform) {
return it->second->client.get();
}
- auto instance = MakeUnique<CompileOnlyInstance>();
+ auto instance = absl::make_unique<CompileOnlyInstance>();
TF_ASSIGN_OR_RETURN(instance->service,
CompileOnlyService::NewService(platform));
- instance->client = MakeUnique<CompileOnlyClient>(instance->service.get());
+ instance->client =
+ absl::make_unique<CompileOnlyClient>(instance->service.get());
CompileOnlyClient* cl = instance->client.get();
client_library.compile_only_instances_.insert(
diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc
index 5c9abad4c3..b6012a0352 100644
--- a/tensorflow/compiler/xla/client/compile_only_client.cc
+++ b/tensorflow/compiler/xla/client/compile_only_client.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/compile_only_client.h"
+#include "absl/memory/memory.h"
#include "llvm/ADT/Triple.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc
index 0221de7672..e569610b85 100644
--- a/tensorflow/compiler/xla/client/lib/math.cc
+++ b/tensorflow/compiler/xla/client/lib/math.cc
@@ -207,7 +207,11 @@ XlaOp Lgamma(XlaOp input) {
XlaOp log_y = log_sqrt_two_pi + (z + one_half) * log_t - t + Log(x);
- XlaOp reflection = log_pi - Log(Sin(pi * input)) - log_y;
+ // If z = a + 0j, the analytic continuation of log reduces to taking the
+ // absolute value of the real part.
+ // Re(log(z)) = Re(log|z| + arg(z)j)
+ // = log|a|
+ XlaOp reflection = log_pi - Log(Abs(Sin(pi * input))) - log_y;
XlaOp result = Select(need_to_reflect, reflection, log_y);
return result;
}
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index cffb24e29b..1cd3e9b22f 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -17,9 +17,9 @@ limitations under the License.
#include <utility>
+#include "absl/memory/memory.h"
#include "llvm/ADT/Triple.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
#include "tensorflow/compiler/xla/service/source_map_util.h"
@@ -257,9 +257,9 @@ StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
local_service_->CompileExecutable(
computation, argument_layouts, updated_options));
- return WrapUnique(new LocalExecutable(std::move(executable),
- local_service_->mutable_backend(),
- updated_options));
+ return absl::WrapUnique(new LocalExecutable(std::move(executable),
+ local_service_->mutable_backend(),
+ updated_options));
}
StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index b3b00e2fff..54fe87a7a8 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -21,6 +21,8 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
@@ -469,8 +471,8 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation,
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
- c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
- [](const Shape& shape) { return &shape; });
+ absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
+ [](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(
@@ -622,8 +624,8 @@ XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands,
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
- c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
- [](const Shape& shape) { return &shape; });
+ absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
+ [](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferConcatOpShape(operand_shape_ptrs, dimension));
@@ -749,8 +751,8 @@ XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
- c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
- [](const Shape& shape) { return &shape; });
+ absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
+ [](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferVariadicOpShape(
HloOpcode::kTuple, operand_shape_ptrs));
@@ -882,24 +884,28 @@ Status XlaBuilder::VerifyConvolution(
XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding) {
+ Padding padding, int64 feature_group_count) {
return ConvWithGeneralDimensions(
lhs, rhs, window_strides, padding,
- CreateDefaultConvDimensionNumbers(window_strides.size()));
+ CreateDefaultConvDimensionNumbers(window_strides.size()),
+ feature_group_count);
}
XlaOp XlaBuilder::ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ int64 feature_group_count) {
return ConvGeneral(lhs, rhs, window_strides, padding,
- CreateDefaultConvDimensionNumbers(window_strides.size()));
+ CreateDefaultConvDimensionNumbers(window_strides.size()),
+ feature_group_count);
}
XlaOp XlaBuilder::ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
@@ -926,7 +932,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions(
return ConvGeneral(lhs, rhs, window_strides,
MakePadding(base_area_dimensions, window_dimensions,
window_strides, padding),
- dimension_numbers);
+ dimension_numbers, feature_group_count);
});
}
@@ -934,9 +940,10 @@ XlaOp XlaBuilder::ConvGeneral(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
- dimension_numbers);
+ dimension_numbers, feature_group_count);
}
XlaOp XlaBuilder::ConvGeneralDilated(
@@ -945,7 +952,8 @@ XlaOp XlaBuilder::ConvGeneralDilated(
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
@@ -964,12 +972,13 @@ XlaOp XlaBuilder::ConvGeneralDilated(
MakeWindow(window_dimensions, window_strides, padding,
lhs_dilation, rhs_dilation));
- TF_ASSIGN_OR_RETURN(
- *instr.mutable_shape(),
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, instr.window(),
- dimension_numbers));
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, instr.window(),
+ dimension_numbers, feature_group_count));
*instr.mutable_convolution_dimension_numbers() = dimension_numbers;
+ instr.set_feature_group_count(feature_group_count);
return AddInstruction(std::move(instr), HloOpcode::kConvolution,
{lhs, rhs});
@@ -1073,6 +1082,23 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
"Replicated sharding is not yet supported for infeeds");
}
+ // Infeed takes a single token operand. Generate the token to pass to the
+ // infeed.
+ XlaOp token;
+ auto make_token = [&]() {
+ HloInstructionProto token_instr;
+ *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+ return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {});
+ };
+ if (sharding()) {
+ // Arbitrarily assign token to device 0.
+ OpSharding sharding = sharding_builder::AssignDevice(0);
+ XlaScopedShardingAssignment scoped_sharding(this, sharding);
+ TF_ASSIGN_OR_RETURN(token, make_token());
+ } else {
+ TF_ASSIGN_OR_RETURN(token, make_token());
+ }
+
// The sharding is set by the client according to the data tuple shape.
// However, the shape of the infeed instruction is a tuple containing the
// data and a token. For tuple sharding type, the sharding must be changed
@@ -1088,11 +1114,11 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
sharding_builder::AssignDevice(0);
XlaScopedShardingAssignment scoped_sharding(this,
infeed_instruction_sharding);
- TF_ASSIGN_OR_RETURN(
- infeed, AddInstruction(std::move(instr), HloOpcode::kInfeed, {}));
+ TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
+ HloOpcode::kInfeed, {token}));
} else {
- TF_ASSIGN_OR_RETURN(
- infeed, AddInstruction(std::move(instr), HloOpcode::kInfeed, {}));
+ TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
+ HloOpcode::kInfeed, {token}));
}
// The infeed instruction produces a tuple of the infed data and a token
@@ -1158,8 +1184,15 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
instr.set_outfeed_config(outfeed_config);
+ // Outfeed takes a token as its second operand. Generate the token to pass
+ // to the outfeed.
+ HloInstructionProto token_instr;
+ *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+ TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
+ HloOpcode::kAfterAll, {}));
+
TF_RETURN_IF_ERROR(
- AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand})
+ AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand, token})
.status());
// The outfeed instruction produces a token. However, existing users expect
@@ -1509,8 +1542,8 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
- c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
- [](const Shape& shape) { return &shape; });
+ absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
+ [](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(
@@ -1600,27 +1633,27 @@ XlaOp XlaBuilder::While(const XlaComputation& condition,
});
}
-XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices,
+XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds) {
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
- TF_ASSIGN_OR_RETURN(const Shape& gather_indices_shape,
- GetShape(gather_indices));
+ TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
+ GetShape(start_indices));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
- ShapeInference::InferGatherShape(input_shape, gather_indices_shape,
- dimension_numbers, window_bounds));
+ ShapeInference::InferGatherShape(input_shape, start_indices_shape,
+ dimension_numbers, slice_sizes));
*instr.mutable_gather_dimension_numbers() = dimension_numbers;
- for (int64 bound : window_bounds) {
- instr.add_gather_window_bounds(bound);
+ for (int64 bound : slice_sizes) {
+ instr.add_gather_slice_sizes(bound);
}
return AddInstruction(std::move(instr), HloOpcode::kGather,
- {input, gather_indices});
+ {input, start_indices});
});
}
@@ -1914,8 +1947,8 @@ XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension,
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices));
std::vector<const Shape*> slice_shape_ptrs;
- c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs),
- [](const Shape& shape) { return &shape; });
+ absl::c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs),
+ [](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
@@ -2265,7 +2298,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
std::unique_ptr<XlaBuilder> XlaBuilder::CreateSubBuilder(
const string& computation_name) {
- auto sub_builder = MakeUnique<XlaBuilder>(computation_name);
+ auto sub_builder = absl::make_unique<XlaBuilder>(computation_name);
sub_builder->parent_builder_ = this;
sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_;
return sub_builder;
@@ -2538,32 +2571,38 @@ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
}
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
- return lhs.builder()->Conv(lhs, rhs, window_strides, padding);
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ int64 feature_group_count) {
+ return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
+ feature_group_count);
}
XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ int64 feature_group_count) {
return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides,
- padding);
+ padding, feature_group_count);
}
XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
return lhs.builder()->ConvWithGeneralDimensions(lhs, rhs, window_strides,
- padding, dimension_numbers);
+ padding, dimension_numbers,
+ feature_group_count);
}
XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding,
- dimension_numbers);
+ dimension_numbers, feature_group_count);
}
XlaOp ConvGeneralDilated(
@@ -2572,10 +2611,11 @@ XlaOp ConvGeneralDilated(
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers) {
- return lhs.builder()->ConvGeneralDilated(lhs, rhs, window_strides, padding,
- lhs_dilation, rhs_dilation,
- dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
+ return lhs.builder()->ConvGeneralDilated(
+ lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
+ dimension_numbers, feature_group_count);
}
XlaOp Fft(const XlaOp& operand, FftType fft_type,
@@ -2868,11 +2908,11 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
mantissa_bits);
}
-XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds) {
- return input.builder()->Gather(input, gather_indices, dimension_numbers,
- window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ return input.builder()->Gather(input, start_indices, dimension_numbers,
+ slice_sizes);
}
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 9403d7ca8d..469d5048b2 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -512,22 +512,24 @@ class XlaBuilder {
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding);
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ int64 feature_group_count = 1);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration in the format returned by MakePadding().
XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ int64 feature_group_count = 1);
// Enqueues a convolution instruction onto the computation, with the caller
// provided dimension numbers configuration.
XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration as well as the dimension numbers.
@@ -535,7 +537,8 @@ class XlaBuilder {
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration, dilation factors and dimension numbers.
@@ -545,7 +548,8 @@ class XlaBuilder {
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1);
// Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length.
@@ -873,9 +877,9 @@ class XlaBuilder {
const int mantissa_bits);
// Enqueues a Gather node onto the computation.
- XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+ XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
// Enqueues a Scatter node onto the computation.
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
@@ -1161,27 +1165,31 @@ class XlaBuilder {
const DotDimensionNumbers& dimension_numbers);
friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding);
+ Padding padding, int64 feature_group_count);
friend XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ int64 feature_group_count);
friend XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count);
friend XlaOp ConvGeneral(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count);
friend XlaOp ConvGeneralDilated(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count);
friend XlaOp Fft(const XlaOp& operand, FftType fft_type,
tensorflow::gtl::ArraySlice<int64> fft_length);
friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
@@ -1320,9 +1328,9 @@ class XlaBuilder {
const XlaComputation& false_computation);
friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
const int mantissa_bits);
- friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+ friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
const XlaOp& updates,
const XlaComputation& update_computation,
@@ -1646,28 +1654,32 @@ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ int64 feature_group_count = 1);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration in the format returned by MakePadding().
XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ int64 feature_group_count = 1);
// Enqueues a convolution instruction onto the computation, with the caller
// provided dimension numbers configuration.
XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration as well as the dimension numbers.
XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration, dilation factors and dimension numbers.
@@ -1677,7 +1689,8 @@ XlaOp ConvGeneralDilated(
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1);
// Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length.
@@ -2011,9 +2024,9 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
const int mantissa_bits);
// Enqueues a Gather node onto the computation.
-XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
// Enqueues a Scatter node onto the computation.
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
diff --git a/tensorflow/compiler/xla/client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_computation.cc
index 3543d41fc2..22c9e83bb2 100644
--- a/tensorflow/compiler/xla/client/xla_computation.cc
+++ b/tensorflow/compiler/xla/client/xla_computation.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <utility>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
@@ -32,7 +32,7 @@ StatusOr<std::unique_ptr<HloSnapshot>> XlaComputation::Snapshot() const {
if (IsNull()) {
return InvalidArgument("Computation is invalid.");
}
- auto session = MakeUnique<HloSnapshot>();
+ auto session = absl::make_unique<HloSnapshot>();
*session->mutable_hlo()->mutable_hlo_module() = proto_;
return std::move(session);
}
diff --git a/tensorflow/compiler/xla/iterator_util_test.cc b/tensorflow/compiler/xla/iterator_util_test.cc
index 7bc3189507..ec8b66df2d 100644
--- a/tensorflow/compiler/xla/iterator_util_test.cc
+++ b/tensorflow/compiler/xla/iterator_util_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <algorithm>
#include <list>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/test.h"
namespace xla {
@@ -27,7 +27,7 @@ namespace {
TEST(UnwrappingIteratorTest, Simple) {
std::vector<std::unique_ptr<int>> v;
for (int i = 0; i < 3; ++i) {
- v.push_back(MakeUnique<int>(i));
+ v.push_back(absl::make_unique<int>(i));
}
int i = 0;
for (auto iter = MakeUnwrappingIterator(v.begin());
@@ -51,7 +51,7 @@ TEST(UnwrappingIteratorTest, PostincrementOperator) {
TEST(UnwrappingIteratorTest, StdFind) {
std::list<std::unique_ptr<int>> l;
for (int i = 0; i < 3; ++i) {
- l.push_back(MakeUnique<int>(i));
+ l.push_back(absl::make_unique<int>(i));
}
EXPECT_EQ(l.begin()->get(),
*std::find(MakeUnwrappingIterator(l.begin()),
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
index 1bf8948ef6..5d27e4a46b 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -316,6 +316,13 @@ void AllocateFlags() {
bool_setter_for(&DebugOptions::set_xla_cpu_use_mkl_dnn),
flag_values->xla_cpu_use_mkl_dnn(),
"Generate calls to MKL-DNN in the CPU backend."),
+ tensorflow::Flag(
+ "xla_gpu_crash_on_verification_failures",
+ bool_setter_for(
+ &DebugOptions::set_xla_gpu_crash_on_verification_failures),
+ flag_values->xla_gpu_crash_on_verification_failures(),
+ "Crashes the program on extra verification failures, e.g. cuDNN "
+ "cross checking failures"),
});
ParseFlagsFromEnv(*flag_objects);
}
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index 36e472568e..d54f051a1a 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -134,7 +135,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
Literal::Literal(const Shape& shape, bool allocate_arrays)
: MutableLiteralBase() {
- shape_ = MakeUnique<Shape>(shape);
+ shape_ = absl::make_unique<Shape>(shape);
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
@@ -175,7 +176,7 @@ Literal& Literal::operator=(Literal&& other) {
}
std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) {
- auto literal = MakeUnique<Literal>(shape);
+ auto literal = absl::make_unique<Literal>(shape);
literal->root_piece_->ForEachMutableSubpiece(
[&](const ShapeIndex& index, Piece* piece) {
if (ShapeUtil::IsArray(piece->subshape())) {
@@ -289,7 +290,7 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
return InvalidArgument("LiteralProto has no layout");
}
- auto literal = MakeUnique<Literal>(proto.shape());
+ auto literal = absl::make_unique<Literal>(proto.shape());
TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus(
[&](const ShapeIndex& index, Piece* piece) {
@@ -479,7 +480,7 @@ Status Literal::MoveFrom(Literal&& src_literal,
dest_piece.set_sparse_indices(src_piece.sparse_indices());
});
- src_literal.shape_ = MakeUnique<Shape>(ShapeUtil::MakeNil());
+ src_literal.shape_ = absl::make_unique<Shape>(ShapeUtil::MakeNil());
delete src_literal.root_piece_;
src_literal.root_piece_ = new LiteralBase::Piece();
src_literal.root_piece_->set_subshape(src_literal.shape_.get());
@@ -566,7 +567,7 @@ std::unique_ptr<Literal> LiteralBase::Relayout(
Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
*subshape->mutable_layout() = new_layout;
- auto result = MakeUnique<Literal>(new_shape);
+ auto result = absl::make_unique<Literal>(new_shape);
TF_CHECK_OK(result->CopyFrom(*this));
return result;
}
@@ -602,7 +603,7 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
result_shape.dimensions(dimensions[i]));
}
- std::unique_ptr<Literal> result = MakeUnique<Literal>(result_shape);
+ std::unique_ptr<Literal> result = absl::make_unique<Literal>(result_shape);
// scratch_source_index is temporary storage space for the computed index into
// the input literal. We put it here to avoid allocating an std::vector in
@@ -691,7 +692,7 @@ std::unique_ptr<Literal> LiteralBase::Transpose(
for (auto index : LayoutUtil::MinorToMajor(shape())) {
layout->add_minor_to_major(inverse_permutation[index]);
}
- auto new_literal = MakeUnique<Literal>(permuted_shape);
+ auto new_literal = absl::make_unique<Literal>(permuted_shape);
DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()),
ShapeUtil::ByteSizeOf(shape()));
std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes());
@@ -702,7 +703,7 @@ template <typename NativeT>
std::unique_ptr<Literal> LiteralBase::SliceInternal(
const Shape& result_shape,
tensorflow::gtl::ArraySlice<int64> start_indices) const {
- auto result_literal = MakeUnique<Literal>(result_shape);
+ auto result_literal = absl::make_unique<Literal>(result_shape);
DimensionVector new_indices(ShapeUtil::Rank(result_shape));
result_literal->EachCell<NativeT>(
[&](tensorflow::gtl::ArraySlice<int64> indices, NativeT /*value*/) {
@@ -756,7 +757,7 @@ Literal LiteralBase::Clone() const {
}
std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
- auto result = MakeUnique<Literal>(shape());
+ auto result = absl::make_unique<Literal>(shape());
TF_CHECK_OK(result->CopyFrom(*this));
return result;
}
@@ -1203,7 +1204,7 @@ template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
const LiteralBase& src_literal, const ConverterType& converter) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
- auto result_literal = MakeUnique<Literal>(ShapeUtil::ChangeElementType(
+ auto result_literal = absl::make_unique<Literal>(ShapeUtil::ChangeElementType(
src_literal.shape(),
primitive_util::NativeToPrimitiveType<NativeDestT>()));
auto src_data = src_literal.data<NativeSrcT>();
@@ -1249,7 +1250,7 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
template <PrimitiveType primitive_src_type>
std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
- auto result_literal = MakeUnique<Literal>(
+ auto result_literal = absl::make_unique<Literal>(
ShapeUtil::ChangeElementType(src_literal.shape(), C64));
using NativeSrcT =
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
@@ -1396,7 +1397,7 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
elements.push_back(std::move(*new_element));
}
- auto converted = MakeUnique<Literal>();
+ auto converted = absl::make_unique<Literal>();
*converted = MutableLiteralBase::MoveIntoTuple(&elements);
return std::move(converted);
}
@@ -1956,7 +1957,7 @@ MutableLiteralBase::~MutableLiteralBase() {}
MutableBorrowingLiteral::MutableBorrowingLiteral(
const MutableBorrowingLiteral& literal)
: MutableLiteralBase() {
- shape_ = MakeUnique<Shape>(literal.shape());
+ shape_ = absl::make_unique<Shape>(literal.shape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
@@ -1967,7 +1968,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(
MutableBorrowingLiteral& MutableBorrowingLiteral::operator=(
const MutableBorrowingLiteral& literal) {
- shape_ = MakeUnique<Shape>(literal.shape());
+ shape_ = absl::make_unique<Shape>(literal.shape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
@@ -1981,7 +1982,7 @@ MutableBorrowingLiteral& MutableBorrowingLiteral::operator=(
MutableBorrowingLiteral::MutableBorrowingLiteral(
const MutableLiteralBase& literal)
: MutableLiteralBase() {
- shape_ = MakeUnique<Shape>(literal.shape());
+ shape_ = absl::make_unique<Shape>(literal.shape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
@@ -1992,7 +1993,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(
MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal)
: MutableLiteralBase() {
- shape_ = MakeUnique<Shape>(literal->shape());
+ shape_ = absl::make_unique<Shape>(literal->shape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
@@ -2004,7 +2005,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal)
MutableBorrowingLiteral::MutableBorrowingLiteral(
MutableBorrowingLiteral literal, const ShapeIndex& view_root)
: MutableLiteralBase() {
- shape_ = MakeUnique<Shape>(literal.piece(view_root).subshape());
+ shape_ = absl::make_unique<Shape>(literal.piece(view_root).subshape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
@@ -2016,7 +2017,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(
MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr,
const Shape& shape)
: MutableLiteralBase() {
- shape_ = MakeUnique<Shape>(shape);
+ shape_ = absl::make_unique<Shape>(shape);
CHECK(LayoutUtil::HasLayout(*shape_));
CHECK(!ShapeUtil::IsTuple(*shape_));
@@ -2061,7 +2062,7 @@ void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
}
BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
- : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
+ : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
CHECK(ShapeUtil::IsArray(*shape_));
CHECK(LayoutUtil::HasLayout(*shape_));
@@ -2072,7 +2073,7 @@ BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
BorrowingLiteral::BorrowingLiteral(
tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs, const Shape& shape)
- : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
+ : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
CHECK(ShapeUtil::IsTuple(*shape_));
CHECK(!ShapeUtil::IsNestedTuple(*shape_));
CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index 92c0f903cb..ed9de65299 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -25,13 +25,13 @@ limitations under the License.
#include <type_traits>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/sparse_index_array.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -312,7 +312,7 @@ class LiteralBase {
// Note: It's an antipattern to use this method then immediately call
// MutableLiteralBase::Populate on the result (since that results in zero
// initialization, then reinitialization. Conside if a call to
- // MakeUnique<Literal>(shape), followed by the call to
+ // absl::make_unique<Literal>(shape), followed by the call to
// MutableLiteralBase::Populate can be used instead.
static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
@@ -1154,8 +1154,8 @@ std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
for (int64 bound : shape().dimensions()) {
bounds.push_back(bound);
}
- auto literal =
- MakeUnique<Literal>(ShapeUtil::MakeShape(shape().element_type(), bounds));
+ auto literal = absl::make_unique<Literal>(
+ ShapeUtil::MakeShape(shape().element_type(), bounds));
int64 elements = ShapeUtil::ElementsIn(literal->shape());
if (elements == 0) {
return literal;
diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc
index 94993cc874..6883a6bbab 100644
--- a/tensorflow/compiler/xla/literal_comparison.cc
+++ b/tensorflow/compiler/xla/literal_comparison.cc
@@ -38,7 +38,8 @@ namespace {
// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
// -- on miscompare, a nice error message is given in the AssertionFailure.
template <typename FloatT, typename UnsignedT>
-Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
+Status CompareFloatsBitwiseEqual(
+ FloatT lhs, FloatT rhs, tensorflow::gtl::ArraySlice<int64> multi_index) {
auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
auto lhs_double = static_cast<double>(lhs);
@@ -46,9 +47,10 @@ Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
if (ulhs != urhs) {
return InvalidArgument(
"floating values are not bitwise-equal; and equality testing "
- "was requested: %s=%g=%a vs %s=%g=%a",
+ "was requested: %s=%g=%a vs %s=%g=%a at index %s",
StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, lhs_double,
- StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double);
+ StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double,
+ LiteralUtil::MultiIndexAsString(multi_index).c_str());
}
return Status::OK();
}
@@ -57,39 +59,48 @@ Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
// bitwise helper above (this is the un-specialized fallback, to just use the
// default gunit implementation).
template <typename NativeT>
-Status CompareEqual(NativeT lhs, NativeT rhs) {
+Status CompareEqual(NativeT lhs, NativeT rhs,
+ tensorflow::gtl::ArraySlice<int64> multi_index) {
if (lhs == rhs) {
return Status::OK();
}
- return InvalidArgument("Expected equality of these values:\n %s\n %s",
- StrCat(lhs).c_str(), StrCat(rhs).c_str());
+ return InvalidArgument(
+ "Expected equality of these values:\n %s\n %s\nat index %s",
+ StrCat(lhs).c_str(), StrCat(rhs).c_str(),
+ LiteralUtil::MultiIndexAsString(multi_index).c_str());
}
// Specializations for floating types that do bitwise comparisons when equality
// comparison is requested.
template <>
-Status CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs) {
- return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs);
+Status CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs,
+ tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs, multi_index);
}
template <>
-Status CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs) {
- return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs);
+Status CompareEqual<Eigen::half>(
+ Eigen::half lhs, Eigen::half rhs,
+ tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs, multi_index);
}
template <>
-Status CompareEqual<float>(float lhs, float rhs) {
- return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
+Status CompareEqual<float>(float lhs, float rhs,
+ tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs, multi_index);
}
template <>
-Status CompareEqual<double>(double lhs, double rhs) {
- return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs);
+Status CompareEqual<double>(double lhs, double rhs,
+ tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs, multi_index);
}
template <>
-Status CompareEqual<complex64>(complex64 lhs, complex64 rhs) {
- auto res = CompareEqual<float>(lhs.real(), rhs.real());
+Status CompareEqual<complex64>(complex64 lhs, complex64 rhs,
+ tensorflow::gtl::ArraySlice<int64> multi_index) {
+ auto res = CompareEqual<float>(lhs.real(), rhs.real(), multi_index);
if (!res.ok()) {
return res;
}
- return CompareEqual<float>(lhs.imag(), rhs.imag());
+ return CompareEqual<float>(lhs.imag(), rhs.imag(), multi_index);
}
// A recursive function which iterates through every index of expected and
@@ -102,7 +113,7 @@ Status Equal(LiteralSlice expected, LiteralSlice actual,
if (dimension == expected.shape().dimensions_size()) {
NativeT expected_value = expected.Get<NativeT>(multi_index);
NativeT actual_value = actual.Get<NativeT>(multi_index);
- return CompareEqual<NativeT>(expected_value, actual_value);
+ return CompareEqual<NativeT>(expected_value, actual_value, multi_index);
}
Status result;
@@ -720,12 +731,10 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
return Status::OK();
}
- return AppendStatus(result,
- tensorflow::strings::Printf(
- "\nat index: %s\nexpected: %s\nactual: %s",
- LiteralUtil::MultiIndexAsString(multi_index).c_str(),
- ToStringTruncated(expected).c_str(),
- ToStringTruncated(actual).c_str()));
+ return AppendStatus(
+ result, tensorflow::strings::Printf("\nexpected: %s\nactual: %s",
+ ToStringTruncated(expected).c_str(),
+ ToStringTruncated(actual).c_str()));
}
Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc
index e8f919950f..c5d0c2c267 100644
--- a/tensorflow/compiler/xla/literal_test.cc
+++ b/tensorflow/compiler/xla/literal_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -355,15 +356,15 @@ TEST_F(LiteralUtilTest, TokenEquality) {
TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
// Test equality with literals which have different layouts.
- auto colmajor =
- MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
+ auto colmajor = absl::make_unique<Literal>(
+ ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
colmajor->Set<float>({0, 0}, 1.0);
colmajor->Set<float>({0, 1}, 2.0);
colmajor->Set<float>({1, 0}, 3.0);
colmajor->Set<float>({1, 1}, 4.0);
- auto rowmajor =
- MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
+ auto rowmajor = absl::make_unique<Literal>(
+ ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
rowmajor->Set<float>({0, 0}, 1.0);
rowmajor->Set<float>({0, 1}, 2.0);
rowmajor->Set<float>({1, 0}, 3.0);
@@ -1089,7 +1090,7 @@ TEST_F(LiteralUtilTest, Populate) {
Shape shape = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
data.layout);
- auto literal = MakeUnique<Literal>(shape);
+ auto literal = absl::make_unique<Literal>(shape);
auto generator = [&](ArraySlice<int64> indexes) -> uint32 {
// Offsets from linear index just to avoid R0 literals to be initialized
// with zero.
@@ -1131,7 +1132,7 @@ TEST_F(LiteralUtilTest, PopulateParallel) {
Shape shape = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
data.layout);
- auto literal = MakeUnique<Literal>(shape);
+ auto literal = absl::make_unique<Literal>(shape);
auto generator = [&](ArraySlice<int64> indexes) -> uint32 {
// Offsets from linear index just to avoid R0 literals to be initialized
// with zero.
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 5d33df7d40..d4c7b76b28 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -57,7 +58,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
primitive_util::NativeToPrimitiveType<ToNativeT>());
}
});
- auto result = MakeUnique<Literal>(result_shape);
+ auto result = absl::make_unique<Literal>(result_shape);
// Then copy over the data from 'literal' converting FromNativeT values to
// ToNativeT values as necessary.
@@ -102,7 +103,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
}
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateToken() {
- return MakeUnique<Literal>(ShapeUtil::MakeTokenShape());
+ return absl::make_unique<Literal>(ShapeUtil::MakeTokenShape());
}
/* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
@@ -279,7 +280,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
const tensorflow::core::Bitmap& values) {
- auto literal = MakeUnique<Literal>(
+ auto literal = absl::make_unique<Literal>(
ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
literal->PopulateR1(values);
return literal;
@@ -287,7 +288,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1U8(
tensorflow::StringPiece value) {
- auto literal = MakeUnique<Literal>(
+ auto literal = absl::make_unique<Literal>(
ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
for (int i = 0; i < value.size(); ++i) {
literal->Set<uint8>({i}, value[i]);
@@ -312,7 +313,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
CHECK_EQ(new_dimensions.size(), minor_to_major.size());
- auto new_literal = MakeUnique<Literal>(
+ auto new_literal = absl::make_unique<Literal>(
ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
// Create a new shape with the given minor-to-major layout. This shape is used
@@ -436,7 +437,8 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
for (const auto* element : elements) {
element_shapes.push_back(element->shape());
}
- auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+ auto literal =
+ absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
for (int i = 0; i < elements.size(); ++i) {
TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
}
@@ -449,7 +451,8 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
for (const auto& element : elements) {
element_shapes.push_back(element.shape());
}
- auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+ auto literal =
+ absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
for (int i = 0; i < elements.size(); ++i) {
TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i}));
}
@@ -463,7 +466,8 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
for (const auto& element : elements) {
element_shapes.push_back(element->shape());
}
- auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+ auto literal =
+ absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
for (int64 i = 0; i < elements.size(); ++i) {
TF_CHECK_OK(
literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i}));
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index e3737a9d00..1109021ea8 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -27,6 +27,7 @@ limitations under the License.
#include <type_traits>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -34,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/sparse_index_array.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -327,7 +327,7 @@ std::ostream& operator<<(std::ostream& out, const Literal& literal);
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR0(NativeT value) {
- auto literal = MakeUnique<Literal>(ShapeUtil::MakeShape(
+ auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShape(
primitive_util::NativeToPrimitiveType<NativeT>(), {}));
literal->Set({}, value);
return literal;
@@ -336,7 +336,7 @@ template <typename NativeT>
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
tensorflow::gtl::ArraySlice<NativeT> values) {
- auto literal = MakeUnique<Literal>(
+ auto literal = absl::make_unique<Literal>(
ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
{static_cast<int64>(values.size())}));
literal->PopulateR1(values);
@@ -347,7 +347,7 @@ template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2WithLayout(
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout) {
- auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
+ auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(),
{static_cast<int64>(values.size()),
static_cast<int64>(values.begin()->size())},
@@ -433,9 +433,10 @@ template <typename NativeT>
int64 rank = dimensions.size();
CHECK_EQ(num_elements, indices.index_count());
CHECK_EQ(rank, indices.rank());
- auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithSparseLayout(
- primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
- indices.max_indices()));
+ auto literal =
+ absl::make_unique<Literal>(ShapeUtil::MakeShapeWithSparseLayout(
+ primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
+ indices.max_indices()));
literal->PopulateSparse(indices, values, sort);
return literal;
}
@@ -451,7 +452,7 @@ template <typename NativeT>
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout) {
- auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
+ auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
AsInt64Slice(layout.minor_to_major())));
literal->PopulateFromArray(values);
@@ -571,8 +572,9 @@ template <typename NativeT>
/* static */ std::unique_ptr<Literal>
LiteralUtil::CreateFullWithDescendingLayout(
tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) {
- auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
- primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
+ auto literal =
+ absl::make_unique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
+ primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
literal->PopulateWithValue(value);
return literal;
}
@@ -584,7 +586,7 @@ LiteralUtil::CreateRandomLiteral(
const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
TF_RET_CHECK(shape.element_type() == type);
- auto literal = MakeUnique<Literal>(shape);
+ auto literal = absl::make_unique<Literal>(shape);
TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
[&](tensorflow::gtl::ArraySlice<int64> indexes) {
return generator(indexes);
diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc
index 6b7fd10d63..55c4a80e29 100644
--- a/tensorflow/compiler/xla/packed_literal_reader.cc
+++ b/tensorflow/compiler/xla/packed_literal_reader.cc
@@ -19,9 +19,9 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -57,7 +57,7 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
PrimitiveType_Name(shape.element_type()).c_str());
}
- auto result = MakeUnique<Literal>(literal_shape);
+ auto result = absl::make_unique<Literal>(literal_shape);
result->PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
int64 elements = ShapeUtil::ElementsIn(shape);
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index c8f2d65c22..a91336c3ac 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -59,6 +59,7 @@ cc_library(
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index 8246f76d34..c133a20419 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -14,10 +14,10 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/local_computation_builder.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -575,6 +575,16 @@ StatusOr<bool> LocalComputationBuilder::IsConstant(const LocalOp& operand) {
return builder_.IsConstant(operand.op());
}
+LocalOp LocalComputationBuilder::Sort(const LocalOp& operand, int64 dimension) {
+ return xla::Sort(operand.op(), tensorflow::gtl::nullopt, dimension);
+}
+
+LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys,
+ const LocalOp& values,
+ int64 dimension) {
+ return xla::Sort(keys.op(), values.op(), dimension);
+}
+
StatusOr<LocalComputation*> LocalComputationBuilder::BuildConstantSubGraph(
const LocalOp& operand) {
TF_ASSIGN_OR_RETURN(XlaComputation computation,
@@ -640,7 +650,6 @@ _FORWARD_UNOP(Sin)
_FORWARD_UNOP(Tanh)
_FORWARD_UNOP(IsFinite)
_FORWARD_UNOP(Neg)
-_FORWARD_UNOP(Sort)
_FORWARD_UNOP(Sqrt)
_FORWARD_UNOP(Rsqrt)
_FORWARD_UNOP(Square)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index a568c24c63..5f9078ab84 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -301,6 +301,11 @@ class LocalComputationBuilder {
StatusOr<bool> IsConstant(const LocalOp& operand);
+ LocalOp Sort(const LocalOp& operand, int64 dimension);
+
+ LocalOp SortKeyVal(const LocalOp& keys, const LocalOp& values,
+ int64 dimension);
+
StatusOr<LocalComputation*> BuildConstantSubGraph(const LocalOp& operand);
#define _FORWARD(method_name, return_sig, args_sig) \
@@ -357,7 +362,6 @@ class LocalComputationBuilder {
_FORWARD_UNOP(Tanh)
_FORWARD_UNOP(IsFinite)
_FORWARD_UNOP(Neg)
- _FORWARD_UNOP(Sort)
_FORWARD_UNOP(Sqrt)
_FORWARD_UNOP(Rsqrt)
_FORWARD_UNOP(Square)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 5d5a955bfe..fa5d75908f 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -1011,6 +1011,7 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Pow;
%unignore xla::swig::LocalComputationBuilder::Neg;
%unignore xla::swig::LocalComputationBuilder::Sort;
+%unignore xla::swig::LocalComputationBuilder::SortKeyVal;
%unignore xla::swig::LocalComputationBuilder::Sqrt;
%unignore xla::swig::LocalComputationBuilder::Rsqrt;
%unignore xla::swig::LocalComputationBuilder::Square;
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index a2c6fc344d..fa4366ff07 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -105,7 +105,6 @@ _UNARY_OPS = [
'Square',
'Reciprocal',
'Neg',
- 'Sort',
'Erf',
'Erfc',
'ErfInv',
@@ -1218,6 +1217,14 @@ class ComputationBuilder(object):
lhs_dilation, rhs_dilation,
dimension_numbers)
+ def Sort(self, operand, dimension=-1):
+ """Enqueues a sort operation onto the computation."""
+ return self._client.Sort(operand, dimension)
+
+ def SortKeyVal(self, keys, values, dimension=-1):
+ """Enqueues a key-value sort operation onto the computation."""
+ return self._client.SortKeyVal(keys, values, dimension)
+
def _forward_methods_to_local_builder():
"""Forward remaining ComputationBuilder methods to the C API.
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index a803520876..3de7ee2bc8 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <array>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
@@ -43,7 +44,7 @@ std::unique_ptr<Array2D<T>> MatmulArray2DImpl(
int m = lhs.height();
int n = rhs.width();
int k = lhs.width();
- auto result = MakeUnique<Array2D<T>>(m, n);
+ auto result = absl::make_unique<Array2D<T>>(m, n);
// Because Eigen is a header-oriented library, make sure that the Eigen code
// is the same as the code used by the CPU backend (otherwise the linker will
// randomly pick *some* definition).
@@ -77,7 +78,8 @@ std::unique_ptr<Array2D<T>> MatmulArray2DImpl(
/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64(
const Array2D<float>& input) {
- auto result = MakeUnique<Array2D<double>>(input.height(), input.width());
+ auto result =
+ absl::make_unique<Array2D<double>>(input.height(), input.width());
for (int64 rowno = 0; rowno < input.height(); ++rowno) {
for (int64 colno = 0; colno < input.height(); ++colno) {
(*result)(rowno, colno) = input(rowno, colno);
@@ -126,8 +128,8 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1},
{rhs_dilation, 1}, dnums2d);
- auto convr3 = MakeUnique<Array3D<float>>(convr4->planes(), convr4->depth(),
- convr4->height());
+ auto convr3 = absl::make_unique<Array3D<float>>(
+ convr4->planes(), convr4->depth(), convr4->height());
convr4->Each(
[&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) {
CHECK_EQ(indices[3], 0);
@@ -201,7 +203,7 @@ ReferenceUtil::ReduceWindow1DGeneric(
window_util::StridedBound(padded_width, window[i], stride[i]);
pad_low[i] = padding[i].first;
}
- auto result = MakeUnique<std::vector<float>>(window_counts[0]);
+ auto result = absl::make_unique<std::vector<float>>(window_counts[0]);
// Do a full 1D reduce window.
for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
@@ -247,7 +249,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
window_util::StridedBound(padded_width, window[i], stride[i]);
pad_low[i] = padding[i].first;
}
- auto result = MakeUnique<Array2D<float>>(window_counts[0], window_counts[1]);
+ auto result =
+ absl::make_unique<Array2D<float>>(window_counts[0], window_counts[1]);
// Do a full 2D reduce window.
for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
@@ -296,8 +299,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
WindowCount(dim_lengths[i], window[i], stride[i], padding);
pad_low[i] = padding_both[i].first;
}
- auto result = MakeUnique<Array3D<float>>(window_counts[0], window_counts[1],
- window_counts[2]);
+ auto result = absl::make_unique<Array3D<float>>(
+ window_counts[0], window_counts[1], window_counts[2]);
for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
@@ -358,8 +361,8 @@ ReferenceUtil::ReduceWindow4DGeneric(
window_util::StridedBound(padded_width, window[i], stride[i]);
pad_low[i] = padding[i].first;
}
- auto result = MakeUnique<Array4D<float>>(window_counts[0], window_counts[1],
- window_counts[2], window_counts[3]);
+ auto result = absl::make_unique<Array4D<float>>(
+ window_counts[0], window_counts[1], window_counts[2], window_counts[3]);
// Do a full 4D reduce window.
for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
@@ -426,8 +429,8 @@ ReferenceUtil::SelectAndScatter4DGePlus(
const tensorflow::gtl::ArraySlice<int64>& window,
const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding) {
Padding padding = same_padding ? Padding::kSame : Padding::kValid;
- auto result = MakeUnique<Array4D<float>>(operand.n1(), operand.n2(),
- operand.n3(), operand.n4());
+ auto result = absl::make_unique<Array4D<float>>(operand.n1(), operand.n2(),
+ operand.n3(), operand.n4());
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
operand.n4()};
auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
@@ -583,10 +586,10 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4);
auto result =
- MakeUnique<Array4D<float>>(result_literal->shape().dimensions(0),
- result_literal->shape().dimensions(1),
- result_literal->shape().dimensions(2),
- result_literal->shape().dimensions(3));
+ absl::make_unique<Array4D<float>>(result_literal->shape().dimensions(0),
+ result_literal->shape().dimensions(1),
+ result_literal->shape().dimensions(2),
+ result_literal->shape().dimensions(3));
result->Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
*value = result_literal->Get<float>(indices);
@@ -601,7 +604,7 @@ ReferenceUtil::ReduceToColArray2D(
const std::function<float(float, float)>& reduce_function) {
int64 rows = matrix.height();
int64 cols = matrix.width();
- auto result = MakeUnique<std::vector<float>>();
+ auto result = absl::make_unique<std::vector<float>>();
for (int64 i = 0; i < rows; ++i) {
float acc = init;
for (int64 j = 0; j < cols; ++j) {
@@ -618,7 +621,7 @@ ReferenceUtil::ReduceToRowArray2D(
const std::function<float(float, float)>& reduce_function) {
int64 rows = matrix.height();
int64 cols = matrix.width();
- auto result = MakeUnique<std::vector<float>>();
+ auto result = absl::make_unique<std::vector<float>>();
for (int64 i = 0; i < cols; ++i) {
float acc = init;
for (int64 j = 0; j < rows; ++j) {
@@ -674,8 +677,8 @@ ReferenceUtil::ReduceToRowArray2D(
/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D(
const std::vector<float>& array, const std::vector<int64>& bounds,
int64 broadcast_from_dim) {
- auto result =
- MakeUnique<Array4D<float>>(bounds[0], bounds[1], bounds[2], bounds[3]);
+ auto result = absl::make_unique<Array4D<float>>(bounds[0], bounds[1],
+ bounds[2], bounds[3]);
for (int64 i = 0; i < result->n1(); ++i) {
for (int64 j = 0; j < result->n2(); ++j) {
for (int64 k = 0; k < result->n3(); ++k) {
@@ -710,7 +713,7 @@ ReferenceUtil::ReduceToRowArray2D(
CHECK_EQ(dims.size(), 1);
int64 rows = dims[0] == 0 ? array.n2() : array.n1();
int64 cols = dims[0] == 2 ? array.n2() : array.n3();
- auto result = MakeUnique<Array2D<float>>(rows, cols);
+ auto result = absl::make_unique<Array2D<float>>(rows, cols);
result->Fill(init);
for (int i0 = 0; i0 < array.n1(); ++i0) {
for (int i1 = 0; i1 < array.n2(); ++i1) {
@@ -730,7 +733,7 @@ ReferenceUtil::ReduceToRowArray2D(
const std::function<float(float)>& map_function) {
int64 rows = matrix.height();
int64 cols = matrix.width();
- auto result = MakeUnique<Array2D<float>>(rows, cols);
+ auto result = absl::make_unique<Array2D<float>>(rows, cols);
for (int64 i = 0; i < rows; ++i) {
for (int64 j = 0; j < cols; ++j) {
(*result)(i, j) = map_function(matrix(i, j));
@@ -746,7 +749,7 @@ ReferenceUtil::ReduceToRowArray2D(
CHECK_EQ(lhs.width(), rhs.width());
int64 rows = lhs.height();
int64 cols = rhs.width();
- auto result = MakeUnique<Array2D<float>>(rows, cols);
+ auto result = absl::make_unique<Array2D<float>>(rows, cols);
for (int64 i = 0; i < rows; ++i) {
for (int64 j = 0; j < cols; ++j) {
(*result)(i, j) = map_function(lhs(i, j), rhs(i, j));
@@ -760,7 +763,7 @@ ReferenceUtil::ReduceToRowArray2D(
const std::function<float(float, int64, int64)>& map_function) {
int64 rows = matrix.height();
int64 cols = matrix.width();
- auto result = MakeUnique<Array2D<float>>(rows, cols);
+ auto result = absl::make_unique<Array2D<float>>(rows, cols);
for (int64 i = 0; i < rows; ++i) {
for (int64 j = 0; j < cols; ++j) {
(*result)(i, j) = map_function(matrix(i, j), i, j);
diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h
index 8fa6961d19..88f853a359 100644
--- a/tensorflow/compiler/xla/reference_util.h
+++ b/tensorflow/compiler/xla/reference_util.h
@@ -22,11 +22,11 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/padding.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -42,7 +42,8 @@ class ReferenceUtil {
template <typename T>
static std::unique_ptr<Array2D<T>> TransposeArray2D(
const Array2D<T>& operand) {
- auto result = MakeUnique<Array2D<T>>(operand.width(), operand.height());
+ auto result =
+ absl::make_unique<Array2D<T>>(operand.width(), operand.height());
for (int64 w = 0; w < operand.width(); ++w) {
for (int64 h = 0; h < operand.height(); ++h) {
(*result)(w, h) = operand(h, w);
@@ -242,7 +243,7 @@ class ReferenceUtil {
const Array2D<T>& rhs,
int concatenate_dimension) {
CHECK(0 <= concatenate_dimension && concatenate_dimension < 2);
- auto result = MakeUnique<Array2D<T>>(
+ auto result = absl::make_unique<Array2D<T>>(
concatenate_dimension == 0 ? lhs.n1() + rhs.n1() : lhs.n1(),
concatenate_dimension == 1 ? lhs.n2() + rhs.n2() : lhs.n2());
for (int64 i0 = 0; i0 < result->n1(); ++i0) {
@@ -276,7 +277,8 @@ class ReferenceUtil {
out_dims[i] = lhs_dims[i] + rhs_dims[i];
}
}
- auto result = MakeUnique<Array3D<T>>(out_dims[0], out_dims[1], out_dims[2]);
+ auto result =
+ absl::make_unique<Array3D<T>>(out_dims[0], out_dims[1], out_dims[2]);
for (int64 i0 = 0; i0 < result->n1(); ++i0) {
for (int64 i1 = 0; i1 < result->n2(); ++i1) {
for (int64 i2 = 0; i2 < result->n3(); ++i2) {
@@ -310,8 +312,8 @@ class ReferenceUtil {
out_dims[i] = lhs_dims[i] + rhs_dims[i];
}
}
- auto result = MakeUnique<Array4D<T>>(out_dims[0], out_dims[1], out_dims[2],
- out_dims[3]);
+ auto result = absl::make_unique<Array4D<T>>(out_dims[0], out_dims[1],
+ out_dims[2], out_dims[3]);
for (int64 i0 = 0; i0 < result->n1(); ++i0) {
for (int64 i1 = 0; i1 < result->n2(); ++i1) {
for (int64 i2 = 0; i2 < result->n3(); ++i2) {
@@ -355,9 +357,9 @@ class ReferenceUtil {
CHECK_LE(limits[1], input.n2());
CHECK_GE(strides[0], 1);
CHECK_GE(strides[1], 1);
- auto result =
- MakeUnique<Array2D<T>>(CeilOfRatio(limits[0] - starts[0], strides[0]),
- CeilOfRatio(limits[1] - starts[1], strides[1]));
+ auto result = absl::make_unique<Array2D<T>>(
+ CeilOfRatio(limits[0] - starts[0], strides[0]),
+ CeilOfRatio(limits[1] - starts[1], strides[1]));
for (int64 i0 = 0; i0 < result->n1(); ++i0) {
for (int64 i1 = 0; i1 < result->n2(); ++i1) {
(*result)(i0, i1) =
@@ -381,10 +383,10 @@ class ReferenceUtil {
CHECK_GE(strides[0], 1);
CHECK_GE(strides[1], 1);
CHECK_GE(strides[2], 1);
- auto result =
- MakeUnique<Array3D<T>>(CeilOfRatio(limits[0] - starts[0], strides[0]),
- CeilOfRatio(limits[1] - starts[1], strides[1]),
- CeilOfRatio(limits[2] - starts[2], strides[2]));
+ auto result = absl::make_unique<Array3D<T>>(
+ CeilOfRatio(limits[0] - starts[0], strides[0]),
+ CeilOfRatio(limits[1] - starts[1], strides[1]),
+ CeilOfRatio(limits[2] - starts[2], strides[2]));
for (int64 i0 = 0; i0 < result->n1(); ++i0) {
for (int64 i1 = 0; i1 < result->n2(); ++i1) {
@@ -415,11 +417,11 @@ class ReferenceUtil {
CHECK_GE(strides[1], 1);
CHECK_GE(strides[2], 1);
CHECK_GE(strides[3], 1);
- auto result =
- MakeUnique<Array4D<T>>(CeilOfRatio(limits[0] - starts[0], strides[0]),
- CeilOfRatio(limits[1] - starts[1], strides[1]),
- CeilOfRatio(limits[2] - starts[2], strides[2]),
- CeilOfRatio(limits[3] - starts[3], strides[3]));
+ auto result = absl::make_unique<Array4D<T>>(
+ CeilOfRatio(limits[0] - starts[0], strides[0]),
+ CeilOfRatio(limits[1] - starts[1], strides[1]),
+ CeilOfRatio(limits[2] - starts[2], strides[2]),
+ CeilOfRatio(limits[3] - starts[3], strides[3]));
for (int64 i0 = 0; i0 < result->n1(); ++i0) {
for (int64 i1 = 0; i1 < result->n2(); ++i1) {
for (int64 i2 = 0; i2 < result->n3(); ++i2) {
@@ -460,8 +462,8 @@ class ReferenceUtil {
template <typename F>
static std::unique_ptr<Array4D<float>> MapWithIndexArray4D(
const Array4D<float>& input, F&& map_function) {
- auto result = MakeUnique<Array4D<float>>(input.planes(), input.depth(),
- input.height(), input.width());
+ auto result = absl::make_unique<Array4D<float>>(
+ input.planes(), input.depth(), input.height(), input.width());
for (int64 plane = 0; plane < input.planes(); ++plane) {
for (int64 depth = 0; depth < input.depth(); ++depth) {
for (int64 height = 0; height < input.height(); ++height) {
@@ -495,8 +497,8 @@ class ReferenceUtil {
template <typename F>
static std::unique_ptr<Array4D<float>> MapWithIndexArray4D(
const Array4D<float>& lhs, const Array4D<float>& rhs, F&& map_function) {
- auto result = MakeUnique<Array4D<float>>(lhs.planes(), lhs.depth(),
- lhs.height(), lhs.width());
+ auto result = absl::make_unique<Array4D<float>>(lhs.planes(), lhs.depth(),
+ lhs.height(), lhs.width());
for (int64 plane = 0; plane < lhs.planes(); ++plane) {
for (int64 depth = 0; depth < lhs.depth(); ++depth) {
for (int64 height = 0; height < lhs.height(); ++height) {
@@ -530,7 +532,7 @@ class ReferenceUtil {
int64 out1 =
in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1;
- auto result = MakeUnique<Array2D<NativeT>>(out0, out1);
+ auto result = absl::make_unique<Array2D<NativeT>>(out0, out1);
result->Fill(pad);
int64 o0 = low_padding0;
for (int64 i0 = 0; i0 < in0; ++i0) {
@@ -669,7 +671,7 @@ class ReferenceUtil {
static std::unique_ptr<Array2D<T1>> ApplyElementwise2D(
F&& f, const Array2D<T1>& array1, const Array2D<Ts>&... arrays) {
AssertSameSize2D(array1, arrays...);
- auto result = MakeUnique<Array2D<T1>>(array1.n1(), array1.n2());
+ auto result = absl::make_unique<Array2D<T1>>(array1.n1(), array1.n2());
for (int64 i = 0; i < array1.n1(); ++i) {
for (int64 j = 0; j < array1.n2(); ++j) {
(*result)(i, j) = f(array1(i, j), arrays(i, j)...);
diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc
index 8091bed499..3ec0192148 100644
--- a/tensorflow/compiler/xla/reference_util_test.cc
+++ b/tensorflow/compiler/xla/reference_util_test.cc
@@ -18,12 +18,12 @@ limitations under the License.
#include <cmath>
#include <memory>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -36,7 +36,7 @@ namespace {
class ReferenceUtilTest : public ::testing::Test {
protected:
ReferenceUtilTest() {
- matrix_ = MakeUnique<Array2D<float>>(rows_, cols_);
+ matrix_ = absl::make_unique<Array2D<float>>(rows_, cols_);
// [1.f 2.f 3.f]
// [4.f 5.f 6.f]
for (int64 i = 0; i < rows_; ++i) {
@@ -112,8 +112,8 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) {
}
TEST_F(ReferenceUtilTest, MapArray4D) {
- auto input = MakeUnique<Array4D<float>>(/*planes=*/2, /*depth=*/3,
- /*height=*/4, /*width=*/5);
+ auto input = absl::make_unique<Array4D<float>>(/*planes=*/2, /*depth=*/3,
+ /*height=*/4, /*width=*/5);
input->FillWithMultiples(1.0f);
auto multiply_by_two = [](float value) { return 2 * value; };
auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two);
@@ -126,8 +126,8 @@ TEST_F(ReferenceUtilTest, MapArray4D) {
}
TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
- auto input = MakeUnique<Array4D<float>>(/*planes=*/2, /*depth=*/3,
- /*height=*/4, /*width=*/5);
+ auto input = absl::make_unique<Array4D<float>>(/*planes=*/2, /*depth=*/3,
+ /*height=*/4, /*width=*/5);
input->FillWithMultiples(1.0f);
auto subtract_index = [](float value, int64 plane, int64 depth, int64 height,
int64 width) {
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 7331d2b54c..7fdffe85c0 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -175,6 +175,7 @@ cc_library(
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -237,6 +238,8 @@ cc_library(
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
],
)
@@ -263,6 +266,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -311,6 +315,8 @@ cc_library(
"//tensorflow/core:human_readable_json",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
],
)
@@ -449,6 +455,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -517,6 +524,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -574,6 +582,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//third_party/eigen3",
+ "@com_google_absl//absl/memory",
],
)
@@ -615,6 +624,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:ptr_util",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
alwayslink = 1,
)
@@ -647,6 +657,7 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -719,6 +730,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -736,6 +748,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:ptr_util",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -766,6 +779,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor",
+ "@com_google_absl//absl/memory",
],
)
@@ -813,6 +827,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -831,6 +846,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -847,6 +863,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -864,6 +881,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/memory",
],
)
@@ -923,6 +941,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -950,6 +969,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/memory",
],
)
@@ -977,6 +997,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1031,6 +1052,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1049,6 +1071,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1065,6 +1088,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1074,6 +1098,7 @@ cc_library(
hdrs = ["hlo_module_group_util.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
":hlo_module_group_metadata",
":hlo_reachability",
"//tensorflow/compiler/xla:status",
@@ -1082,6 +1107,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1142,6 +1168,7 @@ cc_library(
":hlo_pass",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -1181,6 +1208,8 @@ cc_library(
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
],
)
@@ -1198,6 +1227,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -1231,6 +1261,7 @@ cc_library(
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -1245,6 +1276,7 @@ cc_library(
":while_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -1267,6 +1299,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1289,6 +1322,8 @@ cc_library(
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
],
)
@@ -1312,6 +1347,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -1323,8 +1359,7 @@ cc_library(
":hlo",
":hlo_creation_utils",
":hlo_pass",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -1400,6 +1435,40 @@ tf_cc_test(
)
cc_library(
+ name = "convolution_feature_group_converter",
+ srcs = ["convolution_feature_group_converter.cc"],
+ hdrs = ["convolution_feature_group_converter.h"],
+ deps = [
+ ":hlo",
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ ],
+)
+
+tf_cc_test(
+ name = "convolution_feature_group_converter_test",
+ size = "small",
+ srcs = ["convolution_feature_group_converter_test.cc"],
+ deps = [
+ ":convolution_feature_group_converter",
+ ":hlo",
+ ":hlo_matchers",
+ ":hlo_parser",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ ],
+)
+
+cc_library(
name = "while_loop_analysis",
srcs = ["while_loop_analysis.cc"],
hdrs = ["while_loop_analysis.h"],
@@ -1549,6 +1618,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -1569,6 +1639,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1602,6 +1673,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -1621,6 +1693,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
alwayslink = True, # Contains per-platform computation placer registration
)
@@ -1711,6 +1784,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
],
)
@@ -1756,6 +1831,7 @@ tf_cc_binary(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1772,6 +1848,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1831,6 +1908,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1849,6 +1927,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1890,6 +1969,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1983,6 +2063,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/memory",
],
)
@@ -1995,7 +2076,6 @@ cc_library(
":hlo_dataflow_analysis",
":logical_buffer",
":logical_buffer_analysis",
- "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@@ -2003,6 +2083,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -2053,6 +2134,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -2143,6 +2225,7 @@ cc_library(
":shape_inference",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -2225,6 +2308,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -2306,6 +2390,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -2343,6 +2428,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -2359,6 +2445,7 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -2390,6 +2477,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -2404,6 +2492,7 @@ cc_library(
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -2464,6 +2553,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -2532,6 +2622,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/algorithm:container",
"@llvm//:core",
"@llvm//:transform_utils",
],
@@ -2563,10 +2654,10 @@ cc_library(
":computation_layout",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -2747,9 +2838,9 @@ cc_library(
hdrs = ["stream_pool.h"],
deps = [
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -2847,6 +2938,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//third_party/eigen3",
+ "@com_google_absl//absl/memory",
],
)
@@ -2894,6 +2986,7 @@ cc_library(
":tuple_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -2907,6 +3000,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -2922,6 +3016,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -2949,6 +3044,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -3003,6 +3099,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:ptr_util",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -3036,6 +3133,8 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index f7812d9661..1d26e30651 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -22,6 +22,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
@@ -540,7 +542,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
// If a literal is all the same element replace it with a scalar broadcast.
if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
constant->literal().IsAllFirst()) {
- std::unique_ptr<Literal> unique_scalar = MakeUnique<Literal>(
+ std::unique_ptr<Literal> unique_scalar = absl::make_unique<Literal>(
LiteralUtil::GetFirstScalarLiteral(constant->literal()));
HloInstruction* scalar = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(unique_scalar)));
@@ -1752,8 +1754,8 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
}
auto is_unstrided_slice = [](const HloInstruction* hlo) {
- return c_all_of(hlo->slice_strides(),
- [](int64 stride) { return stride == 1; });
+ return absl::c_all_of(hlo->slice_strides(),
+ [](int64 stride) { return stride == 1; });
};
if (slice->operand(0)->opcode() == HloOpcode::kSlice &&
is_unstrided_slice(slice) && is_unstrided_slice(slice->operand(0))) {
@@ -1930,7 +1932,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
// This should make fusion easier or use less memory bandwidth in the unfused
// case.
if (arg->opcode() == HloOpcode::kConcatenate &&
- c_linear_search(reduce->dimensions(), arg->concatenate_dimension())) {
+ absl::c_linear_search(reduce->dimensions(),
+ arg->concatenate_dimension())) {
HloInstruction* old_reduce = nullptr;
for (HloInstruction* operand : arg->operands()) {
HloInstruction* new_reduce = computation_->AddInstruction(
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 5837391d75..427069af5f 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -18,9 +18,9 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -2037,7 +2037,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
// Builds a convolution from <options> and runs algebraic simplification on
// the computation. Returns a string description of the result of
// simplification.
- auto build_and_simplify = [&options]() -> string {
+ auto build_and_simplify = [&]() -> string {
HloComputation::Builder b(TestName());
Window window;
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc
index 51ebc4763b..d0806d24a2 100644
--- a/tensorflow/compiler/xla/service/allocation_tracker.cc
+++ b/tensorflow/compiler/xla/service/allocation_tracker.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -91,8 +91,9 @@ StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal(
// If ShapedBufferTy is ScopedShapedBuffer, release the ScopedShapedBuffer
// into a regular ShapedBuffer, which is stored in
// handle_to_shaped_buffers_.
- handle_to_shaped_buffers_[handle].emplace_back(MakeUnique<ShapedBuffer>(
- ReleaseIfScopedShapedBuffer(std::move(shaped_buffer))));
+ handle_to_shaped_buffers_[handle].emplace_back(
+ absl::make_unique<ShapedBuffer>(
+ ReleaseIfScopedShapedBuffer(std::move(shaped_buffer))));
}
GlobalDataHandle result;
diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc
index d12be3e007..841d0fa85b 100644
--- a/tensorflow/compiler/xla/service/backend.cc
+++ b/tensorflow/compiler/xla/service/backend.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/memory/memory.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
@@ -127,8 +128,8 @@ Backend::Backend(
}
}
// Create a memory allocator for the valid stream executors.
- memory_allocator_ =
- MakeUnique<StreamExecutorMemoryAllocator>(platform, stream_executors);
+ memory_allocator_ = absl::make_unique<StreamExecutorMemoryAllocator>(
+ platform, stream_executors);
CHECK(!stream_executors_.empty())
<< "Service found no devices for backend " << platform_->Name() << '.';
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
index 2099916509..b226e7ecb0 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
@@ -84,10 +85,10 @@ StatusOr<bool> BatchDotSimplification::Run(HloModule* module) {
bool changed = false;
std::vector<HloInstruction*> dot_instrs;
for (HloComputation* computation : module->MakeNonfusionComputations()) {
- c_copy_if(computation->instructions(), std::back_inserter(dot_instrs),
- [](HloInstruction* instr) {
- return instr->opcode() == HloOpcode::kDot;
- });
+ absl::c_copy_if(computation->instructions(), std::back_inserter(dot_instrs),
+ [](HloInstruction* instr) {
+ return instr->opcode() == HloOpcode::kDot;
+ });
}
for (HloInstruction* dot_instr : dot_instrs) {
TF_ASSIGN_OR_RETURN(bool elided_batch_dim_from_one,
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
index a725351462..f62ab12319 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
@@ -18,9 +18,9 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index cfd26fc778..cc15c7122f 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -22,8 +22,8 @@ limitations under the License.
#include <ostream>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_value_containers.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
@@ -1100,8 +1100,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
options.buffers_to_assign = &buffer_value_set;
TF_ASSIGN_OR_RETURN(
const HeapSimulator::Result result,
- HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
- MakeUnique<LazyBestFitHeap>(alignment)),
+ HeapSimulator::Run(absl::make_unique<DecreasingSizeRunsHeap>(
+ absl::make_unique<LazyBestFitHeap>(alignment)),
assignment->module(), module_sequence,
assignment->points_to_analysis(),
assignment->buffer_size_, options));
@@ -1130,11 +1130,12 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
options.buffers_to_assign = &buffer_value_set;
TF_ASSIGN_OR_RETURN(
const HeapSimulator::Result result,
- HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
- MakeUnique<LazyBestFitHeap>(alignment)),
- *computation, *instruction_sequence,
- assignment->points_to_analysis(),
- assignment->buffer_size_, options));
+ HeapSimulator::Run(
+ absl::make_unique<DecreasingSizeRunsHeap>(
+ absl::make_unique<LazyBestFitHeap>(alignment)),
+ *computation, *instruction_sequence,
+ assignment->points_to_analysis(), assignment->buffer_size_,
+ options));
AssignBuffersFromHeapSimulator(result, assignment,
single_colored_set.first);
}
@@ -1646,7 +1647,8 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
XLA_VLOG_LINES(3, liveness->ToString());
XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString());
- // Can't use MakeUnique because BufferAssignment constructor is private.
+ // Can't use absl::make_unique because BufferAssignment constructor is
+ // private.
std::unique_ptr<BufferAssignment> assignment(
new BufferAssignment(module, std::move(liveness), std::move(buffer_size),
std::move(color_alignment)));
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index eccb146a0d..52abda16c4 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -21,8 +21,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/copy_insertion.h"
@@ -87,7 +87,7 @@ class BufferAssignmentTest : public HloTestBase {
std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
int64 alignment = 1) {
return BufferAssigner::Run(
- module, xla::MakeUnique<DependencyHloOrdering>(module),
+ module, absl::make_unique<DependencyHloOrdering>(module),
backend().compiler()->BufferSizeBytesFunction(),
[alignment](LogicalBuffer::Color) { return alignment; },
/*allow_input_output_aliasing=*/false,
@@ -98,7 +98,7 @@ class BufferAssignmentTest : public HloTestBase {
std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersForConstants(
HloModule* module, int64 alignment = 1) {
return BufferAssigner::Run(
- module, xla::MakeUnique<DependencyHloOrdering>(module),
+ module, absl::make_unique<DependencyHloOrdering>(module),
backend().compiler()->BufferSizeBytesFunction(),
[alignment](LogicalBuffer::Color) { return alignment; },
/*allow_input_output_aliasing=*/false,
@@ -109,7 +109,7 @@ class BufferAssignmentTest : public HloTestBase {
std::unique_ptr<BufferAssignment> RunColoredBufferAssignment(
HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) {
return BufferAssigner::Run(
- module, xla::MakeUnique<DependencyHloOrdering>(module),
+ module, absl::make_unique<DependencyHloOrdering>(module),
backend().compiler()->BufferSizeBytesFunction(),
[alignment](LogicalBuffer::Color) { return alignment; },
/*allow_input_output_aliasing=*/false,
@@ -127,7 +127,8 @@ class BufferAssignmentTest : public HloTestBase {
instruction_sequence.end());
return BufferAssigner::Run(
module,
- xla::MakeUnique<SequentialHloOrdering>(module, module_sequence),
+ absl::make_unique<SequentialHloOrdering>(module,
+ module_sequence),
backend().compiler()->BufferSizeBytesFunction(),
[alignment](LogicalBuffer::Color) { return alignment; },
/*allow_input_output_aliasing=*/false,
@@ -1769,7 +1770,8 @@ class WhileBufferAssignmentTest : public HloTestBase {
auto sequence =
ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie();
return BufferAssigner::Run(
- module, xla::MakeUnique<SequentialHloOrdering>(module, sequence),
+ module,
+ absl::make_unique<SequentialHloOrdering>(module, sequence),
ByteSizeOf,
[alignment](LogicalBuffer::Color) { return alignment; },
/*allow_input_output_aliasing=*/false,
@@ -2083,7 +2085,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
auto assignment,
BufferAssigner::Run(
module.get(),
- xla::MakeUnique<SequentialHloOrdering>(module.get(), sequence),
+ absl::make_unique<SequentialHloOrdering>(module.get(), sequence),
backend().compiler()->BufferSizeBytesFunction(),
[](LogicalBuffer::Color) { return 1; },
/*allow_input_output_aliasing=*/false,
@@ -2340,7 +2342,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
auto assignment =
BufferAssigner::Run(
module.get(),
- xla::MakeUnique<SequentialHloOrdering>(module.get(), sequence),
+ absl::make_unique<SequentialHloOrdering>(module.get(), sequence),
ByteSizeOf, [](LogicalBuffer::Color) { return 1; },
/*allow_input_output_aliasing=*/false,
/*allocate_buffers_for_constants=*/true)
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
index 4a927b5767..3ffb7de65f 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
#include <string>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -119,8 +119,8 @@ TEST_F(BufferLivenessTest, ElementwiseChain) {
module->AddEntryComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate));
@@ -167,10 +167,10 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) {
SequentialHloOrdering::HloModuleSequence sequence;
sequence.insert({entry, {param0, negate, param1, exp, add}});
- auto liveness =
- BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>(
- module.get(), sequence))
- .ConsumeValueOrDie();
+ auto liveness = BufferLiveness::Run(module.get(),
+ absl::make_unique<SequentialHloOrdering>(
+ module.get(), sequence))
+ .ConsumeValueOrDie();
// Entry parameters interfere as if they are defined simultaneously at
// the very beginning.
@@ -215,8 +215,8 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) {
module->AddEntryComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
@@ -249,8 +249,8 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) {
module->AddEntryComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
@@ -293,10 +293,10 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) {
SequentialHloOrdering::HloModuleSequence module_sequence;
std::vector<const HloInstruction*> order = {param, negate, exp, add};
module_sequence.emplace(computation, order);
- auto liveness =
- BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>(
- module.get(), module_sequence))
- .ConsumeValueOrDie();
+ auto liveness = BufferLiveness::Run(module.get(),
+ absl::make_unique<SequentialHloOrdering>(
+ module.get(), module_sequence))
+ .ConsumeValueOrDie();
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
@@ -342,10 +342,10 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) {
std::vector<const HloInstruction*> order = {param, add, recv,
recv_done, send, send_done};
module_sequence.emplace(computation, order);
- auto liveness =
- BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>(
- module.get(), module_sequence))
- .ConsumeValueOrDie();
+ auto liveness = BufferLiveness::Run(module.get(),
+ absl::make_unique<SequentialHloOrdering>(
+ module.get(), module_sequence))
+ .ConsumeValueOrDie();
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
// Check the root instruction (add) buffer interferes with the recv buffer.
@@ -376,8 +376,8 @@ TEST_F(BufferLivenessTest, TupleLiveOut) {
module->AddEntryComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// All buffers should be live out except the param
@@ -412,8 +412,8 @@ TEST_F(BufferLivenessTest, EmbeddedComputation) {
module->AddEntryComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// Buffers in different computations should always interfere.
@@ -453,8 +453,8 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
module->AddEntryComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// Only the element buffers of the tuple constant which are pointed to by
@@ -518,8 +518,8 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) {
module->AddEmbeddedComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// We compare tuple element pairs that are input/output to the computation:
@@ -580,8 +580,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) {
module->AddEmbeddedComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// We compare tuple element pairs that are input/output to the computation:
@@ -668,10 +668,10 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
}
// Run BufferLiveness on 'module'.
- auto liveness =
- BufferLiveness::Run(
- module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get()))
- .ConsumeValueOrDie();
+ auto liveness = BufferLiveness::Run(
+ module.get(),
+ absl::make_unique<DependencyHloOrdering>(module.get()))
+ .ConsumeValueOrDie();
// Return whether or not buffers interference is detected between
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
@@ -780,10 +780,10 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
module->AddEntryComputation(BuildDummyComputation());
module->AddEmbeddedComputation(builder.Build());
// Run BufferLiveness on 'module'.
- auto liveness =
- BufferLiveness::Run(
- module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get()))
- .ConsumeValueOrDie();
+ auto liveness = BufferLiveness::Run(
+ module.get(),
+ absl::make_unique<DependencyHloOrdering>(module.get()))
+ .ConsumeValueOrDie();
// Return whether or not buffers interference is detected between
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc
index 985ff30e80..d6efef5f12 100644
--- a/tensorflow/compiler/xla/service/call_graph.cc
+++ b/tensorflow/compiler/xla/service/call_graph.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include <queue>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -237,8 +237,8 @@ void CallGraph::SetCallContexts() {
/* static */
std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
- // Constructor for CallGraph is private so MakeUnique can't be used.
- auto call_graph = WrapUnique<CallGraph>(new CallGraph(module));
+ // Constructor for CallGraph is private so absl::make_unique can't be used.
+ auto call_graph = absl::WrapUnique<CallGraph>(new CallGraph(module));
VLOG(2) << "Building call graph for:";
XLA_VLOG_LINES(2, module->ToString());
diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc
index ff968bca29..e75f6f146d 100644
--- a/tensorflow/compiler/xla/service/call_inliner_test.cc
+++ b/tensorflow/compiler/xla/service/call_inliner_test.cc
@@ -18,9 +18,9 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc
index 13008efed1..9c9e373821 100644
--- a/tensorflow/compiler/xla/service/channel_tracker.cc
+++ b/tensorflow/compiler/xla/service/channel_tracker.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/channel_tracker.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/status.h"
diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc
index 187ce568cb..afbbea35b8 100644
--- a/tensorflow/compiler/xla/service/computation_placer.cc
+++ b/tensorflow/compiler/xla/service/computation_placer.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -60,8 +60,8 @@ DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) {
"computation_count=%d",
proto.replica_count(), proto.computation_count());
}
- auto assignment = MakeUnique<DeviceAssignment>(proto.replica_count(),
- proto.computation_count());
+ auto assignment = absl::make_unique<DeviceAssignment>(
+ proto.replica_count(), proto.computation_count());
for (int computation = 0; computation < proto.computation_count();
++computation) {
const auto& computation_device = proto.computation_devices(computation);
@@ -156,7 +156,7 @@ ComputationPlacer::GetPlatformComputationPlacers() {
} // namespace xla
static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
- return xla::MakeUnique<xla::ComputationPlacer>();
+ return absl::make_unique<xla::ComputationPlacer>();
}
static bool InitModule() {
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
new file mode 100644
index 0000000000..8affa08b65
--- /dev/null
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
@@ -0,0 +1,248 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h"
+
+#include <memory>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+namespace {
+
+// ConvolutionVisitor traverses the HLO computation and rewrites Convolution
+// operations with feature_group_count > 1 into convolutions with
+// feature_group_count = 1.
+class ConvolutionVisitor : public DfsHloVisitorWithDefault {
+ public:
+ // Default visitor action is to do nothing and return OK.
+ Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
+ return Status::OK();
+ }
+
+ Status HandleConvolution(HloInstruction* convolution) override;
+
+ // Runs the visitor on a computation.
+ static bool Run(HloComputation* computation);
+
+ // Returns whether any convolution ops were rewritten.
+ const bool changed() const { return changed_; }
+
+ ~ConvolutionVisitor() override = default;
+
+ private:
+ explicit ConvolutionVisitor(HloComputation* computation)
+ : computation_(computation) {}
+
+ // Current HloComputation instance the ConvolutionVisitor is traversing.
+ HloComputation* computation_;
+
+ // Whether rewrite has occurred.
+ bool changed_ = false;
+};
+
+bool ConvolutionVisitor::Run(HloComputation* computation) {
+ ConvolutionVisitor visitor(computation);
+ TF_CHECK_OK(computation->Accept(&visitor));
+ return visitor.changed_;
+}
+
+Shape ExpandedFilterShape(const Shape& shape, int64 group_count,
+ int64 input_feature_dim) {
+ int64 num_dims = shape.dimensions_size();
+ CHECK_GE(num_dims, 2);
+ Shape expanded_shape = shape;
+ expanded_shape.set_dimensions(
+ input_feature_dim, shape.dimensions(input_feature_dim) * group_count);
+ return expanded_shape;
+}
+
+// Returns a vector with 'group_count' many groups, where the i-th group
+// consists of 'group_size' times the value i.
+std::vector<int32> GetMaskIds(int64 group_size, int64 group_count) {
+ std::vector<int32> values;
+ for (int i = 0; i < group_count; ++i) {
+ for (int j = 0; j < group_size; ++j) {
+ values.push_back(i);
+ }
+ }
+ return values;
+}
+
+// Create a mask for grouped convolution that will make a normal convolution
+// produce the same results as a grouped convolution. For a [2, 1, 6]
+// filter this returns a [2, 3, 6] mask
+// 1 1 0 0 0 0
+// 0 0 1 1 0 0
+// 0 0 0 0 1 1
+//
+// 1 1 0 0 0 0
+// 0 0 1 1 0 0
+// 0 0 0 0 1 1
+//
+// The first step is to create a rank 1 constant:
+// 0 1 2
+//
+// This is broadcasted to
+// 0 0 0 0 0 0
+// 1 1 1 1 1 1
+// 2 2 2 2 2 2
+//
+// 0 0 0 0 0 0
+// 1 1 1 1 1 1
+// 2 2 2 2 2 2
+//
+// Then we create another rank 1 constant
+// 0 0 1 1 2 2
+//
+// This is broadcasted to
+// 0 0 1 1 2 2
+// 0 0 1 1 2 2
+// 0 0 1 1 2 2
+//
+// 0 0 1 1 2 2
+// 0 0 1 1 2 2
+// 0 0 1 1 2 2
+//
+// Finally we use the Eq op of these two broadcasted constants and get the
+// desired mask.
+HloInstruction* GetExpandedFilterMask(
+ const Shape& filter_shape, int64 input_feature_dim,
+ int64 output_feature_dim, int64 group_count,
+ const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
+ add_instruction) {
+ Shape expanded_filter_shape =
+ ExpandedFilterShape(filter_shape, group_count, input_feature_dim);
+ Shape mask_shape = ShapeUtil::MakeShape(
+ S32, AsInt64Slice(expanded_filter_shape.dimensions()));
+ int64 output_feature = filter_shape.dimensions(output_feature_dim);
+ int64 group_size = filter_shape.dimensions(input_feature_dim);
+
+ // Create a 'input_feature' sized linspace and 'output_feature' sized linspace
+ // that will be broadcasted into perpendicular dimensions and compared.
+ const std::vector<int32> input_feature_filter_mask =
+ GetMaskIds(group_size, group_count);
+ const std::vector<int32> output_feature_filter_mask =
+ GetMaskIds(output_feature / group_count, group_count);
+
+ auto mask1 = add_instruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>(input_feature_filter_mask)));
+ auto broadcasted_mask1 = add_instruction(
+ HloInstruction::CreateBroadcast(mask_shape, mask1, {input_feature_dim}));
+ auto mask2 = add_instruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>(output_feature_filter_mask)));
+ auto broadcasted_mask2 = add_instruction(
+ HloInstruction::CreateBroadcast(mask_shape, mask2, {output_feature_dim}));
+
+ // Compare the broadcasted output feature linspace to the input feature
+ // linspace to create a diagonal predicate.
+ Shape predicate_shape = ShapeUtil::MakeShape(
+ PRED, AsInt64Slice(expanded_filter_shape.dimensions()));
+ return add_instruction(HloInstruction::CreateBinary(
+ predicate_shape, HloOpcode::kEq, broadcasted_mask1, broadcasted_mask2));
+}
+
+Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
+ int64 group_count = convolution->feature_group_count();
+ if (group_count == 1) {
+ return Status::OK();
+ }
+ auto filter = convolution->mutable_operand(1);
+ changed_ = true;
+ auto add = [&](std::unique_ptr<HloInstruction> inst) {
+ return computation_->AddInstruction(std::move(inst));
+ };
+
+ auto dim_numbers = convolution->convolution_dimension_numbers();
+ int64 input_feature_dim = dim_numbers.kernel_input_feature_dimension();
+ int64 group_size = filter->shape().dimensions(input_feature_dim);
+ int64 output_feature_dim = dim_numbers.kernel_output_feature_dimension();
+ auto expanded_filter_shape =
+ ExpandedFilterShape(filter->shape(), group_count, input_feature_dim);
+ HloInstruction* filter_mask = GetExpandedFilterMask(
+ filter->shape(), input_feature_dim, output_feature_dim, group_count, add);
+ HloInstruction* expanded_filter;
+ // We want to repeat 'filter' in the 'input_feature_dim' dimension
+ // 'group_count' times.
+ if (group_size == 1) {
+ Shape reshaped_filter_shape =
+ ShapeUtil::DeleteDimension(input_feature_dim, filter->shape());
+ auto reshaped_filter =
+ add(HloInstruction::CreateReshape(reshaped_filter_shape, filter));
+ std::vector<int64> broadcast_dims;
+ for (int64 i = 0; i < filter->shape().dimensions_size(); ++i) {
+ if (i == input_feature_dim) {
+ continue;
+ }
+ broadcast_dims.push_back(i);
+ }
+ expanded_filter = add(HloInstruction::CreateBroadcast(
+ expanded_filter_shape, reshaped_filter, broadcast_dims));
+ } else {
+ // We could possibly also use reshape, broadcast, reshape instead of concat
+ // here, but it would require more complex code, and for depthwise
+ // convolution we would never end up in this branch.
+ std::vector<HloInstruction*> concat_operands(group_count, filter);
+ expanded_filter = add(HloInstruction::CreateConcatenate(
+ expanded_filter_shape, concat_operands, input_feature_dim));
+ }
+ auto zero = add(HloInstruction::CreateConstant(absl::make_unique<Literal>(
+ LiteralUtil::Zero(expanded_filter_shape.element_type()))));
+ auto zero_filter =
+ add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
+ auto new_filter = add(
+ HloInstruction::CreateTernary(expanded_filter_shape, HloOpcode::kSelect,
+ filter_mask, expanded_filter, zero_filter));
+ auto new_convolution = HloInstruction::CreateConvolve(
+ convolution->shape(), convolution->mutable_operand(0), new_filter,
+ convolution->window(), dim_numbers, /*feature_group_count=*/1);
+ TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
+ convolution, std::move(new_convolution)));
+ return Status::OK();
+}
+
+} // namespace
+
+StatusOr<bool> ConvolutionFeatureGroupConverter::Run(HloModule* module) {
+ XLA_VLOG_LINES(2, "ConvolutionFeatureGroupConverter::Run(), before:\n" +
+ module->ToString());
+ bool changed = false;
+ for (auto* comp : module->MakeNonfusionComputations()) {
+ if (ConvolutionVisitor::Run(comp)) {
+ changed = true;
+ }
+ }
+ XLA_VLOG_LINES(2, "ConvolutionFeatureGroupConverter::Run(), after:\n" +
+ module->ToString());
+ return changed;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
new file mode 100644
index 0000000000..f213cc8709
--- /dev/null
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
@@ -0,0 +1,43 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace xla {
+
+// A pass which rewrites convolutions with feature_group_count > 1 into
+// convolutions with feature_group_count = 1.
+class ConvolutionFeatureGroupConverter : public HloPassInterface {
+ public:
+ ConvolutionFeatureGroupConverter() {}
+
+ tensorflow::StringPiece name() const override {
+ return "convolution-feature-group-converter";
+ }
+
+ // Run convolution rewriting on the given computation. Returns whether the
+ // computation was changed.
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc
new file mode 100644
index 0000000000..28373ebf63
--- /dev/null
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc
@@ -0,0 +1,100 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h"
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+namespace {
+
+using ConvolutionFeatureGroupConverterTest = HloTestBase;
+namespace op = testing::opcode_matchers;
+
+TEST_F(ConvolutionFeatureGroupConverterTest,
+ ConvertFeatureGroupCountEqualToInputFeatureDim) {
+ string hlo_string = R"(HloModule Convolve1D1Window_0_module
+
+ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,2], filter: f32[1,1,2]) -> f32[1,2,2] {
+ %input = f32[1,2,2]{2,1,0} parameter(0)
+ %copy = f32[1,2,2]{2,0,1} copy(f32[1,2,2]{2,1,0} %input)
+ %filter = f32[1,1,2]{2,1,0} parameter(1)
+ ROOT %convolution = f32[1,2,2]{2,0,1} convolution(f32[1,2,2]{2,0,1} %copy, f32[1,1,2]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=2
+})";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_string));
+
+ auto computation = module->entry_computation();
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kConvolution);
+ ConvolutionFeatureGroupConverter converter;
+ ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ // Make sure the convolution is converted to one with feature_group_count = 1.
+ EXPECT_EQ(root->opcode(), HloOpcode::kConvolution);
+ EXPECT_EQ(root->feature_group_count(), 1);
+ // Verify that the filter operand has been replaced.
+ EXPECT_THAT(root->operand(1),
+ op::Select(op::Eq(op::Broadcast(op::Constant()),
+ op::Broadcast(op::Constant())),
+ op::Broadcast(op::Reshape(op::Parameter())),
+ op::Broadcast(op::Constant())));
+}
+
+TEST_F(ConvolutionFeatureGroupConverterTest,
+ ConvertFeatureGroupCountDivisorOfInputFeatureDim) {
+ string hlo_string = R"(HloModule Convolve1D1Window_0_module
+
+ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,4], filter: f32[1,2,2]) -> f32[1,2,2] {
+ %input = f32[1,2,4]{2,1,0} parameter(0)
+ %copy = f32[1,2,4]{2,0,1} copy(f32[1,2,4]{2,1,0} %input)
+ %filter = f32[1,2,2]{2,1,0} parameter(1)
+ ROOT %convolution = f32[1,2,2]{2,0,1} convolution(f32[1,2,4]{2,0,1} %copy, f32[1,2,2]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=2
+})";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_string));
+
+ auto computation = module->entry_computation();
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kConvolution);
+ ConvolutionFeatureGroupConverter converter;
+ ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ // Make sure the convolution is converted to one with feature_group_count = 1.
+ EXPECT_EQ(root->opcode(), HloOpcode::kConvolution);
+ EXPECT_EQ(root->feature_group_count(), 1);
+ // Verify that the filter operand has been replaced.
+ EXPECT_THAT(root->operand(1),
+ op::Select(op::Eq(op::Broadcast(op::Constant()),
+ op::Broadcast(op::Constant())),
+ // We expect to see Concatenate here instead of
+ // Broadcast, because feature_group_count < input
+ // feature dimension.
+ op::Concatenate(op::Parameter(), op::Parameter()),
+ op::Broadcast(op::Constant())));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 84779c60b0..850948b54b 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -50,6 +50,7 @@ cc_library(
"//tensorflow/compiler/xla/service/cpu:cpu_runtime",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
alwayslink = True, # Contains per-platform transfer manager registration
)
@@ -85,6 +86,7 @@ cc_library(
":ir_emitter",
":parallel_task_assignment",
":simple_orc_jit",
+ "@com_google_absl//absl/memory",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/xla/service:scatter_expander",
"//tensorflow/compiler/xla:literal",
@@ -101,6 +103,7 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_liveness",
"//tensorflow/compiler/xla/service:call_inliner",
"//tensorflow/compiler/xla/service:conditional_simplifier",
+ "//tensorflow/compiler/xla/service:convolution_feature_group_converter",
"//tensorflow/compiler/xla/service:dot_decomposer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
@@ -177,6 +180,7 @@ cc_library(
":runtime_single_threaded_conv2d",
":runtime_single_threaded_fft",
":runtime_single_threaded_matmul",
+ "@com_google_absl//absl/memory",
"@llvm//:execution_engine",
"@llvm//:core",
"@llvm//:mc", # fixdeps: keep
@@ -417,6 +421,7 @@ cc_library(
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
"@llvm//:analysis",
"@llvm//:core",
"@llvm//:ipo",
@@ -633,6 +638,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//third_party/eigen3",
+ "@com_google_absl//absl/memory",
],
)
@@ -809,6 +815,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
"//tensorflow/compiler/xla/service:hlo_pass",
+ "@com_google_absl//absl/memory",
],
)
@@ -892,6 +899,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
"@llvm//:core",
"@llvm//:support",
],
diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
index 128eea4828..73b03440cb 100644
--- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
+++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
@@ -35,7 +36,6 @@ limitations under the License.
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/IPO/AlwaysInliner.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
@@ -205,7 +205,7 @@ void CompilerFunctor::AddTargetInfoPasses(
llvm::legacy::PassManagerBase* passes) const {
llvm::Triple target_triple(target_machine_->getTargetTriple());
auto target_library_info_impl =
- MakeUnique<llvm::TargetLibraryInfoImpl>(target_triple);
+ absl::make_unique<llvm::TargetLibraryInfoImpl>(target_triple);
target_library_info_impl->addVectorizableFunctions(
VectorFunctionsForTargetLibraryInfoImpl());
passes->add(
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 35154af048..5116f926f5 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -26,6 +26,7 @@ limitations under the License.
// IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc"
// IWYU pragma: no_include "llvm/Config/Targets.def.inc"
+#include "absl/memory/memory.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Triple.h"
#include "llvm/IR/Function.h"
@@ -42,7 +43,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
@@ -50,6 +50,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
+#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h"
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
@@ -258,6 +259,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pipeline.AddPass<CallInliner>();
pipeline.AddPass<BatchDotSimplification>();
pipeline.AddPass<DotDecomposer>();
+ pipeline.AddPass<ConvolutionFeatureGroupConverter>();
pipeline.AddPass<ConvCanonicalization>(&target_machine_features);
{
auto& pass =
@@ -276,7 +278,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
// BatchNormExpander can create zero-sized ops, so zero-sized HLO
// elimination has to come after that pass.
- pipeline.AddPass<ZeroSizedHloElimination>();
+ pass.AddPass<ZeroSizedHloElimination>();
pass.AddPass<WhileLoopInvariantCodeMotion>();
pass.AddPass<TupleSimplifier>();
@@ -451,7 +453,7 @@ Status CreateHloProfilingArtifacts(
computation_to_profile_idx,
std::unique_ptr<HloProfileIndexMap>* hlo_profile_index_map,
std::unique_ptr<HloProfilePrinterData>* hlo_profile_printer_data) {
- *hlo_profile_index_map = MakeUnique<HloProfileIndexMap>(module);
+ *hlo_profile_index_map = absl::make_unique<HloProfileIndexMap>(module);
const HloComputation& entry_computation = *module.entry_computation();
TF_ASSIGN_OR_RETURN(
@@ -518,11 +520,11 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
&pre_optimization_ir_hook, &post_optimization_ir_hook));
// Compile must be thread-safe so create a new LLVM context for the module.
- auto llvm_context = xla::MakeUnique<llvm::LLVMContext>();
+ auto llvm_context = absl::make_unique<llvm::LLVMContext>();
auto llvm_module =
- xla::MakeUnique<llvm::Module>("__compute_module", *llvm_context);
+ absl::make_unique<llvm::Module>("__compute_module", *llvm_context);
- auto jit = xla::MakeUnique<SimpleOrcJIT>(
+ auto jit = absl::make_unique<SimpleOrcJIT>(
CompilerTargetOptions(module->config()),
CodeGenOptLevel(module->config()),
options::OptimizeForSizeRequested(module->config()),
@@ -564,12 +566,12 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// temporary buffers are required to run the computation.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> assignment,
- BufferAssigner::Run(
- module.get(),
- xla::MakeUnique<SequentialHloOrdering>(module.get(), module_sequence),
- BufferSizeBytesFunction(), memory_alignment,
- /*allow_input_output_aliasing=*/false,
- /*allocate_buffers_for_constants=*/true));
+ BufferAssigner::Run(module.get(),
+ absl::make_unique<SequentialHloOrdering>(
+ module.get(), module_sequence),
+ BufferSizeBytesFunction(), memory_alignment,
+ /*allow_input_output_aliasing=*/false,
+ /*allocate_buffers_for_constants=*/true));
// BufferAssignment::ToString() includes a header, so no need for us to
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
@@ -714,7 +716,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
llvm::StringRef cpu_name = llvm_ir::AsStringRef(options.cpu_name());
llvm::StringRef features = llvm_ir::AsStringRef(options.features());
llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config());
- std::unique_ptr<llvm::TargetMachine> target_machine = WrapUnique(
+ std::unique_ptr<llvm::TargetMachine> target_machine = absl::WrapUnique(
target->createTargetMachine(triple.getTriple(), cpu_name, features,
CompilerTargetOptions(modules[0]->config()),
reloc_model, llvm::None, opt_level));
@@ -755,7 +757,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
std::unique_ptr<BufferAssignment> assignment,
BufferAssigner::Run(
module,
- xla::MakeUnique<SequentialHloOrdering>(module, module_sequence),
+ absl::make_unique<SequentialHloOrdering>(module, module_sequence),
BufferSizeBytesFunction(), memory_alignment,
/*allow_input_output_aliasing=*/false,
/*allocate_buffers_for_constants=*/true));
@@ -849,7 +851,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
assignment->GetUniqueTopLevelOutputSlice());
- results.emplace_back(MakeUnique<CpuAotCompilationResult>(
+ results.emplace_back(absl::make_unique<CpuAotCompilationResult>(
std::move(object_file_data), std::move(buffer_infos),
result_slice.index(), std::move(hlo_profile_printer_data)));
}
@@ -872,7 +874,7 @@ HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const {
static bool InitModule() {
xla::Compiler::RegisterCompilerFactory(
stream_executor::host::kHostPlatformId,
- []() { return xla::MakeUnique<xla::cpu::CpuCompiler>(); });
+ []() { return absl::make_unique<xla::cpu::CpuCompiler>(); });
return true;
}
static bool module_initialized = InitModule();
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 991b14f17d..e6130c7d76 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -697,8 +697,9 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name,
HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend));
if (add_extra_use_for_dot) {
+ auto* token = builder.AddInstruction(HloInstruction::CreateToken());
builder.AddInstruction(
- HloInstruction::CreateOutfeed(dot_shape, dot, "no_config"));
+ HloInstruction::CreateOutfeed(dot_shape, dot, token, "no_config"));
}
module->AddEntryComputation(builder.Build());
@@ -791,11 +792,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
one = s32[] constant(1)
one_broadcasted = s32[3,2] broadcast(one), dimensions={}
ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted)
@@ -807,11 +808,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,3,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
one = s32[] constant(1)
one_broadcasted = s32[2,3,2] broadcast(one), dimensions={}
ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted)
@@ -823,11 +824,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=2,
- window_bounds={1, 1}
+ slice_sizes={1, 1}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -839,11 +840,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -855,11 +856,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -871,11 +872,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
gather = s32[1,1] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={0,1},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
one = s32[] constant(1)
one_broadcasted = s32[1,1] broadcast(one), dimensions={}
ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted)
@@ -887,11 +888,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,1,1] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
one = s32[] constant(1)
one_broadcasted = s32[2,1,1] broadcast(one), dimensions={}
ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted)
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
index 2ac950e6d9..bc4cfc0999 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
@@ -19,10 +19,10 @@ limitations under the License.
#include <string>
#include <tuple>
+#include "absl/memory/memory.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
@@ -46,7 +46,7 @@ std::unique_ptr<Array2D<float>> MaybeTransposeArray2D(const Array2D<T>& array,
if (transpose) {
std::swap(output_width, output_height);
}
- auto output = MakeUnique<Array2D<float>>(output_height, output_width);
+ auto output = absl::make_unique<Array2D<float>>(output_height, output_width);
for (int y = 0; y < array.height(); y++) {
for (int x = 0; x < array.width(); x++) {
if (transpose) {
@@ -93,7 +93,7 @@ std::unique_ptr<Array2D<float>> EigenMatrixMultiply(const Array2D<float>& a,
// Since we're going to transpose c before returning it. Swap the order of the
// dimension sizes to ensure the returned array is properly dimensioned.
- auto c_transpose = MakeUnique<Array2D<float>>(n, m);
+ auto c_transpose = absl::make_unique<Array2D<float>>(n, m);
if (single_threaded) {
__xla_cpu_runtime_EigenSingleThreadedMatMulF32(
nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(),
@@ -204,7 +204,7 @@ std::unique_ptr<Array2D<float>> MKLMatrixMultiply(const Array2D<float>& a,
// Since we're going to transpose c before returning it, swap the order of the
// dimension sizes to ensure the returned array is properly dimensioned.
- auto c_transpose = MakeUnique<Array2D<float>>(n, m);
+ auto c_transpose = absl::make_unique<Array2D<float>>(n, m);
if (single_threaded) {
__xla_cpu_runtime_MKLSingleThreadedMatMulF32(
nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(),
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
index 59bc7e0e16..b07cd675ff 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
@@ -256,7 +257,7 @@ StatusOr<Shape> CpuTransferManager::TransferBuffersFromOutfeedInternal(
VLOG(2)
<< "Enqueueing outfeed buffer (for the device to populate) of length "
<< size_32 << "B";
- buffers.emplace_back(MakeUnique<CpuOutfeedBuffer>(b.first, size_32));
+ buffers.emplace_back(absl::make_unique<CpuOutfeedBuffer>(b.first, size_32));
}
std::vector<cpu::runtime::XfeedBuffer*> buffer_pointers;
@@ -283,7 +284,7 @@ StatusOr<Shape> CpuTransferManager::TransferBuffersFromOutfeedInternal(
} // namespace xla
static std::unique_ptr<xla::TransferManager> CreateCpuTransferManager() {
- return xla::MakeUnique<xla::CpuTransferManager>();
+ return absl::make_unique<xla::CpuTransferManager>();
}
static bool InitModule() {
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
index 4fa5984b04..286d407ca6 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
@@ -109,7 +110,7 @@ ParallelTaskAssignment::ParallelTaskAssignment(
: target_machine_features_(*target_machine_features) {
VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism;
// Run cost analysis on 'module'.
- auto cost_analysis = MakeUnique<HloCostAnalysis>(shape_size);
+ auto cost_analysis = absl::make_unique<HloCostAnalysis>(shape_size);
HloComputation* computation = module->entry_computation();
Status status = computation->root_instruction()->Accept(cost_analysis.get());
if (status.ok()) {
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
index 36c9f74385..ee272b5f4f 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
@@ -110,9 +110,10 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) {
const string hlo_string = R"(
HloModule TestTaskParallel_infeed_outfeed
ENTRY InfeedOutfeed {
- infeed0 = (u32[12345678,2]{1,0}, token[]) infeed()
+ token = token[] after-all()
+ infeed0 = (u32[12345678,2]{1,0}, token[]) infeed(token)
infeed0.data = u32[12345678,2]{1,0} get-tuple-element((u32[12345678,2]{1,0}, token[]) infeed0), index=0
- ROOT outfeed0 = token[] outfeed(infeed0.data)
+ ROOT outfeed0 = token[] outfeed(infeed0.data, token)
}
)";
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index be772cfb7e..b026aef3fe 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -20,13 +20,13 @@ limitations under the License.
#include <list>
#include <utility>
+#include "absl/memory/memory.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/JITSymbol.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/IR/Mangler.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/Host.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index 181cec3cdd..4635fa5d74 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -51,6 +51,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -94,6 +95,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:filecheck",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
"@llvm//:core",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
index d98856fdbf..b68ac67574 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc
index 90b99c828e..3b87683fff 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc
@@ -38,7 +38,8 @@ while_body {
while_cond {
arg_cond = f32[2,3,2] parameter(0)
- infeed = (pred[], token[]) infeed()
+ token = token[] after-all()
+ infeed = (pred[], token[]) infeed(token)
ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0
}
@@ -50,8 +51,9 @@ ENTRY main {
{{2, 1}, {2001, 3002}, {2001, 2002}}})
const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body
- out0 = token[] outfeed(f32[2,3,2] const_a)
- ROOT out1 = token[] outfeed(f32[2,3,2] const_b)
+ token = token[] after-all()
+ out0 = token[] outfeed(f32[2,3,2] const_a, token[] token)
+ ROOT out1 = token[] outfeed(f32[2,3,2] const_b, token[] token)
}
)";
@@ -85,7 +87,8 @@ while_body {
while_cond {
arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0)
- infeed = (pred[], token[]) infeed()
+ token = token[] after-all()
+ infeed = (pred[], token[]) infeed(token)
ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0
}
@@ -94,8 +97,9 @@ ENTRY main {
const_a = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} ))
const_b = (f32[2,1]{1,0}, f32[1]{0}) while((f32[2,1]{1,0}, f32[1]{0}) const_a), condition=while_cond, body=while_body
- out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a)
- ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b)
+ token = token[] after-all()
+ out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a, token[] token)
+ ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b, token[] token)
}
)";
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
index 01daed4bcd..bb105194f1 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
@@ -16,9 +16,9 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -62,7 +62,8 @@ TEST_F(CpuNoAliasTest, Concat) {
// Now that we have an HLO module, build an llvm_ir::AliasAnalysis for it.
auto status_or_buffer_assn = BufferAssigner::Run(
- hlo_module.get(), MakeUnique<DependencyHloOrdering>(hlo_module.get()),
+ hlo_module.get(),
+ absl::make_unique<DependencyHloOrdering>(hlo_module.get()),
backend().compiler()->BufferSizeBytesFunction(),
[](LogicalBuffer::Color) { return /*alignment=*/1; });
ASSERT_EQ(status_or_buffer_assn.status(), Status::OK());
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
index dac416e1c7..780c07f819 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
@@ -32,7 +32,8 @@ ENTRY main {
{{{1, 2}, {1001, 1002}, {2001, 2002}},
{{2, 1}, {2001, 3002}, {2001, 2002}}})
- outfeed = token[] outfeed(f32[2,3,2] const_a)
+ token = token[] after-all()
+ outfeed = token[] outfeed(f32[2,3,2] const_a, token)
ROOT root = () tuple()
}
)";
diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
index 3274be8d9d..962ea69c09 100644
--- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
+++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
+#include "absl/algorithm/container.h"
#include "llvm/Support/raw_ostream.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
@@ -422,8 +423,8 @@ TileVariable::TileVariable(VectorSupportLibrary* vector_support,
std::vector<llvm::Value*> TileVariable::Get() const {
std::vector<llvm::Value*> result;
- c_transform(storage_, std::back_inserter(result),
- [&](VectorVariable vect_var) { return vect_var.Get(); });
+ absl::c_transform(storage_, std::back_inserter(result),
+ [&](VectorVariable vect_var) { return vect_var.Get(); });
return result;
}
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 2e9d6be2de..4b19aa5df9 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "absl/algorithm/container.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
@@ -1672,22 +1673,21 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
std::vector<int64> operand_to_output_dim(operand_shape.dimensions_size(), -1);
for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0;
i < e; i++) {
- if (c_binary_search(dim_numbers.elided_window_dims(), i)) {
+ if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
operand_index.push_back(index.GetConstantWithIndexType(0));
} else {
- int64 output_window_dim =
- dim_numbers.output_window_dims(operand_index_dim++);
+ int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++);
operand_to_output_dim[i] = output_window_dim;
operand_index.push_back(index[output_window_dim]);
}
}
- // This is the index of the index vector in the gather_indices tensor.
+ // This is the index of the index vector in the start_indices tensor.
IrArray::Index gather_index_index(index_type);
{
std::vector<llvm::Value*> gather_index_index_components;
for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) {
- if (!c_binary_search(dim_numbers.output_window_dims(), i)) {
+ if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
gather_index_index.push_back(index[i]);
}
}
@@ -1700,7 +1700,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) {
llvm::Value* gather_dim_component_extended =
b_->CreateSExtOrTrunc(index_component, index_type);
- int64 operand_dim = dim_numbers.gather_dims_to_operand_dims(dim);
+ int64 operand_dim = dim_numbers.start_index_map(dim);
int64 output_dim = operand_to_output_dim[operand_dim];
// If 'output_dim' is -1, it means 'operand_dim' is an elided window dim.
// This means we set the iteration index to 0, so for the purpose of the
diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc
index fd75847d0c..1c9f396b68 100644
--- a/tensorflow/compiler/xla/service/executable.cc
+++ b/tensorflow/compiler/xla/service/executable.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/executable.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/status.h"
@@ -76,8 +77,8 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper(
std::unique_ptr<HloExecutionProfile> profile_ptr =
module_config().debug_options().xla_hlo_profile() &&
hlo_profiling_enabled()
- ? MakeUnique<HloExecutionProfile>(&hlo_profile_printer_data(),
- &hlo_profile_index_map())
+ ? absl::make_unique<HloExecutionProfile>(&hlo_profile_printer_data(),
+ &hlo_profile_index_map())
: nullptr;
StatusOr<ScopedShapedBuffer> return_value =
diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc
index 228c3fac95..70a78c8a2b 100644
--- a/tensorflow/compiler/xla/service/execution_tracker.cc
+++ b/tensorflow/compiler/xla/service/execution_tracker.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <utility>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -53,8 +53,8 @@ ExecutionHandle ExecutionTracker::Register(Backend* backend,
tensorflow::mutex_lock lock(execution_mutex_);
int64 handle = next_handle_++;
auto inserted = handle_to_execution_.emplace(
- handle,
- MakeUnique<AsyncExecution>(backend, std::move(streams), profile, result));
+ handle, absl::make_unique<AsyncExecution>(backend, std::move(streams),
+ profile, result));
CHECK(inserted.second);
ExecutionHandle execution_handle;
diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc
index e3a42d0d06..d889fd8e88 100644
--- a/tensorflow/compiler/xla/service/gather_expander.cc
+++ b/tensorflow/compiler/xla/service/gather_expander.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include <utility>
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gather_expander.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
@@ -27,85 +28,85 @@ namespace xla {
using tensorflow::gtl::ArraySlice;
static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast(
- HloInstruction* gather_indices, int64 index_vector_dim) {
- const Shape& gather_indices_shape = gather_indices->shape();
+ HloInstruction* start_indices, int64 index_vector_dim) {
+ const Shape& start_indices_shape = start_indices->shape();
- if (gather_indices_shape.dimensions_size() == index_vector_dim) {
- return gather_indices;
+ if (start_indices_shape.dimensions_size() == index_vector_dim) {
+ return start_indices;
}
- if (index_vector_dim == (gather_indices_shape.dimensions_size() - 1)) {
- return gather_indices;
+ if (index_vector_dim == (start_indices_shape.dimensions_size() - 1)) {
+ return start_indices;
}
std::vector<int64> permutation;
- permutation.reserve(gather_indices_shape.dimensions_size());
- for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) {
+ permutation.reserve(start_indices_shape.dimensions_size());
+ for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
if (i != index_vector_dim) {
permutation.push_back(i);
}
}
permutation.push_back(index_vector_dim);
- return MakeTransposeHlo(gather_indices, permutation);
+ return MakeTransposeHlo(start_indices, permutation);
}
-// Canonicalizes the gather_indices tensors so that we only have deal with some
+// Canonicalizes the start_indices tensors so that we only have deal with some
// specific cases in the while loop that does the heavy lifting.
//
// See the "High Level Algorithm" section for a broader picture.
static StatusOr<HloInstruction*> CanonicalizeGatherIndices(
- HloInstruction* gather_indices, int64 index_vector_dim) {
+ HloInstruction* start_indices, int64 index_vector_dim) {
// Transpose the non-index-vector dimensions to the front.
TF_ASSIGN_OR_RETURN(
- HloInstruction * transposed_gather_indices,
- TransposeIndexVectorDimToLast(gather_indices, index_vector_dim));
+ HloInstruction * transposed_start_indices,
+ TransposeIndexVectorDimToLast(start_indices, index_vector_dim));
bool indices_are_scalar =
- index_vector_dim == gather_indices->shape().dimensions_size();
+ index_vector_dim == start_indices->shape().dimensions_size();
- // The number of dimensions in gather_indices that are index dimensions.
- const int64 index_dims_in_gather_indices = indices_are_scalar ? 0 : 1;
+ // The number of dimensions in start_indices that are index dimensions.
+ const int64 index_dims_in_start_indices = indices_are_scalar ? 0 : 1;
- // If there is only one index (i.e. gather_indices has rank 1 and this gather
+ // If there is only one index (i.e. start_indices has rank 1 and this gather
// is really just a dynamic slice) add a leading degenerate dimension for
// uniformity. Otherwise create a "collapsed" leading dimension that subsumes
// all of the non-index-vector dimensions.
- const Shape& shape = transposed_gather_indices->shape();
- if (shape.dimensions_size() == index_dims_in_gather_indices) {
- return PrependDegenerateDims(transposed_gather_indices, 1);
+ const Shape& shape = transposed_start_indices->shape();
+ if (shape.dimensions_size() == index_dims_in_start_indices) {
+ return PrependDegenerateDims(transposed_start_indices, 1);
} else {
- // Collapse all but the dimensions (0 or 1) in gather_indices containing the
+ // Collapse all but the dimensions (0 or 1) in start_indices containing the
// index vectors.
return CollapseFirstNDims(
- transposed_gather_indices,
- shape.dimensions_size() - index_dims_in_gather_indices);
+ transposed_start_indices,
+ shape.dimensions_size() - index_dims_in_start_indices);
}
}
// Expands out or contracts away the gather dimensions in the accumulator
// produced by the while loop.
-static StatusOr<HloInstruction*> AdjustGatherDimsInAccumulator(
- const Shape& gather_indices_shape, HloInstruction* accumulator,
+static StatusOr<HloInstruction*> AdjustBatchDimsInAccumulator(
+ const Shape& start_indices_shape, HloInstruction* accumulator,
int64 index_vector_dim) {
- std::vector<int64> output_gather_dim_bounds;
- output_gather_dim_bounds.reserve(gather_indices_shape.dimensions_size());
- for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) {
+ std::vector<int64> batch_dim_bounds;
+ batch_dim_bounds.reserve(start_indices_shape.dimensions_size());
+ for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
if (i != index_vector_dim) {
- output_gather_dim_bounds.push_back(gather_indices_shape.dimensions(i));
+ batch_dim_bounds.push_back(start_indices_shape.dimensions(i));
}
}
- if (output_gather_dim_bounds.empty()) {
- // If output_gather_dim_bounds is empty we must be lowering a (effectively)
+ if (batch_dim_bounds.empty()) {
+ // If batch_dim_bounds is empty we must be lowering a (effectively)
// dynamic-slice. In that case, there is a leading degenerate gather
// dimension that we added to make this special case play well with the
// general while loop which we need to remove now.
return ElideDegenerateDims(accumulator, {0});
}
- return ExpandFirstDimIntoNDims(accumulator, output_gather_dim_bounds);
+ return ExpandFirstDimIntoNDims(accumulator, batch_dim_bounds);
}
-// Expand an index vector from the gather_indices tensor into a vector that can
+// Expand an index vector from the start_indices tensor into a vector that can
// be used to dynamic-slice out of the gather operand.
static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
HloInstruction* index_vector, const GatherDimensionNumbers& dim_numbers,
@@ -121,10 +122,8 @@ static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
std::vector<HloInstruction*> expanded_index_components;
for (int i = 0; i < operand_rank; i++) {
- int64 index_vector_dim_index =
- FindIndex(dim_numbers.gather_dims_to_operand_dims(), i);
- if (index_vector_dim_index !=
- dim_numbers.gather_dims_to_operand_dims_size()) {
+ int64 index_vector_dim_index = FindIndex(dim_numbers.start_index_map(), i);
+ if (index_vector_dim_index != dim_numbers.start_index_map_size()) {
TF_ASSIGN_OR_RETURN(
HloInstruction * component_to_concat,
MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index},
@@ -147,10 +146,10 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
const GatherDimensionNumbers& dim_numbers = gather.gather_dimension_numbers();
CHECK_EQ(incoming_loop_state.size(), 3);
HloInstruction* const operand = incoming_loop_state[0];
- HloInstruction* const gather_indices = incoming_loop_state[1];
+ HloInstruction* const start_indices = incoming_loop_state[1];
HloInstruction* const output_accumulator = incoming_loop_state[2];
- bool has_scalar_indices = gather_indices->shape().dimensions_size() == 1;
+ bool has_scalar_indices = start_indices->shape().dimensions_size() == 1;
CHECK_EQ(has_scalar_indices,
dim_numbers.index_vector_dim() ==
gather.operand(1)->shape().dimensions_size());
@@ -163,24 +162,24 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
HloInstruction* index_vector;
if (has_scalar_indices) {
- // In this case gather_indices has rank 1 and induction_var_as_vector (of
+ // In this case start_indices has rank 1 and induction_var_as_vector (of
// shape {1}) is an index into this rank 1 tensor.
TF_ASSIGN_OR_RETURN(
index_vector,
- MakeDynamicSliceHlo(gather_indices, induction_var_as_vector, {1}));
+ MakeDynamicSliceHlo(start_indices, induction_var_as_vector, {1}));
} else {
- // In this case gather_indices has rank 2 and induction_var_as_vector (of
+ // In this case start_indices has rank 2 and induction_var_as_vector (of
// shape {1}) is an index into just the first dimension of this rank 2
// tensor.
TF_ASSIGN_OR_RETURN(
- HloInstruction * index_into_gather_indices,
+ HloInstruction * index_into_start_indices,
PadVectorWithZeros(induction_var_as_vector,
/*zeros_to_prepend=*/0, /*zeros_to_append=*/1));
- int64 index_vector_size = gather_indices->shape().dimensions(1);
+ int64 index_vector_size = start_indices->shape().dimensions(1);
TF_ASSIGN_OR_RETURN(
HloInstruction * index_vector_2d,
- MakeDynamicSliceHlo(gather_indices, index_into_gather_indices,
+ MakeDynamicSliceHlo(start_indices, index_into_start_indices,
{1, index_vector_size}));
TF_ASSIGN_OR_RETURN(index_vector,
@@ -194,26 +193,26 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice,
MakeDynamicSliceHlo(operand, gathered_slice_start,
- gather.gather_window_bounds()));
+ gather.gather_slice_sizes()));
TF_ASSIGN_OR_RETURN(
- HloInstruction * gathered_slice_with_dims_elided,
+ HloInstruction* const gathered_slice_with_dims_collapsed,
ElideDegenerateDims(gathered_slice,
- AsInt64Slice(dim_numbers.elided_window_dims())));
+ AsInt64Slice(dim_numbers.collapsed_slice_dims())));
TF_ASSIGN_OR_RETURN(
- HloInstruction * gathered_slice_for_update,
- PrependDegenerateDims(gathered_slice_with_dims_elided, 1));
+ HloInstruction* const gathered_slice_for_update,
+ PrependDegenerateDims(gathered_slice_with_dims_collapsed, 1));
TF_ASSIGN_OR_RETURN(
- HloInstruction * index_vector_into_accumulator,
+ HloInstruction* const index_vector_into_accumulator,
PadVectorWithZeros(
induction_var_as_vector, /*zeros_to_prepend=*/0,
/*zeros_to_append=*/
- gathered_slice_with_dims_elided->shape().dimensions_size()));
+ gathered_slice_with_dims_collapsed->shape().dimensions_size()));
TF_ASSIGN_OR_RETURN(
- HloInstruction * updated_accumulator,
+ HloInstruction* const updated_accumulator,
MakeDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update,
index_vector_into_accumulator));
@@ -221,19 +220,19 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
// WhileUtil::MakeCountedLoop functions takes care of the induction variable
// and the while loop exit condition.
return StatusOr<std::vector<HloInstruction*>>{
- {operand, gather_indices, updated_accumulator}};
+ {operand, start_indices, updated_accumulator}};
}
static StatusOr<HloInstruction*> CreateGatherLoopAccumulatorInitValue(
HloComputation* computation, PrimitiveType element_type,
- ArraySlice<int64> window_bounds, int64 gather_loop_trip_count,
+ ArraySlice<int64> slice_sizes, int64 gather_loop_trip_count,
const GatherDimensionNumbers& dim_numbers) {
std::vector<int64> accumulator_state_shape_dims;
- accumulator_state_shape_dims.reserve(1 + window_bounds.size());
+ accumulator_state_shape_dims.reserve(1 + slice_sizes.size());
accumulator_state_shape_dims.push_back(gather_loop_trip_count);
- for (int64 i = 0; i < window_bounds.size(); i++) {
- if (!c_binary_search(dim_numbers.elided_window_dims(), i)) {
- accumulator_state_shape_dims.push_back(window_bounds[i]);
+ for (int64 i = 0; i < slice_sizes.size(); i++) {
+ if (!absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
+ accumulator_state_shape_dims.push_back(slice_sizes[i]);
}
}
return BroadcastZeros(computation, element_type,
@@ -241,23 +240,23 @@ static StatusOr<HloInstruction*> CreateGatherLoopAccumulatorInitValue(
}
// `accumulator` is almost the tensor the gather operation would have produced,
-// except that it has the dimensions in the wrong order -- the gather dimensions
-// are the major dimensions and the window dimensions are the minor dimensions.
+// except that it has the dimensions in the wrong order -- the batch dimensions
+// are the major dimensions and the offset dimensions are the minor dimensions.
// Fix this up with a transpose.
-static StatusOr<HloInstruction*> PermuteGatherAndWindowDims(
- HloInstruction* accumulator, ArraySlice<int64> output_window_dims,
+static StatusOr<HloInstruction*> PermuteBatchAndOffsetDims(
+ HloInstruction* accumulator, ArraySlice<int64> offset_dims,
int64 output_rank) {
std::vector<int64> permutation;
permutation.reserve(output_rank);
- int64 gather_idx_counter = 0;
- int64 window_idx_counter = output_rank - output_window_dims.size();
+ int64 batch_idx_counter = 0;
+ int64 offset_idx_counter = output_rank - offset_dims.size();
for (int64 i = 0; i < output_rank; i++) {
- bool is_window_dim = c_binary_search(output_window_dims, i);
- if (is_window_dim) {
- permutation.push_back(window_idx_counter++);
+ bool is_offset_dim = absl::c_binary_search(offset_dims, i);
+ if (is_offset_dim) {
+ permutation.push_back(offset_idx_counter++);
} else {
- permutation.push_back(gather_idx_counter++);
+ permutation.push_back(batch_idx_counter++);
}
}
@@ -268,11 +267,11 @@ static StatusOr<HloInstruction*> PermuteGatherAndWindowDims(
//
// We follow the following steps in sequence:
//
-// 1. We canonicalize the gather_indices tensor such that it has rank
+// 1. We canonicalize the start_indices tensor such that it has rank
// 2 (i.e. is a matrix) where each row is an index vector into the
// operand.
// 2. We iterate over the set of indices in the canonicalized
-// gather_indices tensor using a while loop, accumulating slices
+// start_indices tensor using a while loop, accumulating slices
// of the operand tensor into an accumulator using
// DynamicUpdateSlice.
// 3. The accumulator result from the while loop from (2) is then
@@ -287,11 +286,11 @@ static StatusOr<HloInstruction*> PermuteGatherAndWindowDims(
// operand = s32[3,3] parameter(0)
// indices = s32[2,2] parameter(1)
// ROOT gather = s32[2,3,2] gather(operand, indices),
-// output_window_dims={1},
-// elided_window_dims={1},
-// gather_dims_to_operand_dims={1},
+// offset_dims={1},
+// collapsed_slice_dims={1},
+// start_index_map={1},
// index_vector_dim=2,
-// window_bounds={3, 1}
+// slice_sizes={3, 1}
// }
//
// We'd first reshape indices to s32[4,1], where each row is an index
@@ -305,8 +304,8 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
HloComputation* computation = gather_instr->parent();
HloInstruction* operand = gather_instr->mutable_operand(0);
- HloInstruction* gather_indices = gather_instr->mutable_operand(1);
- const Shape& gather_indices_shape = gather_indices->shape();
+ HloInstruction* start_indices = gather_instr->mutable_operand(1);
+ const Shape& start_indices_shape = start_indices->shape();
const Shape& output_shape = gather_instr->shape();
int64 output_rank = output_shape.dimensions_size();
@@ -314,9 +313,9 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
gather_instr->gather_dimension_numbers();
int64 gather_loop_trip_count = 1;
- for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) {
+ for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
if (i != dim_numbers.index_vector_dim()) {
- gather_loop_trip_count *= gather_indices_shape.dimensions(i);
+ gather_loop_trip_count *= start_indices_shape.dimensions(i);
}
}
@@ -327,24 +326,24 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
gather_instr->ToString().c_str());
}
- TF_ASSIGN_OR_RETURN(HloInstruction * canonical_gather_indices,
- CanonicalizeGatherIndices(
- gather_indices, dim_numbers.index_vector_dim()));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * canonical_start_indices,
+ CanonicalizeGatherIndices(start_indices, dim_numbers.index_vector_dim()));
CHECK_EQ(gather_loop_trip_count,
- canonical_gather_indices->shape().dimensions(0));
+ canonical_start_indices->shape().dimensions(0));
TF_ASSIGN_OR_RETURN(
HloInstruction * accumulator_init,
CreateGatherLoopAccumulatorInitValue(
computation, output_shape.element_type(),
- gather_instr->gather_window_bounds(), gather_loop_trip_count,
+ gather_instr->gather_slice_sizes(), gather_loop_trip_count,
gather_instr->gather_dimension_numbers()));
StatusOr<std::vector<HloInstruction*>> gather_loop_result_or_error =
WhileUtil::MakeCountedLoop(
computation, gather_loop_trip_count,
- {operand, canonical_gather_indices, accumulator_init},
+ {operand, canonical_start_indices, accumulator_init},
[&](HloInstruction* indvar,
const std::vector<HloInstruction*>& loop_state) {
return GatherLoopBody(*gather_instr, indvar, loop_state);
@@ -356,13 +355,13 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
HloInstruction* accumulator_result = gather_loop_result.back();
TF_ASSIGN_OR_RETURN(
- HloInstruction * accumulator_with_output_gather_dims_decanonicalized,
- AdjustGatherDimsInAccumulator(gather_indices->shape(), accumulator_result,
- dim_numbers.index_vector_dim()));
+ HloInstruction* const accumulator_with_batch_dims_decanonicalized,
+ AdjustBatchDimsInAccumulator(start_indices->shape(), accumulator_result,
+ dim_numbers.index_vector_dim()));
- return PermuteGatherAndWindowDims(
- accumulator_with_output_gather_dims_decanonicalized,
- AsInt64Slice(dim_numbers.output_window_dims()), output_rank);
+ return PermuteBatchAndOffsetDims(accumulator_with_batch_dims_decanonicalized,
+ AsInt64Slice(dim_numbers.offset_dims()),
+ output_rank);
}
StatusOr<bool> GatherExpander::Run(HloModule* module) {
@@ -375,8 +374,8 @@ StatusOr<bool> GatherExpander::Run(HloModule* module) {
std::vector<HloInstruction*> gather_instrs;
for (HloComputation* computation : module->MakeNonfusionComputations()) {
- c_copy_if(computation->instructions(), std::back_inserter(gather_instrs),
- is_nontrivial_gather);
+ absl::c_copy_if(computation->instructions(),
+ std::back_inserter(gather_instrs), is_nontrivial_gather);
}
for (HloInstruction* inst : gather_instrs) {
diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc
index 020ffcd106..141dd4d6f1 100644
--- a/tensorflow/compiler/xla/service/gather_expander_test.cc
+++ b/tensorflow/compiler/xla/service/gather_expander_test.cc
@@ -28,11 +28,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2147483647,5] parameter(1)
ROOT gather = s32[2147483647,3,5] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
@@ -55,11 +55,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 19575c7905..17eefc430d 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -56,6 +56,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -91,6 +92,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_reachability",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -107,6 +109,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -180,6 +183,8 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
"@llvm//:core",
"@llvm//:support",
],
@@ -243,6 +248,7 @@ cc_library(
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -257,6 +263,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:ptr_util",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -337,6 +344,7 @@ cc_library(
"//tensorflow/core/platform/default/build_config:cufft_plugin",
"//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep
"//tensorflow/stream_executor",
+ "@com_google_absl//absl/memory",
],
)
@@ -362,6 +370,7 @@ cc_library(
hdrs = ["cudnn_convolution_algorithm_picker.h"],
deps = [
":backend_configs",
+ ":buffer_comparator",
":cudnn_convolution_runner",
":gpu_executable",
":ir_emission_utils",
@@ -465,6 +474,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:multi_output_fusion",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -512,6 +522,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -543,6 +554,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_creation_utils",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:shape_inference",
+ "@com_google_absl//absl/memory",
],
)
@@ -599,6 +611,7 @@ cc_library(
"//tensorflow/compiler/xla/service/gpu:infeed_manager",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
"@llvm//:core",
],
alwayslink = True, # Contains per-platform transfer manager registration
@@ -638,6 +651,7 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_liveness",
"//tensorflow/compiler/xla/service:call_inliner",
"//tensorflow/compiler/xla/service:conditional_simplifier",
+ "//tensorflow/compiler/xla/service:convolution_feature_group_converter",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
"//tensorflow/compiler/xla/service:hlo",
@@ -668,6 +682,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
"@llvm//:core",
],
alwayslink = True, # Contains compiler registration
@@ -700,8 +715,8 @@ cc_library(
":xfeed_queue",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -716,6 +731,7 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -765,12 +781,12 @@ cc_library(
":stream_assignment",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:buffer_value",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_ordering",
"//tensorflow/compiler/xla/service:hlo_reachability",
"//tensorflow/compiler/xla/service:hlo_scheduling",
+ "@com_google_absl//absl/memory",
],
)
@@ -787,6 +803,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
index 537295292b..e208ad61e3 100644
--- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
+++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -40,7 +40,7 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
const BufferAssignment* buffer_assignment, int device_ordinal,
DeviceMemoryAllocator* memory_allocator) {
const int64 num_buffers = buffer_assignment->Allocations().size();
- auto buffer_allocations = WrapUnique(new BufferAllocations(
+ auto buffer_allocations = absl::WrapUnique(new BufferAllocations(
num_buffers, device_ordinal, memory_allocator, buffer_assignment));
for (BufferAllocation::Index i = 0; i < num_buffers; ++i) {
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
index 5780e0af40..8b0426aa27 100644
--- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 7d93bdfc8b..caeb89d78e 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
+#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/core/lib/gtl/optional.h"
@@ -177,6 +178,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) {
+ CHECK_EQ(input_shape.element_type(), filter_shape.element_type());
+ CHECK_EQ(input_shape.element_type(), output_shape.element_type());
+ // TODO(timshen): for now only check fp16. It can be expanded to other types,
+ // with some work on the HLO routines.
+ const bool cross_check_enabled = input_shape.element_type() == xla::F16;
+
// Don't run this function concurrently on the same GPU.
//
// This is a bit of a hack and doesn't protect us against arbitrary concurrent
@@ -216,20 +223,64 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
input_output_allocator.AllocateBytes(
&stream, ShapeUtil::ByteSizeOf(output_shape)));
- // Although we don't have evidence this matters, zero out the buffers before
- // autotuning. It's conceivable that using uninitialized memory as the inputs
- // might affect performance if e.g. the inputs contain denormals, and this is
- // easy enough.
- TF_RETURN_IF_ERROR(stream.ThenMemZero(&input_buf, input_buf.size())
- .ThenMemZero(&filter_buf, filter_buf.size())
- .ThenMemZero(&output_buf, output_buf.size())
- .BlockHostUntilDone());
+ if (cross_check_enabled) {
+ // Broadcast a constant to the buffer, instead of zeroing the buffer. A
+ // non-zero constant is useful for the cross checking, because zero-inputs
+ // may not always reveal the bugs.
+ const auto initialize_f16 = [&stream](DeviceMemoryBase buffer) {
+ CHECK_EQ(0, (uintptr_t)buffer.opaque() % 4);
+ size_t left_over_bytes = buffer.size() % 4;
+ CHECK_EQ(0, left_over_bytes % 2);
+
+ constexpr float kBroadcastedConstant = 0.1f;
+ Eigen::half halfs[2] = {Eigen::half(kBroadcastedConstant),
+ Eigen::half(kBroadcastedConstant)};
+ uint32 bits;
+ static_assert(sizeof(bits) == sizeof(halfs), "");
+ memcpy(&bits, halfs, sizeof(bits));
+
+ size_t aligned_size = buffer.size() / 4 * 4;
+ stream.ThenMemset32(&buffer, bits, aligned_size);
+
+ DeviceMemoryBase left_over(
+ static_cast<char*>(buffer.opaque()) + aligned_size, left_over_bytes);
+ stream.ThenMemcpy(&left_over, halfs, left_over_bytes);
+ };
+ initialize_f16(input_buf);
+ initialize_f16(filter_buf);
+ initialize_f16(output_buf);
+ } else {
+ // Although we don't have evidence this matters, zero out the buffers before
+ // autotuning. It's conceivable that using uninitialized memory as the
+ // inputs might affect performance if e.g. the inputs contain denormals, and
+ // this is easy enough.
+ stream.ThenMemZero(&input_buf, input_buf.size())
+ .ThenMemZero(&filter_buf, filter_buf.size())
+ .ThenMemZero(&output_buf, output_buf.size());
+ }
+ TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
+
+ DeviceMemoryBase* result_buf = [&] {
+ switch (kind) {
+ case CudnnConvKind::kBackwardFilter:
+ return &filter_buf;
+ case CudnnConvKind::kBackwardInput:
+ return &input_buf;
+ case CudnnConvKind::kForward:
+ return &output_buf;
+ }
+ }();
const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo(
input_shape, output_shape, dnums, stream_exec_);
se::dnn::ProfileResult best_result;
int64 best_result_bytes_used = 0;
+ optional<F16BufferComparator> comparator;
+ // Use the first algorithm that's supported as reference. There isn't a
+ // particular reason to use it, as any algorithm sufficies. It doesn't make
+ // this algorithm considered correct, though.
+ optional<AlgorithmDesc> first_algorithm;
for (const AlgorithmDesc& alg :
GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) {
ScratchAllocator scratch_allocator(device_ordinal, allocator);
@@ -245,6 +296,42 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
.ok();
if (launch_ok && profile_result.is_valid()) {
+ const bool crash_on_checking_failure =
+ instr->GetModule()
+ ->config()
+ .debug_options()
+ .xla_gpu_crash_on_verification_failures();
+ if (comparator.has_value()) {
+ StatusOr<bool> result = comparator->CompareEqual(
+ se::DeviceMemory<Eigen::half>(*result_buf));
+ if (!result.ok()) {
+ LOG(ERROR) << "Unable to compare "
+ << AlgorithmToString(*first_algorithm) << " against "
+ << AlgorithmToString(alg) << " for " << instr->ToString()
+ << ": " << result.status();
+ CHECK(!crash_on_checking_failure);
+ } else if (!result.ValueOrDie()) {
+ LOG(ERROR) << "Results mismatch between different convolution "
+ "algorithms. This is likely a bug in convolution, or "
+ "an excessive loss of precision in convolution. "
+ << instr->ToString() << " for "
+ << AlgorithmToString(*first_algorithm) << " vs "
+ << AlgorithmToString(alg);
+ CHECK(!crash_on_checking_failure);
+ }
+ } else if (cross_check_enabled) {
+ auto comp = F16BufferComparator::Create(
+ se::DeviceMemory<Eigen::half>(*result_buf), compiler_, allocator,
+ &stream);
+ if (comp.ok()) {
+ comparator.emplace(comp.ConsumeValueOrDie());
+ first_algorithm.emplace(alg);
+ } else {
+ LOG(ERROR) << "Fail to initialize buffer comparator: "
+ << comp.status() << ", instruction: " << instr->ToString();
+ CHECK(!crash_on_checking_failure);
+ }
+ }
int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
VLOG(3) << "Run of algorithm " << AlgorithmToString(alg)
<< " succeeded, taking " << profile_result.elapsed_time_in_ms()
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 0645fbb3ad..7b0d9e53d6 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -96,15 +96,9 @@ Status RunCudnnConvolution(
// tensorflow/python/ops/nn_ops.py).
const int effective_num_dimensions = std::max(2, num_dimensions);
- if (std::is_same<T, float>::value) {
- CHECK_EQ(F32, output_shape.element_type())
- << ShapeUtil::HumanString(output_shape);
- } else if (std::is_same<T, Eigen::half>::value) {
- CHECK_EQ(F16, output_shape.element_type())
- << ShapeUtil::HumanString(output_shape);
- } else {
- LOG(FATAL) << ShapeUtil::HumanString(output_shape);
- }
+ CHECK_EQ(primitive_util::NativeToPrimitiveType<T>(),
+ output_shape.element_type())
+ << ShapeUtil::HumanString(output_shape);
CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size());
CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size());
@@ -246,21 +240,31 @@ Status RunCudnnConvolution(
se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
se::dnn::ProfileResult* profile_result) {
PrimitiveType output_primitive_type = output_shape.element_type();
- CHECK(output_primitive_type == F32 || output_primitive_type == F16)
- << ShapeUtil::HumanString(output_shape);
- if (output_primitive_type == F32) {
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<float>(input_buf), se::DeviceMemory<float>(filter_buf),
- se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums,
- algorithm, stream, profile_result);
+ switch (output_primitive_type) {
+ case F16:
+ return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<Eigen::half>(input_buf),
+ se::DeviceMemory<Eigen::half>(filter_buf),
+ se::DeviceMemory<Eigen::half>(output_buf),
+ scratch_allocator, window, dnums, algorithm,
+ stream, profile_result);
+ case F32:
+ return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<float>(input_buf),
+ se::DeviceMemory<float>(filter_buf),
+ se::DeviceMemory<float>(output_buf),
+ scratch_allocator, window, dnums, algorithm,
+ stream, profile_result);
+ case F64:
+ return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<double>(input_buf),
+ se::DeviceMemory<double>(filter_buf),
+ se::DeviceMemory<double>(output_buf),
+ scratch_allocator, window, dnums, algorithm,
+ stream, profile_result);
+ default:
+ LOG(FATAL) << ShapeUtil::HumanString(output_shape);
}
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<Eigen::half>(input_buf),
- se::DeviceMemory<Eigen::half>(filter_buf),
- se::DeviceMemory<Eigen::half>(output_buf),
- scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
index 2fd2206324..88f0b4d71c 100644
--- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -28,7 +28,7 @@ ForThunk::ForThunk(const int64 loop_limit,
const HloInstruction* hlo)
: Thunk(Kind::kWhile, hlo),
loop_limit_(loop_limit),
- body_thunk_sequence_(MakeUnique<SequentialThunk>(
+ body_thunk_sequence_(absl::make_unique<SequentialThunk>(
// Pass nullptr as the HloInstruction* to the body_thunk_sequence_
// constructor because this SequentialThunk is logically "part of"
// this ForThunk, and shouldn't be profiled separately from it.
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
index 3cd30b754c..9b86e5315b 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <vector>
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -64,10 +65,11 @@ double CalculateBytesReadByFusionParameter(HloInstruction* param) {
// Slice for a more accurate estimate of bytes read.
double bytes = 0.0;
for (auto& instruction : instructions) {
- if (c_all_of(instruction->users(), [](const HloInstruction* instruction) {
- return instruction->opcode() == HloOpcode::kSlice ||
- instruction->opcode() == HloOpcode::kDynamicSlice;
- })) {
+ if (absl::c_all_of(
+ instruction->users(), [](const HloInstruction* instruction) {
+ return instruction->opcode() == HloOpcode::kSlice ||
+ instruction->opcode() == HloOpcode::kDynamicSlice;
+ })) {
// All users are slice: accumulate bytes of all user slice instructions.
for (auto& user : instruction->users()) {
bytes += ShapeUtil::ByteSizeOf(user->shape());
@@ -223,7 +225,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
// Skip 'fusion' instruction if we cannot merge into all of its users.
// Merging into all users enables the removal of 'fusion' from the
// computation.
- if (!c_all_of(fusion->users(), [](const HloInstruction* user) {
+ if (!absl::c_all_of(fusion->users(), [](const HloInstruction* user) {
return user->opcode() == HloOpcode::kFusion &&
(user->fusion_kind() == HloInstruction::FusionKind::kLoop ||
user->fusion_kind() == HloInstruction::FusionKind::kInput);
@@ -241,11 +243,11 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
// If 'fusion' has just one user, then an earlier fusion pass chose not to
// fuse this producer/comsumer pair (likely because of expensive instruction
// re-use by the consumer), and so we honor that choice here as well.
- if (c_any_of(fusion->fused_instructions(),
- [](const HloInstruction* instruction) {
- return instruction->opcode() != HloOpcode::kParameter &&
- GpuInstructionFusion::IsExpensive(*instruction);
- })) {
+ if (absl::c_any_of(fusion->fused_instructions(),
+ [](const HloInstruction* instruction) {
+ return instruction->opcode() != HloOpcode::kParameter &&
+ GpuInstructionFusion::IsExpensive(*instruction);
+ })) {
VLOG(3) << "Not merging " << fusion->name()
<< ": Contains one or more expensive instructions.";
++num_fail_expensive_fused_instruction_;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 7060837904..a1fbd8022d 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -144,7 +144,7 @@ Status GpuExecutable::ExecuteThunks(
TF_RETURN_IF_ERROR(
thunk->ExecuteOnStream(buffer_allocations, stream, &profiler));
if (thunk_schedule_->Depended(thunk)) {
- auto finish_event = MakeUnique<se::Event>(main_stream->parent());
+ auto finish_event = absl::make_unique<se::Event>(main_stream->parent());
finish_event->Init();
stream->ThenRecordEvent(finish_event.get());
thunk_to_finish_event[thunk] = std::move(finish_event);
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
index a2f53f8446..44303724bb 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "llvm/IR/DataLayout.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
@@ -160,9 +161,10 @@ Status GpuTransferManager::TransferLiteralFromOutfeed(
if (ShapeUtil::IsTuple(shape)) {
return;
}
- *buffer = MakeUnique<gpu::OutfeedBuffer>(GetByteSizeRequirement(shape));
+ *buffer = absl::make_unique<gpu::OutfeedBuffer>(
+ GetByteSizeRequirement(shape));
(*buffer)->set_destination(
- MakeUnique<MutableBorrowingLiteral>(literal, index));
+ absl::make_unique<MutableBorrowingLiteral>(literal, index));
});
// Give the tree of buffers to the outfeed mananger. The device will fill it
@@ -179,7 +181,7 @@ Status GpuTransferManager::TransferLiteralFromOutfeed(
} // namespace xla
static std::unique_ptr<xla::TransferManager> CreateNVPTXTransferManager() {
- return xla::MakeUnique<xla::gpu::GpuTransferManager>(
+ return absl::make_unique<xla::gpu::GpuTransferManager>(
/*id=*/stream_executor::cuda::kCudaPlatformId,
/*pointer_size=*/llvm::DataLayout(xla::gpu::NVPTXCompiler::kDataLayout)
.getPointerSize(0 /* default address space */));
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
index 1722676930..b9c21e8edb 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -33,7 +34,7 @@ namespace gpu {
namespace {
void InitAndStartTimer(std::stack<std::unique_ptr<se::Timer>>* timers,
se::Stream* stream) {
- timers->push(MakeUnique<se::Timer>(stream->parent()));
+ timers->push(absl::make_unique<se::Timer>(stream->parent()));
stream->InitTimer(timers->top().get()).ThenStartTimer(timers->top().get());
}
@@ -115,7 +116,7 @@ HloExecutionProfiler::MakeScopedInstructionProfiler(
CHECK(hlo_instructions_.insert(hlo_instruction).second)
<< hlo_instruction->name();
}
- return MakeUnique<ScopedInstructionProfiler>(this, hlo_instruction);
+ return absl::make_unique<ScopedInstructionProfiler>(this, hlo_instruction);
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
index 19de37b0fb..76055ff009 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
@@ -59,8 +59,8 @@ GpuHloOrdering::GpuHloOrdering(
: PredecessorHloOrdering(module) {
// The entry computation has a total order when there's only one stream.
if (stream_assignment.StreamCount() == 1) {
- entry_sequence_ =
- MakeUnique<std::vector<const HloInstruction*>>(thunk_launch_order);
+ entry_sequence_ = absl::make_unique<std::vector<const HloInstruction*>>(
+ thunk_launch_order);
}
// The ordering of instructions for the entry computation is determined by the
@@ -75,7 +75,7 @@ GpuHloOrdering::GpuHloOrdering(
// same-stream predecessors of each instruction.
// Compute the set of all instructions we will want to set reachability on.
- auto predecessor_map = MakeUnique<HloReachabilityMap>(
+ auto predecessor_map = absl::make_unique<HloReachabilityMap>(
module->entry_computation()->MakeInstructionPostOrder());
// The most recently visited instruction per stream.
@@ -208,7 +208,7 @@ StatusOr<std::unique_ptr<HloSchedule>> HloSchedule::Build(
BFSLaunchOrder(entry_computation, &schedule->thunk_launch_order_);
}
- schedule->hlo_ordering_ = MakeUnique<GpuHloOrdering>(
+ schedule->hlo_ordering_ = absl::make_unique<GpuHloOrdering>(
&module, stream_assignment, schedule->thunk_launch_order_);
return std::move(schedule);
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
index 45f0a1c645..d4a96cd5b3 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <unordered_set>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -47,7 +48,7 @@ class HloScheduleTest : public HloTestBase {
auto debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_disable_multi_streaming(false);
config.set_debug_options(debug_options);
- return MakeUnique<HloModule>("test_module", config);
+ return absl::make_unique<HloModule>("test_module", config);
}
HloVec RemoveHlo(const HloVec& input,
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc
index c5f0cdf6cd..a4364b0deb 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
namespace xla {
namespace gpu {
@@ -24,7 +24,7 @@ se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) {
tensorflow::mutex_lock l(host_to_device_stream_mu_);
if (host_to_device_executor_ == nullptr) {
host_to_device_executor_ = executor;
- host_to_device_stream_ = MakeUnique<se::Stream>(executor);
+ host_to_device_stream_ = absl::make_unique<se::Stream>(executor);
host_to_device_stream_->Init();
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 6675dbd3f9..7111b53944 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "absl/algorithm/container.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Instructions.h"
@@ -518,7 +519,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
// We don't have to iterate over the batch dimensions in both arrays, simplify
// the loop nest of the rhs.
for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) {
- DCHECK(c_linear_search(dnums.lhs_batch_dimensions(), i));
+ DCHECK(absl::c_linear_search(dnums.lhs_batch_dimensions(), i));
rhs_index[i] = lhs_index[i];
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 1e81cbde35..dea2a31920 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -21,6 +21,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
@@ -29,7 +31,6 @@ limitations under the License.
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
@@ -314,13 +315,13 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size,
};
// Check the size of input tensors
- if (!c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) {
+ if (!absl::c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) {
return i64_ty;
}
// Check the size of the internal result tensors
if (unnested_hlo->opcode() == HloOpcode::kFusion) {
- if (!c_all_of(
+ if (!absl::c_all_of(
unnested_hlo->fused_instructions_computation()->instructions(),
hlo_shape_in_range)) {
return i64_ty;
@@ -383,7 +384,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
int64 feature_index_value = feature_index->literal().Get<int64>({});
thunk_sequence_->emplace_back(
- MakeUnique<CudnnBatchNormForwardInferenceThunk>(
+ absl::make_unique<CudnnBatchNormForwardInferenceThunk>(
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
@@ -413,7 +414,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
auto output_mean = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
auto output_inv_stddev = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie();
thunk_sequence_->emplace_back(
- MakeUnique<CudnnBatchNormForwardTrainingThunk>(
+ absl::make_unique<CudnnBatchNormForwardTrainingThunk>(
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
@@ -443,19 +444,20 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
auto output_grad_scale = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
auto output_grad_offset =
assn.GetUniqueSlice(custom_call, {2}).ValueOrDie();
- thunk_sequence_->emplace_back(MakeUnique<CudnnBatchNormBackwardThunk>(
- /*operand=*/GetAllocationSlice(*custom_call->operand(0)),
- /*scale=*/GetAllocationSlice(*custom_call->operand(1)),
- /*mean=*/GetAllocationSlice(*custom_call->operand(2)),
- /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)),
- /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)),
- /*epsilon=*/epsilon_value,
- /*feature_index=*/feature_index_value,
- /*output_grad_data=*/output_grad_data,
- /*output_grad_scale=*/output_grad_scale,
- /*output_grad_offset=*/output_grad_offset,
- /*output_tuple=*/GetAllocationSlice(*custom_call),
- /*hlo=*/custom_call));
+ thunk_sequence_->emplace_back(
+ absl::make_unique<CudnnBatchNormBackwardThunk>(
+ /*operand=*/GetAllocationSlice(*custom_call->operand(0)),
+ /*scale=*/GetAllocationSlice(*custom_call->operand(1)),
+ /*mean=*/GetAllocationSlice(*custom_call->operand(2)),
+ /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)),
+ /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)),
+ /*epsilon=*/epsilon_value,
+ /*feature_index=*/feature_index_value,
+ /*output_grad_data=*/output_grad_data,
+ /*output_grad_scale=*/output_grad_scale,
+ /*output_grad_offset=*/output_grad_offset,
+ /*output_tuple=*/GetAllocationSlice(*custom_call),
+ /*hlo=*/custom_call));
return Status::OK();
}
@@ -475,7 +477,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
const auto& target = custom_call->custom_call_target();
std::unique_ptr<ConvolutionThunk> thunk;
if (target == kCudnnConvForwardCallTarget) {
- thunk = MakeUnique<ConvolutionThunk>(
+ thunk = absl::make_unique<ConvolutionThunk>(
CudnnConvKind::kForward,
/*input_buffer=*/lhs_slice,
/*filter_buffer=*/rhs_slice,
@@ -489,7 +491,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
backend_config.algorithm(), backend_config.tensor_ops_enabled(),
custom_call);
} else if (target == kCudnnConvBackwardInputCallTarget) {
- thunk = MakeUnique<ConvolutionThunk>(
+ thunk = absl::make_unique<ConvolutionThunk>(
CudnnConvKind::kBackwardInput,
/*input_buffer=*/conv_result_slice,
/*filter_buffer=*/rhs_slice,
@@ -503,7 +505,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
backend_config.algorithm(), backend_config.tensor_ops_enabled(),
custom_call);
} else if (target == kCudnnConvBackwardFilterCallTarget) {
- thunk = MakeUnique<ConvolutionThunk>(
+ thunk = absl::make_unique<ConvolutionThunk>(
CudnnConvKind::kBackwardFilter,
/*input_buffer=*/lhs_slice,
/*filter_buffer=*/conv_result_slice,
@@ -576,7 +578,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
thunks.push_back(
BuildKernelThunk(fusion, /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
- MakeUnique<SequentialThunk>(std::move(thunks), fusion));
+ absl::make_unique<SequentialThunk>(std::move(thunks), fusion));
std::vector<IrArray> parameter_arrays;
for (HloInstruction* operand : fusion->operands()) {
parameter_arrays.push_back(GetIrArray(*operand, *fusion));
@@ -1718,7 +1720,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
thunks.push_back(
BuildKernelThunk(reduce, /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
- MakeUnique<SequentialThunk>(std::move(thunks), reduce));
+ absl::make_unique<SequentialThunk>(std::move(thunks), reduce));
return EmitReductionToVector(
reduce, input->shape(), {[&](const IrArray::Index& index) {
@@ -1738,7 +1740,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
bool all_tuple_elements_have_buffer =
- c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) {
+ absl::c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) {
return ir_emitter_context_->buffer_assignment()
.GetUniqueTopLevelSlice(tuple_element)
.ok();
@@ -1760,7 +1762,7 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
for (const HloInstruction* tuple_element : tuple->operands()) {
tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element));
}
- thunk_sequence_->emplace_back(MakeUnique<TupleThunk>(
+ thunk_sequence_->emplace_back(absl::make_unique<TupleThunk>(
tuple_element_buffers, GetAllocationSlice(*tuple), tuple));
return Status::OK();
}
@@ -1792,8 +1794,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
thunks.push_back(std::move(initializer_thunk));
thunks.push_back(BuildKernelThunk(select_and_scatter,
/*implements_whole_instruction=*/false));
- thunk_sequence_->emplace_back(
- MakeUnique<SequentialThunk>(std::move(thunks), select_and_scatter));
+ thunk_sequence_->emplace_back(absl::make_unique<SequentialThunk>(
+ std::move(thunks), select_and_scatter));
// TODO(b/31410564): Implement dilation rate for select-and-scatter.
if (window_util::HasDilation(window)) {
@@ -2018,7 +2020,7 @@ Status IrEmitterUnnested::HandleRng(HloInstruction* rng) {
thunks.push_back(std::move(rng_thunk));
thunks.push_back(std::move(increment_seed_thunk));
thunk_sequence_->emplace_back(
- MakeUnique<SequentialThunk>(std::move(thunks), rng));
+ absl::make_unique<SequentialThunk>(std::move(thunks), rng));
return Status::OK();
}
@@ -2043,7 +2045,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
auto values_destination = GetAllocationSlice(*sort, values_shape_index);
if (keys_destination != GetAllocationSlice(*keys)) {
- thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+ thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*keys),
/*destination_buffer=*/keys_destination,
/*mem_size=*/ShapeUtil::ByteSizeOf(keys->shape()), nullptr));
@@ -2051,7 +2053,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
if (values != nullptr && values_destination != GetAllocationSlice(*values)) {
// TODO(b/26783907): Figure out why we never seem to share buffers for
// key/value sort.
- thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+ thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*values),
/*destination_buffer=*/values_destination,
/*mem_size=*/ShapeUtil::ByteSizeOf(values->shape()), nullptr));
@@ -2103,7 +2105,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
}
thunk_sequence_->emplace_back(
- MakeUnique<SequentialThunk>(std::move(thunks), sort));
+ absl::make_unique<SequentialThunk>(std::move(thunks), sort));
return Status::OK();
}
@@ -2130,7 +2132,7 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) {
if (crs->operand_count() == 1) {
CHECK(ShapeUtil::IsArray(crs->operand(0)->shape()))
<< "Operands to cross-replica-sum must be arrays: " << crs->ToString();
- thunk_sequence_->push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+ thunk_sequence_->push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*crs->operand(0)),
/*destination_buffer=*/GetAllocationSlice(*crs),
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs));
@@ -2145,17 +2147,17 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) {
tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment()
.GetUniqueSlice(crs, {i})
.ValueOrDie());
- thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+ thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*crs->operand(i)),
/*destination_buffer=*/tuple_element_buffers.back(),
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr));
}
// Output a tuple of the buffers above.
- thunks.push_back(MakeUnique<TupleThunk>(tuple_element_buffers,
- GetAllocationSlice(*crs), nullptr));
+ thunks.push_back(absl::make_unique<TupleThunk>(
+ tuple_element_buffers, GetAllocationSlice(*crs), nullptr));
thunk_sequence_->push_back(
- MakeUnique<SequentialThunk>(std::move(thunks), crs));
+ absl::make_unique<SequentialThunk>(std::move(thunks), crs));
return Status::OK();
}
@@ -2322,10 +2324,10 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
// We'll pass a pointer to each of the elements of `buffers` to our kernel, in
// this order.
std::vector<const BufferAllocation*> non_constant_buffers;
- c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers),
- [](const BufferAllocation* allocation) {
- return !allocation->is_constant();
- });
+ absl::c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers),
+ [](const BufferAllocation* allocation) {
+ return !allocation->is_constant();
+ });
std::sort(non_constant_buffers.begin(), non_constant_buffers.end(),
[](const BufferAllocation* a, const BufferAllocation* b) {
@@ -2389,7 +2391,7 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
llvm::ConstantPointerNull::get(b_.getInt8PtrTy()));
}
- return MakeUnique<KernelThunk>(
+ return absl::make_unique<KernelThunk>(
non_constant_buffers, llvm_ir::AsString(kernel->getName()),
implements_whole_instruction ? inst : nullptr, unroll_factor);
}
@@ -2398,7 +2400,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk(
const HloInstruction* inst) {
const HloInstruction* operand = inst->operand(0);
CHECK_EQ(HloOpcode::kConstant, operand->opcode());
- return MakeUnique<HostToDeviceCopyThunk>(
+ return absl::make_unique<HostToDeviceCopyThunk>(
/*source_address=*/operand->literal().untyped_data(),
/*destination_buffer=*/GetAllocationSlice(*inst),
/*mem_size=*/
@@ -2410,7 +2412,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk(
std::unique_ptr<Thunk> IrEmitterUnnested::BuildDeviceToDeviceCopyThunk(
const HloInstruction* inst) {
const HloInstruction* operand = inst->operand(0);
- return MakeUnique<DeviceToDeviceCopyThunk>(
+ return absl::make_unique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*operand),
/*destination_buffer=*/GetAllocationSlice(*inst),
/*mem_size=*/
@@ -2430,7 +2432,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk(
.GetUniqueSlice(inst, index)
.ConsumeValueOrDie();
});
- return MakeUnique<InfeedThunk>(slices, inst);
+ return absl::make_unique<InfeedThunk>(slices, inst);
}
std::unique_ptr<Thunk> IrEmitterUnnested::BuildOutfeedThunk(
@@ -2447,7 +2449,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildOutfeedThunk(
*slice = status_or_slice.ConsumeValueOrDie();
}
});
- return MakeUnique<OutfeedThunk>(std::move(slices), inst);
+ return absl::make_unique<OutfeedThunk>(std::move(slices), inst);
}
namespace {
@@ -2470,7 +2472,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
if (inst->opcode() == HloOpcode::kDot) {
const HloInstruction* lhs = inst->operand(0);
const HloInstruction* rhs = inst->operand(1);
- return MakeUnique<GemmThunk>(
+ return absl::make_unique<GemmThunk>(
GetAllocationSlice(*lhs), // The buffer assigned to LHS.
GetAllocationSlice(*rhs), // The buffer assigned to RHS.
GetAllocationSlice(*inst), // The output buffer.
@@ -2512,7 +2514,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
const HloInstruction* rhs =
inst->operand(rhs_parameter->parameter_number());
- return MakeUnique<GemmThunk>(
+ return absl::make_unique<GemmThunk>(
GetAllocationSlice(*lhs), // The buffer assigned to LHS.
GetAllocationSlice(*rhs), // The buffer assigned to RHS.
GetAllocationSlice(*inst), // The output buffer.
@@ -2529,11 +2531,12 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk(
const HloInstruction* inst) {
const HloInstruction* operand = inst->operand(0);
- return MakeUnique<FftThunk>(inst->fft_type(), inst->fft_length(),
- /*input_buffer=*/GetAllocationSlice(*operand),
- /*output_buffer=*/GetAllocationSlice(*inst),
- /*input_shape=*/operand->shape(),
- /*output_shape=*/inst->shape(), inst);
+ return absl::make_unique<FftThunk>(
+ inst->fft_type(), inst->fft_length(),
+ /*input_buffer=*/GetAllocationSlice(*operand),
+ /*output_buffer=*/GetAllocationSlice(*inst),
+ /*input_shape=*/operand->shape(),
+ /*output_shape=*/inst->shape(), inst);
}
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
@@ -2582,9 +2585,9 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
// MemzeroThunk.
ArraySlice<uint8> literal_bytes(
reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
- if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
- return {
- MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo, index), nullptr)};
+ if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
+ return {absl::make_unique<MemzeroThunk>(GetAllocationSlice(*hlo, index),
+ nullptr)};
}
// If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
@@ -2601,7 +2604,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16));
}
uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
- return {MakeUnique<Memset32BitValueThunk>(
+ return {absl::make_unique<Memset32BitValueThunk>(
pattern32, GetAllocationSlice(*hlo, index), nullptr)};
}
@@ -2612,7 +2615,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
literal_bytes.size() - 4) == 0) {
uint32 word;
memcpy(&word, literal_bytes.data(), sizeof(word));
- return {MakeUnique<Memset32BitValueThunk>(
+ return {absl::make_unique<Memset32BitValueThunk>(
word, GetAllocationSlice(*hlo, index), nullptr)};
}
}
@@ -2764,7 +2767,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk(
ir_emitter_context_);
TF_CHECK_OK(body->Accept(&ir_emitter_body));
- return MakeUnique<WhileThunk>(
+ return absl::make_unique<WhileThunk>(
GetAllocationSlice(*condition->root_instruction()), // cond result
ir_emitter_condition.ConsumeThunkSequence(),
ir_emitter_body.ConsumeThunkSequence(), hlo);
@@ -2782,8 +2785,8 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
ir_emitter_context_);
TF_CHECK_OK(body->Accept(&ir_emitter_body));
- return MakeUnique<ForThunk>(loop_limit,
- ir_emitter_body.ConsumeThunkSequence(), hlo);
+ return absl::make_unique<ForThunk>(
+ loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo);
}
std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
@@ -2803,7 +2806,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
ir_emitter_context_);
TF_CHECK_OK(false_computation->Accept(&ir_emitter_false));
- return MakeUnique<ConditionalThunk>(
+ return absl::make_unique<ConditionalThunk>(
GetAllocationSlice(*hlo->operand(0)),
GetAllocationSlice(*hlo->operand(1)),
GetAllocationSlice(*hlo->operand(2)),
@@ -3105,7 +3108,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
CeilOfRatio<int64>(output_dims_in_tiles[i], kTileSize);
}
const int64 num_tiles =
- c_accumulate(output_dims_in_tiles, 1, std::multiplies<int64>());
+ absl::c_accumulate(output_dims_in_tiles, 1, std::multiplies<int64>());
LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile);
llvm::Type* index_ty =
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
index e76823ad10..6305396635 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/types.h"
@@ -95,7 +95,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
VLOG(3) << "Launching " << kernel->name();
// Launch the kernel with potentially multiple blocks and threads.
static constexpr int kKernelArgsLimit = 1024;
- auto kernel_args = MakeUnique<se::KernelArgsArray<kKernelArgsLimit>>();
+ auto kernel_args = absl::make_unique<se::KernelArgsArray<kKernelArgsLimit>>();
for (const BufferAllocation* arg : args_) {
const auto& buf = buffer_allocations.GetDeviceAddress(arg->index());
kernel_args->add_device_memory_argument(buf);
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
index eb93efc560..6bd9c58f83 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
@@ -34,6 +34,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/memory",
"@llvm//:amdgpu_code_gen",
"@llvm//:analysis",
"@llvm//:bit_reader",
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
index ff4ae1f9ef..cce6e48141 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <string>
#include <utility>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
@@ -205,7 +205,7 @@ std::unique_ptr<llvm::TargetMachine> GetTargetMachine(
default:
codegen_opt_level = CodeGenOpt::None;
}
- return WrapUnique(target->createTargetMachine(
+ return absl::WrapUnique(target->createTargetMachine(
triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx60", target_options,
Optional<Reloc::Model>(RelocModel), Optional<CodeModel::Model>(CMModel),
codegen_opt_level));
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index c62bae0628..34a479b289 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
@@ -131,7 +132,7 @@ bool ReduceFriendlyInputLayouts(HloInstruction* instr) {
max_rank_layout = &param->shape().layout();
}
}
- return c_all_of(params, [&](HloInstruction* param) {
+ return absl::c_all_of(params, [&](HloInstruction* param) {
return (ShapeUtil::Rank(param->shape()) < max_rank) ||
(LayoutUtil::Equal(param->shape().layout(), *max_rank_layout));
});
@@ -248,7 +249,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
}
// Do not fuse a producer if the other operands of the fusion are
// reachable from the producer, this would create a cycle.
- if (c_any_of(consumer_operands, [&](HloInstruction* operand) {
+ if (absl::c_any_of(consumer_operands, [&](HloInstruction* operand) {
return producer != operand &&
reachability()->IsReachable(producer, operand);
})) {
@@ -268,7 +269,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
for (auto& fusion_pair : potential_fusion_list) {
HloInstruction* producer = fusion_pair.first;
HloInstruction* consumer = fusion_pair.second;
- if (!c_any_of(consumer->operands(), [&](HloInstruction* operand) {
+ if (!absl::c_any_of(consumer->operands(), [&](HloInstruction* operand) {
return producer != operand &&
reachability()->IsReachable(producer, operand);
})) {
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index d937123357..5868c1a42e 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -21,19 +21,20 @@ limitations under the License.
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
#include <utility>
+#include "absl/memory/memory.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
+#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
@@ -203,6 +204,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// (PadInsertion).
HloPassPipeline pipeline("conv_canonicalization");
pipeline.AddInvariantChecker<HloVerifier>();
+ // TODO(b/31709653): Directly use the grouped convolution support of Cudnn.
+ pipeline.AddPass<ConvolutionFeatureGroupConverter>();
pipeline.AddPass<CudnnConvolutionRewriter>();
pipeline.AddPass<PadInsertion>();
if (IsVoltaOrLater(*stream_exec)) {
@@ -687,7 +690,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
const std::vector<uint8> cubin =
CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor);
- auto thunk_schedule = MakeUnique<ThunkSchedule>(
+ auto thunk_schedule = absl::make_unique<ThunkSchedule>(
ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment),
hlo_schedule->ThunkLaunchOrder());
VLOG(2) << "Printing the thunk schedule...";
@@ -701,7 +704,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
cost_analysis.set_bytes_per_second(
stream_exec->GetDeviceDescription().memory_bandwidth());
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
- profile_index_map = MakeUnique<HloProfileIndexMap>(*module);
+ profile_index_map = absl::make_unique<HloProfileIndexMap>(*module);
profile_printer =
CreateHloProfilePrinterData(*profile_index_map, cost_analysis);
}
@@ -810,7 +813,7 @@ se::Platform::Id NVPTXCompiler::PlatformId() const {
static bool InitModule() {
xla::Compiler::RegisterCompilerFactory(
stream_executor::cuda::kCudaPlatformId,
- []() { return xla::MakeUnique<xla::gpu::NVPTXCompiler>(); });
+ []() { return absl::make_unique<xla::gpu::NVPTXCompiler>(); });
return true;
}
static bool module_initialized = InitModule();
diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc
index 4aaf0c9e14..2fa170964e 100644
--- a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/platform/logging.h"
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index b22040eee1..98cc21ccac 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
@@ -69,7 +70,7 @@ HloInstruction* MaybePaddedAndSlicedInput(
PrimitiveType element_type = input->shape().element_type();
HloInstruction* padding =
computation->AddInstruction(HloInstruction::CreateConstant(
- MakeUnique<Literal>(LiteralUtil::Zero(element_type))));
+ absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
input = MakePadHlo(input, padding, padding_config).ValueOrDie();
}
@@ -126,7 +127,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window,
PrimitiveType element_type = kernel->shape().element_type();
HloInstruction* padding =
computation->AddInstruction(HloInstruction::CreateConstant(
- MakeUnique<Literal>(LiteralUtil::Zero(element_type))));
+ absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
}
} // namespace
@@ -236,7 +237,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
HloComputation* computation = backward_conv->parent();
HloInstruction* output = backward_conv->mutable_operand(1);
HloInstruction* padding = computation->AddInstruction(
- HloInstruction::CreateConstant(MakeUnique<Literal>(
+ HloInstruction::CreateConstant(absl::make_unique<Literal>(
LiteralUtil::Zero(input->shape().element_type()))));
HloInstruction* padded_input =
MakePadHlo(input, padding, input_padding_config).ValueOrDie();
diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
index d3fd0544fb..c927c5ee16 100644
--- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
@@ -18,8 +18,8 @@ limitations under the License.
#include <ostream>
#include <string>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc
index 0806dd5161..5b6cf2c04d 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
@@ -119,7 +119,7 @@ int ComputeStreamToAssign(
} // namespace
std::unique_ptr<StreamAssignment> AssignStreams(const HloModule& module) {
- auto stream_assignment = MakeUnique<StreamAssignment>();
+ auto stream_assignment = absl::make_unique<StreamAssignment>();
const HloComputation& computation = *module.entry_computation();
std::unique_ptr<HloReachabilityMap> reachability =
computation.ComputeReachability();
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
index 6f4bb0580e..3f75d8b559 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -33,7 +34,7 @@ class StreamAssignmentTest : public HloTestBase {
auto debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_disable_multi_streaming(false);
config.set_debug_options(debug_options);
- return MakeUnique<HloModule>("test_module", config);
+ return absl::make_unique<HloModule>("test_module", config);
}
// Pre-canned shapes.
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index 4fad3f46cf..db4a33dc56 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -35,13 +35,13 @@ cc_library(
"requires-gpu-sm35",
],
deps = [
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:gpu_plugin",
"//tensorflow/compiler/xla/service/gpu:gpu_executable",
"//tensorflow/compiler/xla/tests:filecheck",
"//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -60,6 +60,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -94,6 +95,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -150,6 +152,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -168,6 +171,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc
index 4b8415fe91..0e84ec7e62 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/tests/filecheck.h"
#include "tensorflow/core/platform/logging.h"
@@ -32,7 +32,7 @@ std::unique_ptr<HloModule> GpuCodegenTest::CreateNewModuleWithFTZ(bool ftz) {
debug_options.add_xla_disable_hlo_passes("constant_folding");
config.set_debug_options(debug_options);
- return MakeUnique<HloModule>(TestName(), config);
+ return absl::make_unique<HloModule>(TestName(), config);
}
void GpuCodegenTest::CompileAndVerifyPtx(std::unique_ptr<HloModule> hlo_module,
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
index ce69e058e6..4550f36fdf 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
@@ -16,9 +16,9 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc
index e5958165ef..a06576df7b 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc
index 6c9ae7bada..6a9ecd9dae 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc
index c42e5704a4..15198865bd 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
index 8579b1545f..989b542ff4 100644
--- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
@@ -25,7 +26,7 @@ Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
HloExecutionProfiler* profiler) {
auto size = tuple_element_buffers_.size();
- auto tuple_element_buffer_addresses = MakeUnique<void*[]>(size);
+ auto tuple_element_buffer_addresses = absl::make_unique<void*[]>(size);
for (int i = 0; i != size; ++i) {
tuple_element_buffer_addresses[i] =
buffer_allocations.GetDeviceAddress(tuple_element_buffers_[i]).opaque();
diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
index d81d87e7dc..828fc2884b 100644
--- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -34,9 +34,9 @@ WhileThunk::WhileThunk(
// and body_thunk_sequence_ constructors because these SequentialThunks
// are logically "part of" this WhileThunk, and shouldn't be profiled
// separately from it.
- condition_thunk_sequence_(MakeUnique<SequentialThunk>(
+ condition_thunk_sequence_(absl::make_unique<SequentialThunk>(
std::move(*condition_thunk_sequence), nullptr)),
- body_thunk_sequence_(MakeUnique<SequentialThunk>(
+ body_thunk_sequence_(absl::make_unique<SequentialThunk>(
std::move(*body_thunk_sequence), nullptr)) {}
Status WhileThunk::Initialize(const GpuExecutable& executable,
diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc
index aa89567ee8..31431f115f 100644
--- a/tensorflow/compiler/xla/service/graphviz_example.cc
+++ b/tensorflow/compiler/xla/service/graphviz_example.cc
@@ -22,9 +22,9 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -84,7 +84,7 @@ HloComputation* CallForwardingComputation(HloComputation* computation,
// the module.
std::unique_ptr<HloModule> MakeBigGraph() {
HloModuleConfig config;
- auto module = MakeUnique<HloModule>("BigGraph", config);
+ auto module = absl::make_unique<HloModule>("BigGraph", config);
auto builder = HloComputation::Builder("TestBigGraphvizGraph");
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index 4005fc0d11..93a922b904 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/util.h"
@@ -45,7 +46,7 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
// bound, by minimizing the liveness of sub-computations.
TF_ASSIGN_OR_RETURN(
HeapSimulator::Result result,
- HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), *module,
+ HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(), *module,
module_sequence, *points_to_analysis, size_function));
return result.heap_size;
}
@@ -60,9 +61,10 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
memory_by_computation) {
TF_ASSIGN_OR_RETURN(
HeapSimulator::Result result,
- HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation,
- sequence, points_to_analysis, size_function,
- HeapSimulator::Options(), memory_by_computation));
+ HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(),
+ computation, sequence, points_to_analysis,
+ size_function, HeapSimulator::Options(),
+ memory_by_computation));
return result.heap_size;
}
@@ -344,7 +346,7 @@ HeapSimulator::HeapSimulator(
const SequentialHloOrdering::HloModuleSequence* module_sequence,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
memory_by_computation)
- : no_fragmentation_stats_(MakeUnique<NoFragmentationStatsHeap>()),
+ : no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()),
algorithm_(std::move(algorithm)),
size_fn_(size_fn),
options_(options),
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index b41dc66fe9..5f85f14565 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -137,7 +138,7 @@ class HeapSimulatorTracker {
const string& name, std::unique_ptr<HloComputation> computation,
const std::vector<const HloInstruction*>& instruction_sequence) {
HloModuleConfig config;
- module_ = MakeUnique<HloModule>(name, config);
+ module_ = absl::make_unique<HloModule>(name, config);
module_->AddEntryComputation(std::move(computation));
points_to_analysis_ =
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
@@ -146,8 +147,8 @@ class HeapSimulatorTracker {
// the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls by
// buffer id, for determinism in the tests.
auto zero_size = [](const BufferValue& buffer) { return 0; };
- auto algorithm = MakeUnique<DecreasingSizeRunsHeap>(
- MakeUnique<HeapCallRecorder>(&actual_calls_));
+ auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
+ absl::make_unique<HeapCallRecorder>(&actual_calls_));
result_ = HeapSimulator::Run(
std::move(algorithm), *module_->entry_computation(),
instruction_sequence, *points_to_analysis_, zero_size)
@@ -156,7 +157,7 @@ class HeapSimulatorTracker {
explicit HeapSimulatorTracker(const string& name) {
HloModuleConfig config;
- module_ = MakeUnique<HloModule>(name, config);
+ module_ = absl::make_unique<HloModule>(name, config);
}
// Similar to the single entry computation constructor above, but runs the
@@ -182,8 +183,8 @@ class HeapSimulatorTracker {
auto size_fn = [&reverse_position](const BufferValue& buffer) {
return reverse_position[buffer.instruction()];
};
- auto algorithm = MakeUnique<DecreasingSizeRunsHeap>(
- MakeUnique<HeapCallRecorder>(&actual_calls_));
+ auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
+ absl::make_unique<HeapCallRecorder>(&actual_calls_));
result_ = HeapSimulator::Run(std::move(algorithm), *module_,
module_sequence, *points_to_analysis_, size_fn)
.ConsumeValueOrDie();
@@ -675,7 +676,8 @@ class HeapAlgorithmTestBase : public ::testing::Test {
const BufferValue::Id id = buffers_.size();
auto const0 = builder_.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
- buffers_.emplace_back(MakeUnique<HloValue>(id, const0, ShapeIndex{}));
+ buffers_.emplace_back(
+ absl::make_unique<HloValue>(id, const0, ShapeIndex{}));
return buffers_.back().get();
}
@@ -724,7 +726,8 @@ class DecreasingSizeRunsHeapTest : public HeapAlgorithmTestBase {};
TEST_F(DecreasingSizeRunsHeapTest, Empty) {
CallSequence call_sequence;
- DecreasingSizeRunsHeap heap(MakeUnique<HeapCallRecorder>(&call_sequence));
+ DecreasingSizeRunsHeap heap(
+ absl::make_unique<HeapCallRecorder>(&call_sequence));
heap.Finish();
EXPECT_EQ(call_sequence, CallSequence({
{kFinish, nullptr},
@@ -733,7 +736,8 @@ TEST_F(DecreasingSizeRunsHeapTest, Empty) {
TEST_F(DecreasingSizeRunsHeapTest, Simple) {
CallSequence call_sequence;
- DecreasingSizeRunsHeap heap(MakeUnique<HeapCallRecorder>(&call_sequence));
+ DecreasingSizeRunsHeap heap(
+ absl::make_unique<HeapCallRecorder>(&call_sequence));
heap.Alloc(buffer_a_, 10);
heap.Alloc(buffer_b_, 20);
heap.Alloc(buffer_c_, 30);
@@ -760,7 +764,8 @@ TEST_F(DecreasingSizeRunsHeapTest, Simple) {
TEST_F(DecreasingSizeRunsHeapTest, Mixed) {
CallSequence call_sequence;
- DecreasingSizeRunsHeap heap(MakeUnique<HeapCallRecorder>(&call_sequence));
+ DecreasingSizeRunsHeap heap(
+ absl::make_unique<HeapCallRecorder>(&call_sequence));
heap.Alloc(buffer_a_, 10);
heap.Alloc(buffer_b_, 20);
heap.Free(buffer_b_, 20);
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index be9098f555..fa218657fe 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -34,6 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
+// Next ID: 51
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@@ -74,6 +75,11 @@ message HloInstructionProto {
// Describes the dimension numbers used for a convolution.
xla.ConvolutionDimensionNumbers convolution_dimension_numbers = 16;
+ // The number of feature groups. Used for a convolution. Must be a divisor of
+ // the input feature dimension and output feature dimension. If not specified,
+ // it will use a default value of 1.
+ int64 feature_group_count = 50;
+
// Describes the [begin, end) index range and stride for slices.
message SliceDimensions {
int64 start = 1;
@@ -133,7 +139,7 @@ message HloInstructionProto {
// Gather dimension numbers.
xla.GatherDimensionNumbers gather_dimension_numbers = 33;
- repeated int64 gather_window_bounds = 34;
+ repeated int64 gather_slice_sizes = 34;
// Compute Host.
string channel_name = 41;
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
index e8a4b034b4..0ca489846e 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
@@ -457,7 +457,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
VLOG(2) << "HloAliasAnalysis::Run on module " << module->name();
XLA_VLOG_LINES(2, module->ToString());
- auto alias_analysis = WrapUnique(new HloAliasAnalysis(module));
+ auto alias_analysis = absl::WrapUnique(new HloAliasAnalysis(module));
TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_,
HloDataflowAnalysis::Run(*module, /*ssa_form=*/true,
/*bitcast_defines_value=*/false,
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 441288da1a..70b18ff356 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -23,9 +23,10 @@ limitations under the License.
#include <set>
#include <sstream>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -56,8 +57,8 @@ std::unique_ptr<HloComputation> HloComputation::Builder::Build(
HloInstruction* root =
root_instruction ? root_instruction : last_added_instruction_;
CHECK_NE(nullptr, root);
- return WrapUnique(new HloComputation(name_, parameter_count, &instructions_,
- root, fusion_instruction_));
+ return absl::WrapUnique(new HloComputation(
+ name_, parameter_count, &instructions_, root, fusion_instruction_));
}
HloComputation::HloComputation(
@@ -320,6 +321,7 @@ void ComputeComputationPostOrder(
enum State { kVisiting, kVisited };
void ComputeInstructionPostOrder(
+ std::map<int64, std::vector<HloInstruction*>> channel_dependency_map,
std::vector<HloInstruction*>* post_order, HloInstruction* root,
tensorflow::gtl::FlatMap<HloInstruction*, State>* visited) {
std::vector<HloInstruction*> dfs_stack;
@@ -354,12 +356,67 @@ void ComputeInstructionPostOrder(
for (HloInstruction* op : current->control_predecessors()) {
dfs_stack.emplace_back(op);
}
+
+ // Add inputs for send->recv_done dependencies and cross-replica-sum
+ // dependencies.
+ switch (current->opcode()) {
+ case HloOpcode::kRecvDone: {
+ const auto& dependencies =
+ channel_dependency_map[current->channel_id()];
+ for (HloInstruction* op : dependencies) {
+ dfs_stack.emplace_back(op);
+ }
+ break;
+ }
+ case HloOpcode::kCrossReplicaSum: {
+ auto all_reduce_id = current->all_reduce_id();
+ if (all_reduce_id) {
+ const auto& dependencies =
+ channel_dependency_map[all_reduce_id.value()];
+ for (HloInstruction* op : dependencies) {
+ dfs_stack.emplace_back(op);
+ }
+ }
+ break;
+ }
+ default:
+ break;
+ }
}
}
} // namespace
+std::map<int64, std::vector<HloInstruction*>>
+HloComputation::ComputeChannelDependencies() const {
+ std::map<int64, std::vector<HloInstruction*>> channel_dependency_map;
+ for (const auto& instruction : instructions_) {
+ switch (instruction->opcode()) {
+ case HloOpcode::kSend: {
+ channel_dependency_map[instruction->channel_id()].push_back(
+ instruction.get());
+ break;
+ }
+ case HloOpcode::kCrossReplicaSum: {
+ auto all_reduce_id = instruction->all_reduce_id();
+ if (all_reduce_id) {
+ auto& dependencies = channel_dependency_map[all_reduce_id.value()];
+ absl::c_copy(instruction->operands(),
+ std::back_inserter(dependencies));
+ absl::c_copy(instruction->control_predecessors(),
+ std::back_inserter(dependencies));
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ }
+ return channel_dependency_map;
+}
+
std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
+ auto channel_dependency_map = ComputeChannelDependencies();
std::vector<HloInstruction*> post_order;
post_order.reserve(instruction_count());
std::vector<HloInstruction*> trace_instructions;
@@ -371,7 +428,8 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
// users).
trace_instructions.push_back(instruction.get());
} else if (instruction->users().empty()) {
- ComputeInstructionPostOrder(&post_order, instruction.get(), &visited);
+ ComputeInstructionPostOrder(channel_dependency_map, &post_order,
+ instruction.get(), &visited);
}
}
post_order.insert(post_order.end(), trace_instructions.begin(),
@@ -493,9 +551,9 @@ HloComputation::CreateFromProto(
return to_proto_id[a.get()] < to_proto_id[b.get()];
});
- return WrapUnique(new HloComputation(proto.name(), parameter_count,
- &instructions, root,
- /*fusion_instruction=*/nullptr));
+ return absl::WrapUnique(new HloComputation(proto.name(), parameter_count,
+ &instructions, root,
+ /*fusion_instruction=*/nullptr));
}
void HloComputation::FuseInstructionsInto(
@@ -674,13 +732,34 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability()
const {
const auto& all = MakeInstructionPostOrder();
- auto result = MakeUnique<HloReachabilityMap>(all);
+ auto result = absl::make_unique<HloReachabilityMap>(all);
+ auto channel_dependency_map = ComputeChannelDependencies();
std::vector<HloInstruction*> inputs;
for (const HloInstruction* hlo : all) {
inputs.assign(hlo->operands().begin(), hlo->operands().end());
inputs.insert(inputs.end(), hlo->control_predecessors().begin(),
hlo->control_predecessors().end());
+
+ switch (hlo->opcode()) {
+ case HloOpcode::kRecvDone: {
+ const auto& dependencies = channel_dependency_map[hlo->channel_id()];
+ absl::c_copy(dependencies, std::back_inserter(inputs));
+ break;
+ }
+ case HloOpcode::kCrossReplicaSum: {
+ auto all_reduce_id = hlo->all_reduce_id();
+ if (all_reduce_id) {
+ const auto& dependencies =
+ channel_dependency_map[all_reduce_id.value()];
+ absl::c_copy(dependencies, std::back_inserter(inputs));
+ }
+ break;
+ }
+ default:
+ break;
+ }
+
result->FastSetReachabilityToUnion(inputs, hlo);
}
return result;
@@ -829,7 +908,7 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
HloCloneContext* context, const string& suffix) {
std::unique_ptr<HloCloneContext> context_ptr;
if (context == nullptr) {
- context_ptr = MakeUnique<HloCloneContext>(parent(), suffix);
+ context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix);
context = context_ptr.get();
}
@@ -901,9 +980,9 @@ void HloComputation::UniquifyName(NameUniquer* name_uniquer) {
HloInstruction* HloComputation::GetInstructionWithName(
tensorflow::StringPiece name) {
auto instructions_in_computation = instructions();
- auto it = c_find_if(instructions_in_computation, [&](HloInstruction* instr) {
- return instr->name() == name;
- });
+ auto it = absl::c_find_if(
+ instructions_in_computation,
+ [&](HloInstruction* instr) { return instr->name() == name; });
return it == instructions_in_computation.end() ? nullptr : *it;
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 49ed65910f..faa33f0f90 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -399,6 +399,13 @@ class HloComputation {
// Internal helper to collect unreachable roots.
std::vector<HloInstruction*> CollectUnreachableRoots() const;
+ // Returns a map from channel-id to directed dependencies of the channel
+ // instructions. For send&recv pairs it means the send instruction and for
+ // cross-replica-sum the union of the dependencies for all participating
+ // instructions.
+ std::map<int64, std::vector<HloInstruction*>> ComputeChannelDependencies()
+ const;
+
string name_;
int64 unique_id_;
HloInstruction* root_instruction_;
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
index e4c5470331..f7ed1b0316 100644
--- a/tensorflow/compiler/xla/service/hlo_computation_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -691,6 +691,27 @@ TEST_F(HloComputationTest, StringificationCanonical) {
EXPECT_EQ(computation->ToString(options), expected_computation2);
}
-} // namespace
+TEST_F(HloComputationTest, ChannelReachability) {
+ const Shape shape = ShapeUtil::MakeShape(F32, {5, 7});
+ HloComputation::Builder builder("ChannelReachability");
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param"));
+ auto token0 = builder.AddInstruction(HloInstruction::CreateToken());
+ auto send =
+ builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1));
+ auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
+ auto token1 = builder.AddInstruction(HloInstruction::CreateToken());
+ auto recv =
+ builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1));
+ auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build(recv_done));
+ auto reachability = computation->ComputeReachability();
+ EXPECT_TRUE(reachability->IsReachable(param, recv_done));
+ EXPECT_FALSE(reachability->IsReachable(send, recv));
+ EXPECT_FALSE(reachability->IsReachable(send_done, recv));
+}
+
+} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index 7229031c0c..6dddda1ca8 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@@ -38,7 +39,7 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
// Limit the constant folding to 0 iterations to skip folding loops. This
// retains the behavior from before while loop support in HloEvaluator and may
// be revised.
- auto evaluator = MakeUnique<HloEvaluator>(/*max_loop_iterations=*/0);
+ auto evaluator = absl::make_unique<HloEvaluator>(/*max_loop_iterations=*/0);
XLA_VLOG_LINES(2,
"HloConstantFolding::Run(), before:\n" + module->ToString());
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index 858992a326..c4e27dc558 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -14,9 +14,10 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
@@ -149,13 +150,13 @@ StatusOr<HloInstruction*> MakeConcatHlo(ArraySlice<HloInstruction*> operands,
CHECK_GT(operands.size(), 0);
HloComputation* computation = operands[0]->parent();
- CHECK(c_all_of(operands, [&](HloInstruction* instr) {
+ CHECK(absl::c_all_of(operands, [&](HloInstruction* instr) {
return instr->parent() == computation;
}));
std::vector<const Shape*> operand_shapes;
- c_transform(operands, std::back_inserter(operand_shapes),
- [](HloInstruction* instr) { return &instr->shape(); });
+ absl::c_transform(operands, std::back_inserter(operand_shapes),
+ [](HloInstruction* instr) { return &instr->shape(); });
TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape(
operand_shapes, dimension));
@@ -228,7 +229,7 @@ StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
const Shape& operand_shape = operand->shape();
new_shape_dims.reserve(n + operand_shape.dimensions_size());
new_shape_dims.insert(new_shape_dims.begin(), n, 1);
- c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims));
+ absl::c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims));
return MakeReshapeHlo(new_shape_dims, operand);
}
@@ -240,7 +241,7 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
std::vector<int64> expanded_shape_dim_bounds;
expanded_shape_dim_bounds.reserve(expanded_dims.size() +
operand->shape().dimensions_size() - 1);
- c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds));
+ absl::c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds));
std::copy(operand->shape().dimensions().begin() + 1,
operand->shape().dimensions().end(),
std::back_inserter(expanded_shape_dim_bounds));
@@ -251,7 +252,7 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
ArraySlice<int64> dims_to_elide) {
- CHECK(c_is_sorted(dims_to_elide));
+ CHECK(absl::c_is_sorted(dims_to_elide));
const Shape& input_shape = operand->shape();
// First accumulate in reverse
@@ -268,7 +269,7 @@ StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
}
}
- c_reverse(new_shape_dim_bounds);
+ absl::c_reverse(new_shape_dim_bounds);
Shape output_shape =
ShapeUtil::MakeShape(input_shape.element_type(), new_shape_dim_bounds);
return MakeReshapeHlo(output_shape, operand);
@@ -276,7 +277,7 @@ StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
StatusOr<HloInstruction*> InsertDegenerateDims(
HloInstruction* operand, ArraySlice<int64> dims_to_insert) {
- CHECK(c_is_sorted(dims_to_insert));
+ CHECK(absl::c_is_sorted(dims_to_insert));
const Shape& operand_shape = operand->shape();
int64 output_shape_rank =
@@ -318,7 +319,7 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
*padding_config.add_dimensions() = padding_config_dim;
HloInstruction* zero = computation->AddInstruction(
- HloInstruction::CreateConstant(MakeUnique<Literal>(
+ HloInstruction::CreateConstant(absl::make_unique<Literal>(
LiteralUtil::Zero(operand->shape().element_type()))));
return MakePadHlo(operand, zero, padding_config);
}
@@ -328,7 +329,7 @@ StatusOr<HloInstruction*> BroadcastZeros(
ArraySlice<int64> broadcast_dimensions) {
HloInstruction* zero =
computation->AddInstruction(HloInstruction::CreateConstant(
- MakeUnique<Literal>(LiteralUtil::Zero(element_type))));
+ absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
/*result_shape_bounds=*/broadcast_dimensions);
}
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
index 60d3e71757..a8de285d16 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -28,7 +28,7 @@ using tensorflow::gtl::ArraySlice;
class HloCreationUtilsTest : public HloTestBase {
protected:
- static std::unique_ptr<HloModule> CreateModuleWithProgramShape(
+ std::unique_ptr<HloModule> CreateModuleWithProgramShape(
PrimitiveType primitive_type, ArraySlice<int64> input_shape_dims,
ArraySlice<int64> output_shape_dims, HloInstruction** param,
HloComputation** entry_computation) {
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index 90fbaa37c5..406d712ec6 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -20,9 +20,9 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index bbfb0c253f..9b15057929 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <queue>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -886,7 +886,7 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
XLA_VLOG_LINES(2, module.ToString());
- auto dataflow_analysis = WrapUnique(new HloDataflowAnalysis(
+ auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis(
module, ssa_form, bitcast_defines_value, fusion_can_share_buffer));
TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc
index 26e3736e01..3b5cde2996 100644
--- a/tensorflow/compiler/xla/service/hlo_dce_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc
@@ -17,9 +17,9 @@ limitations under the License.
#include <memory>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc
index 9e096320db..edf0073f30 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <algorithm>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/types.h"
@@ -25,14 +26,14 @@ namespace xla {
/* static */ StatusOr<std::unique_ptr<HloDomainMap>> HloDomainMap::Create(
HloComputation* computation, string domain_kind) {
- auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind)));
+ auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind)));
TF_RETURN_IF_ERROR(domain_map->Populate(computation));
return std::move(domain_map);
}
/* static */ StatusOr<std::unique_ptr<HloDomainMap>> HloDomainMap::Create(
HloModule* module, string domain_kind) {
- auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind)));
+ auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind)));
for (HloComputation* computation : module->computations()) {
TF_RETURN_IF_ERROR(domain_map->Populate(computation));
}
@@ -56,14 +57,14 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
// both sides.
for (HloInstruction* operand : instruction->unique_operands()) {
if (IsDomainInstruction(operand)) {
- auto domain = MakeUnique<DomainMetadata::Domain>();
+ auto domain = absl::make_unique<DomainMetadata::Domain>();
domain->enter_domains.insert(operand);
domain->exit_domains.insert(instruction);
TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
}
}
if (instruction == instruction->parent()->root_instruction()) {
- auto domain = MakeUnique<DomainMetadata::Domain>();
+ auto domain = absl::make_unique<DomainMetadata::Domain>();
domain->enter_domains.insert(instruction);
TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
}
@@ -143,7 +144,7 @@ Status HloDomainMap::ExpandDomain(HloInstruction* instruction,
StatusOr<std::unique_ptr<DomainMetadata::Domain>> HloDomainMap::CreateDomain(
HloInstruction* instruction) const {
- auto domain = MakeUnique<DomainMetadata::Domain>();
+ auto domain = absl::make_unique<DomainMetadata::Domain>();
TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get()));
domain->instructions = MakeNonDomainInstructions(domain->reach_set);
return std::move(domain);
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index 70271be304..7d48be15cf 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
@@ -80,7 +81,7 @@ class OpNameMetadata : public DomainMetadata {
explicit OpNameMetadata(string opname) : opname_(std::move(opname)) {}
std::unique_ptr<DomainMetadata> Clone() const override {
- return MakeUnique<OpNameMetadata>(opname_);
+ return absl::make_unique<OpNameMetadata>(opname_);
}
tensorflow::StringPiece Kind() const override { return KindName(); }
@@ -110,9 +111,9 @@ std::unique_ptr<HloInstruction> OpNameDomainCreator(HloInstruction* instruction,
return nullptr;
}
std::unique_ptr<DomainMetadata> operand_side_metadata =
- MakeUnique<OpNameMetadata>(operand->metadata().op_name());
+ absl::make_unique<OpNameMetadata>(operand->metadata().op_name());
std::unique_ptr<DomainMetadata> user_side_metadata =
- MakeUnique<OpNameMetadata>(instruction->metadata().op_name());
+ absl::make_unique<OpNameMetadata>(instruction->metadata().op_name());
return HloInstruction::CreateDomain(operand->shape(), operand,
std::move(operand_side_metadata),
std::move(user_side_metadata));
@@ -474,8 +475,8 @@ ENTRY entry {
TEST_F(HloDomainTest, DumpParseNullSharding) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {});
- auto sharding_md_0 = MakeUnique<ShardingMetadata>(nullptr);
- auto sharding_md_1 = MakeUnique<ShardingMetadata>(nullptr);
+ auto sharding_md_0 = absl::make_unique<ShardingMetadata>(nullptr);
+ auto sharding_md_1 = absl::make_unique<ShardingMetadata>(nullptr);
HloInstruction* param =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p"));
HloInstruction* domain = builder.AddInstruction(HloInstruction::CreateDomain(
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 51353eea6e..35d9e799df 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -23,13 +23,14 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -95,7 +96,7 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
<< HloOpcodeString(opcode);
}
- auto result = MakeUnique<Literal>(shape);
+ auto result = absl::make_unique<Literal>(shape);
TF_RETURN_IF_ERROR(result->Populate<bool>([&](ArraySlice<int64> multi_index) {
return compare_op(lhs_literal.Get<OperandT>(multi_index),
rhs_literal.Get<OperandT>(multi_index));
@@ -125,7 +126,7 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
<< HloOpcodeString(opcode);
}
- auto result = MakeUnique<Literal>(shape);
+ auto result = absl::make_unique<Literal>(shape);
TF_RETURN_IF_ERROR(result->Populate<bool>([&](ArraySlice<int64> multi_index) {
return compare_op(lhs_literal.Get<complex64>(multi_index),
rhs_literal.Get<complex64>(multi_index));
@@ -138,44 +139,57 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
HloEvaluator::HloEvaluator(int64 max_loop_iterations)
: max_loop_iterations_(max_loop_iterations) {
- typed_visitors_[PRED] = MakeUnique<HloEvaluatorTypedVisitor<bool>>(this);
- typed_visitors_[U8] = MakeUnique<HloEvaluatorTypedVisitor<uint8>>(this);
- typed_visitors_[U16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
- return Unimplemented(
- "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: "
- "U16.");
- });
- typed_visitors_[U32] = MakeUnique<HloEvaluatorTypedVisitor<uint32>>(this);
- typed_visitors_[U64] = MakeUnique<HloEvaluatorTypedVisitor<uint64>>(this);
- typed_visitors_[S8] = MakeUnique<HloEvaluatorTypedVisitor<int8>>(this);
- typed_visitors_[S16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
- return Unimplemented(
- "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: "
- "S16.");
- });
- typed_visitors_[S32] = MakeUnique<HloEvaluatorTypedVisitor<int32>>(this);
- typed_visitors_[S64] = MakeUnique<HloEvaluatorTypedVisitor<int64>>(this);
+ typed_visitors_[PRED] =
+ absl::make_unique<HloEvaluatorTypedVisitor<bool>>(this);
+ typed_visitors_[U8] =
+ absl::make_unique<HloEvaluatorTypedVisitor<uint8>>(this);
+ typed_visitors_[U16] =
+ absl::make_unique<FunctionVisitor>([](HloInstruction*) {
+ return Unimplemented(
+ "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: "
+ "U16.");
+ });
+ typed_visitors_[U32] =
+ absl::make_unique<HloEvaluatorTypedVisitor<uint32>>(this);
+ typed_visitors_[U64] =
+ absl::make_unique<HloEvaluatorTypedVisitor<uint64>>(this);
+ typed_visitors_[S8] = absl::make_unique<HloEvaluatorTypedVisitor<int8>>(this);
+ typed_visitors_[S16] =
+ absl::make_unique<FunctionVisitor>([](HloInstruction*) {
+ return Unimplemented(
+ "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: "
+ "S16.");
+ });
+ typed_visitors_[S32] =
+ absl::make_unique<HloEvaluatorTypedVisitor<int32>>(this);
+ typed_visitors_[S64] =
+ absl::make_unique<HloEvaluatorTypedVisitor<int64>>(this);
typed_visitors_[F16] =
- MakeUnique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this);
- typed_visitors_[F32] = MakeUnique<HloEvaluatorTypedVisitor<float>>(this);
- typed_visitors_[F64] = MakeUnique<HloEvaluatorTypedVisitor<double>>(this);
- typed_visitors_[C64] = MakeUnique<HloEvaluatorTypedVisitor<complex64>>(this);
+ absl::make_unique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this);
+ typed_visitors_[F32] =
+ absl::make_unique<HloEvaluatorTypedVisitor<float>>(this);
+ typed_visitors_[F64] =
+ absl::make_unique<HloEvaluatorTypedVisitor<double>>(this);
+ typed_visitors_[C64] =
+ absl::make_unique<HloEvaluatorTypedVisitor<complex64>>(this);
// Most of the evaluator computations we use don't support BF16 (e.g.,
// std::ceil, std::tanh). To make evaluator work with BF16, we set all
// elementwise computations to be done in F32 and do BF16<->F32 conversion
// around the input and the output of the computations.
typed_visitors_[BF16] =
- MakeUnique<HloEvaluatorTypedVisitor<bfloat16, float>>(this);
-
- typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
- return Unimplemented(
- "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE.");
- });
- typed_visitors_[OPAQUE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
- return Unimplemented(
- "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE.");
- });
+ absl::make_unique<HloEvaluatorTypedVisitor<bfloat16, float>>(this);
+
+ typed_visitors_[TUPLE] =
+ absl::make_unique<FunctionVisitor>([](HloInstruction*) {
+ return Unimplemented(
+ "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE.");
+ });
+ typed_visitors_[OPAQUE] =
+ absl::make_unique<FunctionVisitor>([](HloInstruction*) {
+ return Unimplemented(
+ "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE.");
+ });
}
template <typename LiteralPtr>
@@ -555,43 +569,41 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
return Status::OK();
}
-// Returns an ShapeUtil::IndexIterationSpace that iterates over the output
-// gather dimensions while keeping the rest of the output dimensions clamped to
-// 0.
-ShapeUtil::IndexIterationSpace IterationSpaceForOutputGatherIndices(
+// Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch
+// dimensions while keeping the rest of the output dimensions clamped to 0.
+ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices(
const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) {
int64 output_rank = output_shape.dimensions_size();
std::vector<int64> index_base(output_rank, 0);
std::vector<int64> index_count;
index_count.reserve(output_rank);
for (int64 i = 0; i < output_rank; i++) {
- bool is_output_gather_dim =
- !c_binary_search(dim_numbers.output_window_dims(), i);
- index_count.push_back(is_output_gather_dim ? output_shape.dimensions(i)
- : 1);
+ bool is_output_batch_dim =
+ !absl::c_binary_search(dim_numbers.offset_dims(), i);
+ index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1);
}
return {std::move(index_base), std::move(index_count),
std::vector<int64>(output_rank, 1)};
}
-// Return an ShapeUtil::IndexIterationSpace that iterates over the output window
+// Return an ShapeUtil::IndexIterationSpace that iterates over the output slice
// dimensions while keeping the rest of the output dimensions clamped to 0.
-ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices(
- int64 output_rank, ArraySlice<int64> window_bounds,
+ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices(
+ int64 output_rank, ArraySlice<int64> slice_sizes,
const GatherDimensionNumbers& dim_numbers) {
std::vector<int64> index_base(output_rank, 0);
std::vector<int64> index_count(output_rank, 1);
- int64 window_bounds_idx = 0;
+ int64 slice_sizes_idx = 0;
for (int64 i = 0; i < output_rank; i++) {
bool is_output_window_dim =
- c_binary_search(dim_numbers.output_window_dims(), i);
+ absl::c_binary_search(dim_numbers.offset_dims(), i);
if (is_output_window_dim) {
- while (c_binary_search(dim_numbers.elided_window_dims(),
- window_bounds_idx)) {
- window_bounds_idx++;
+ while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(),
+ slice_sizes_idx)) {
+ slice_sizes_idx++;
}
- index_count[i] = window_bounds[window_bounds_idx++];
+ index_count[i] = slice_sizes[slice_sizes_idx++];
}
}
@@ -599,30 +611,30 @@ ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices(
std::vector<int64>(output_rank, 1)};
}
-// This functor computes the contribution of gather_indices to an input index
+// This functor computes the contribution of start_indices to an input index
// corresponding to an output index. That is, given an output index I, it picks
-// out the gather output indices in I and uses them to look up a gather index,
-// G, from the gather indices tensor, and expands G into the input space
-// according to gather_dims_to_operand_dims.
-class OutputGatherIndexToInputIndex {
+// out the batch indices in I and uses them to look up a starting index, G, from
+// the start indices tensor, and expands G into the input space according to
+// start_index_map.
+class OutputBatchIndexToInputIndex {
public:
// The constructor does some setup work that is amortized across all
// iterations.
- explicit OutputGatherIndexToInputIndex(
+ explicit OutputBatchIndexToInputIndex(
const GatherDimensionNumbers* dim_numbers, const Shape& input_shape,
- const Shape& output_shape, const Literal* gather_indices)
- : dim_numbers_(*dim_numbers), gather_indices_(*gather_indices) {
+ const Shape& output_shape, const Literal* start_indices)
+ : dim_numbers_(*dim_numbers), start_indices_(*start_indices) {
for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
- output_dim_is_gather_dims_.push_back(
- !c_binary_search(dim_numbers_.output_window_dims(), i));
+ output_dim_is_batch_dims_.push_back(
+ !absl::c_binary_search(dim_numbers_.offset_dims(), i));
}
for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
int64 index_of_input_dim_in_index_vector =
- std::distance(dim_numbers_.gather_dims_to_operand_dims().begin(),
- c_find(dim_numbers_.gather_dims_to_operand_dims(), i));
+ std::distance(dim_numbers_.start_index_map().begin(),
+ absl::c_find(dim_numbers_.start_index_map(), i));
if (index_of_input_dim_in_index_vector ==
- dim_numbers_.gather_dims_to_operand_dims_size()) {
+ dim_numbers_.start_index_map_size()) {
input_dim_value_to_index_vector_.push_back(-1);
} else {
input_dim_value_to_index_vector_.push_back(
@@ -630,14 +642,14 @@ class OutputGatherIndexToInputIndex {
}
}
- index_vector_index_.resize(gather_indices_.shape().dimensions_size());
+ index_vector_index_.resize(start_indices_.shape().dimensions_size());
input_index_.resize(input_shape.dimensions_size());
int64 index_vector_size =
- gather_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
+ start_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
index_vector_.resize(index_vector_size);
}
- // Returns the contribution of gather_indices to the input index corresponding
+ // Returns the contribution of start_indices to the input index corresponding
// to output_index. See gather_inner_loop_body.
//
// This is conceptually a stateless transformation from output_index to the
@@ -659,7 +671,7 @@ class OutputGatherIndexToInputIndex {
}
private:
- // Propagates the gather index dimensions from the output index into
+ // Propagates the batch dimensions from the output index into
// index_vector_index_ by mutating index_vector_index_ in place. Does not
// update the dim_numbers.index_vector_dim() dimension -- that's the dimension
// we iterate over in FetchIndexVector.
@@ -667,7 +679,7 @@ class OutputGatherIndexToInputIndex {
ArraySlice<int64> output_index) {
int64 index_vector_index_i = 0;
for (int64 i = 0, e = output_index.size(); i < e; i++) {
- if (!output_dim_is_gather_dims_[i]) {
+ if (!output_dim_is_batch_dims_[i]) {
continue;
}
@@ -679,14 +691,14 @@ class OutputGatherIndexToInputIndex {
}
}
- // Populates index_vector_ by iterating over gather_indices_ according to
+ // Populates index_vector_ by iterating over start_indices_ according to
// index_vector_index_.
Status FetchIndexVector() {
int64 index_vector_dim = dim_numbers_.index_vector_dim();
for (int64 i = 0, e = index_vector_.size(); i < e; i++) {
index_vector_index_[index_vector_dim] = i;
- TF_ASSIGN_OR_RETURN(index_vector_[i], gather_indices_.GetIntegralAsS64(
- index_vector_index_));
+ TF_ASSIGN_OR_RETURN(index_vector_[i],
+ start_indices_.GetIntegralAsS64(index_vector_index_));
}
return Status::OK();
}
@@ -708,15 +720,15 @@ class OutputGatherIndexToInputIndex {
// PropagateIndexVectorToInputIndex.
std::vector<int64> input_dim_value_to_index_vector_;
- // output_dim_is_gather_dims_[i] is true iff the output index i is a gather
+ // output_dim_is_batch_dims_[i] is true iff the output index i is a gather
// dimension.
- std::vector<bool> output_dim_is_gather_dims_;
+ std::vector<bool> output_dim_is_batch_dims_;
- // The buffer into which we construct an index into gather_indices_ to fetch
+ // The buffer into which we construct an index into start_indices_ to fetch
// the index vector.
std::vector<int64> index_vector_index_;
- // The index vector fetched from gather_indices_.
+ // The index vector fetched from start_indices_.
std::vector<int64> index_vector_;
// The result computed by this functor. operator() returns an ArraySlice into
@@ -724,24 +736,23 @@ class OutputGatherIndexToInputIndex {
std::vector<int64> input_index_;
const GatherDimensionNumbers& dim_numbers_;
- const Literal& gather_indices_;
+ const Literal& start_indices_;
};
-// This functor computes the contribution of the window indices in an output
+// This functor computes the contribution of the offset indices in an output
// index to an input index. That is, given an output index I it picks out the
-// output window indices in I and expands it into a window index into the input
-// shape.
-class OutputWindowIndexToInputIndex {
+// output offset indices in I and expands it into an index into the input shape.
+class OutputOffsetIndexToInputIndex {
public:
// The constructor does some setup work that is amortized across all
// iterations.
- explicit OutputWindowIndexToInputIndex(
+ explicit OutputOffsetIndexToInputIndex(
const GatherDimensionNumbers& dim_numbers, const Shape& input_shape,
const Shape& output_shape) {
std::vector<int64> window_index_to_output_index;
int64 output_index_count = 0;
for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
- if (c_binary_search(dim_numbers.output_window_dims(), i)) {
+ if (absl::c_binary_search(dim_numbers.offset_dims(), i)) {
window_index_to_output_index.push_back(output_index_count++);
} else {
output_index_count++;
@@ -750,7 +761,7 @@ class OutputWindowIndexToInputIndex {
int64 window_dim_count = 0;
for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
- if (c_binary_search(dim_numbers.elided_window_dims(), i)) {
+ if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
input_dim_value_to_output_index_.push_back(-1);
} else {
input_dim_value_to_output_index_.push_back(
@@ -808,20 +819,20 @@ class OutputWindowIndexToInputIndex {
// Rehapes the gather indices input to have a trailing degenerate `1` dimension
// if necessary. Hands over the ownership of the newly created literal (if
-// there is one) to `reshaped_gather_indices`.
+// there is one) to `reshaped_start_indices`.
static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
- int64 index_vector_dim, const Literal& gather_indices,
- std::unique_ptr<Literal>* reshaped_gather_indices) {
- if (gather_indices.shape().dimensions_size() != index_vector_dim) {
- return std::cref(gather_indices);
+ int64 index_vector_dim, const Literal& start_indices,
+ std::unique_ptr<Literal>* reshaped_start_indices) {
+ if (start_indices.shape().dimensions_size() != index_vector_dim) {
+ return std::cref(start_indices);
}
- std::vector<int64> new_shape(gather_indices.shape().dimensions().begin(),
- gather_indices.shape().dimensions().end());
+ std::vector<int64> new_shape(start_indices.shape().dimensions().begin(),
+ start_indices.shape().dimensions().end());
new_shape.push_back(1);
- TF_ASSIGN_OR_RETURN(*reshaped_gather_indices,
- gather_indices.Reshape(new_shape));
- return std::cref(**reshaped_gather_indices);
+ TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
+ start_indices.Reshape(new_shape));
+ return std::cref(**reshaped_start_indices);
}
Status HloEvaluator::HandleGather(HloInstruction* gather) {
@@ -830,34 +841,33 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
const GatherDimensionNumbers& dim_numbers =
gather->gather_dimension_numbers();
const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
- std::unique_ptr<Literal> reshaped_gather_indices;
+ std::unique_ptr<Literal> reshaped_start_indices;
TF_ASSIGN_OR_RETURN(
- const Literal& gather_indices,
+ const Literal& start_indices,
ReshapedGatherIndices(dim_numbers.index_vector_dim(),
GetEvaluatedLiteralFor(gather->operand(1)),
- &reshaped_gather_indices));
+ &reshaped_start_indices));
// We iterate over the gather dimensions in the output shape in an outer loop
// nest, and iterate over the window dimensions in the output shape in an
// inner loop nest.
- ShapeUtil::IndexIterationSpace gather_indices_iteration_space =
- IterationSpaceForOutputGatherIndices(shape, dim_numbers);
- ShapeUtil::IndexIterationSpace window_indices_iteration_space =
- IterationSpaceForOutputWindowIndices(
- shape.dimensions_size(), gather->gather_window_bounds(), dim_numbers);
+ ShapeUtil::IndexIterationSpace start_indices_iteration_space =
+ IterationSpaceForOutputBatchIndices(shape, dim_numbers);
+ ShapeUtil::IndexIterationSpace offset_indices_iteration_space =
+ IterationSpaceForOutputOffsetIndices(
+ shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers);
// Scratch buffers that hold an index in the output shape and the
// corresponding index in the input shape.
std::vector<int64> input_index(operand.shape().dimensions_size());
std::vector<int64> output_index(gather->shape().dimensions_size());
- std::vector<int64> input_gather_index_clamped(
- operand.shape().dimensions_size());
+ std::vector<int64> input_index_clamped(operand.shape().dimensions_size());
- OutputGatherIndexToInputIndex output_gather_index_to_input_index(
+ OutputBatchIndexToInputIndex output_batch_index_to_input_index(
&gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
- /*output_shape=*/shape, &gather_indices);
- OutputWindowIndexToInputIndex output_window_index_to_input_index(
+ /*output_shape=*/shape, &start_indices);
+ OutputOffsetIndexToInputIndex output_offset_index_to_input_index(
gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
/*output_shape=*/shape);
@@ -869,29 +879,29 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
ArraySlice<int64> output_gather_index) -> StatusOr<bool> {
TF_ASSIGN_OR_RETURN(
ArraySlice<int64> input_window_index,
- output_window_index_to_input_index(output_window_index));
+ output_offset_index_to_input_index(output_window_index));
for (int i = 0, e = output_index.size(); i < e; i++) {
output_index[i] = output_gather_index[i] + output_window_index[i];
DCHECK_LT(output_index[i], shape.dimensions(i));
}
for (int i = 0, e = input_gather_index.size(); i < e; i++) {
int64 output_dim =
- output_window_index_to_input_index.input_dim_value_to_output_index(i);
+ output_offset_index_to_input_index.input_dim_value_to_output_index(i);
// If 'output_dim' is -1, it means 'i' is an elided window dim. This means
// we set the iteration index to 0, so for the purpose of the following
// calculations we can consider the output dimension size to be 1.
int64 output_dim_size =
output_dim == -1 ? 1 : shape.dimensions(output_dim);
// Clamp the gather index so that the gather region fits in the operand.
- // input_gather_index_clamped[i] = clamp(input_gather_index[i], 0,
+ // input_index_clamped[i] = clamp(input_gather_index[i], 0,
// operand_shape.dimensions(i) -
// output_dim_size);
- input_gather_index_clamped[i] =
+ input_index_clamped[i] =
std::min(operand_shape.dimensions(i) - output_dim_size,
std::max(0LL, input_gather_index[i]));
}
for (int i = 0, e = input_index.size(); i < e; i++) {
- input_index[i] = input_gather_index_clamped[i] + input_window_index[i];
+ input_index[i] = input_index_clamped[i] + input_window_index[i];
DCHECK_GE(input_index[i], 0);
DCHECK_LT(input_index[i], operand_shape.dimensions(i));
}
@@ -902,18 +912,17 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
auto gather_outer_loop_body =
[&](ArraySlice<int64> output_gather_index) -> StatusOr<bool> {
- TF_ASSIGN_OR_RETURN(
- ArraySlice<int64> input_gather_index,
- output_gather_index_to_input_index(output_gather_index));
+ TF_ASSIGN_OR_RETURN(ArraySlice<int64> input_gather_index,
+ output_batch_index_to_input_index(output_gather_index));
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
- shape, window_indices_iteration_space,
+ shape, offset_indices_iteration_space,
std::bind(gather_inner_loop_body, std::placeholders::_1,
input_gather_index, output_gather_index)));
return true;
};
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
- shape, gather_indices_iteration_space, gather_outer_loop_body));
+ shape, start_indices_iteration_space, gather_outer_loop_body));
evaluated_[gather] = std::move(result);
return Status::OK();
}
@@ -960,7 +969,7 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
- evaluated_[get_tuple_element] = MakeUnique<Literal>(
+ evaluated_[get_tuple_element] = absl::make_unique<Literal>(
ShapeUtil::GetTupleElementShape(operand->shape(), index));
return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal,
/*dest_shape_index=*/{},
@@ -1162,10 +1171,11 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
result_keys.push_back(key_value.first);
result_values.push_back(key_value.second);
}
- auto result_keys_literal = MakeUnique<Literal>(keys_literal.shape());
+ auto result_keys_literal = absl::make_unique<Literal>(keys_literal.shape());
result_keys_literal->PopulateR1(
tensorflow::gtl::ArraySlice<KeyType>(result_keys));
- auto result_values_literal = MakeUnique<Literal>(values_literal.shape());
+ auto result_values_literal =
+ absl::make_unique<Literal>(values_literal.shape());
result_values_literal->PopulateR1(
tensorflow::gtl::ArraySlice<ValueType>(result_values));
return std::make_pair(std::move(result_keys_literal),
@@ -1180,8 +1190,9 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
} else {
// For R2 sort, the desired semantics are to sort each matrix row
// independently.
- auto keys_result_literal = MakeUnique<Literal>(keys_literal.shape());
- auto values_result_literal = MakeUnique<Literal>(values_literal.shape());
+ auto keys_result_literal = absl::make_unique<Literal>(keys_literal.shape());
+ auto values_result_literal =
+ absl::make_unique<Literal>(values_literal.shape());
int64 r1_length = keys_literal.shape().dimensions(1);
for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) {
TF_ASSIGN_OR_RETURN(auto keys_r1_slice,
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index a4c37ef328..7588916de5 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -226,7 +226,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
ShapeUtil::HumanString(operand->shape()).c_str());
}
- auto result = MakeUnique<Literal>(shape);
+ auto result = absl::make_unique<Literal>(shape);
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
return unary_op(operand_literal.Get<NativeT>(multi_index));
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 3ac6d68df3..4b8e6260ac 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
@@ -52,7 +53,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
public HloVerifiedTestBase {
protected:
HloEvaluatorTest() : use_bfloat16_(GetParam()) {
- evaluator_ = MakeUnique<HloEvaluator>();
+ evaluator_ = absl::make_unique<HloEvaluator>();
}
std::unique_ptr<Literal> Evaluate(
@@ -523,7 +524,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected_array = MakeUnique<Array4D<float>>(8, 5, 1, 1);
+ auto expected_array = absl::make_unique<Array4D<float>>(8, 5, 1, 1);
expected_array->Fill(kPadValue);
(*expected_array)(1, 0, 0, 0) = 1.0f;
(*expected_array)(1, 2, 0, 0) = 2.0f;
@@ -547,7 +548,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
// { 9, 10, 11 },
// { 13, 14, 15 },
// }
- auto input_array = MakeUnique<Array2D<float>>(4, 3);
+ auto input_array = absl::make_unique<Array2D<float>>(4, 3);
input_array->FillUnique(1.0f);
auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
HloInstruction* input_instruction =
@@ -568,7 +569,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
std::unique_ptr<Literal> result = Evaluate();
// f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 }
- auto expected_array = MakeUnique<Array2D<float>>(1, 5);
+ auto expected_array = absl::make_unique<Array2D<float>>(1, 5);
(*expected_array)(0, 0) = 7.0f;
(*expected_array)(0, 1) = 2.718f;
(*expected_array)(0, 2) = 2.718f;
@@ -588,7 +589,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
// { 9, 10, 11 },
// { 13, 14, 15 },
// }
- auto input_array = MakeUnique<Array2D<float>>(4, 3);
+ auto input_array = absl::make_unique<Array2D<float>>(4, 3);
input_array->FillUnique(1.0f);
auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
HloInstruction* input_instruction =
@@ -612,7 +613,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected_array = MakeUnique<Array2D<float>>(0, 9);
+ auto expected_array = absl::make_unique<Array2D<float>>(0, 9);
auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
@@ -628,7 +629,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
// { 3 },
// { 4 },
// }
- auto lhs_array = MakeUnique<Array2D<float>>(4, 1);
+ auto lhs_array = absl::make_unique<Array2D<float>>(4, 1);
lhs_array->FillUnique(1.0f);
auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
HloInstruction* lhs_instruction =
@@ -679,7 +680,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
// { 3, 4 },
// { 5, 6 },
// }
- auto rhs_array = MakeUnique<Array2D<float>>(3, 2);
+ auto rhs_array = absl::make_unique<Array2D<float>>(3, 2);
rhs_array->FillUnique(1.0f);
auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
HloInstruction* rhs_instruction =
@@ -710,7 +711,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
// { 9, 10, 11 },
// { 13, 14, 15 },
// }
- auto lhs_array = MakeUnique<Array2D<float>>(4, 3);
+ auto lhs_array = absl::make_unique<Array2D<float>>(4, 3);
lhs_array->FillUnique(1.0f);
auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
HloInstruction* lhs_instruction =
@@ -722,7 +723,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
// { 3, 4 },
// { 5, 6 },
// }
- auto rhs_array = MakeUnique<Array2D<float>>(3, 2);
+ auto rhs_array = absl::make_unique<Array2D<float>>(3, 2);
rhs_array->FillUnique(1.0f);
auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
HloInstruction* rhs_instruction =
@@ -1297,7 +1298,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
- auto arg_array = MakeUnique<Array2D<float>>(2, 3);
+ auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
arg_array->FillUnique(1.0f);
auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
@@ -1339,7 +1340,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
- auto arg_array = MakeUnique<Array2D<float>>(2, 3);
+ auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
arg_array->FillUnique(1.0f);
auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
@@ -1390,7 +1391,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
- auto arg_array = MakeUnique<Array2D<float>>(2, 3);
+ auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
arg_array->FillUnique(1.0f);
auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
@@ -1511,7 +1512,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) {
// { 9, 10, 11, 12, 13 },
// { 17, 18, 19, 20, 21 },
// }
- auto operand_array = MakeUnique<Array2D<float>>(3, 5);
+ auto operand_array = absl::make_unique<Array2D<float>>(3, 5);
operand_array->FillUnique(1.0f);
auto operand_literal =
LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
@@ -1544,7 +1545,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) {
// { 1, 2, 3, 4 },
// { 5, 6, 7, 8 },
// }
- auto operand_array = MakeUnique<Array2D<float>>(2, 4);
+ auto operand_array = absl::make_unique<Array2D<float>>(2, 4);
operand_array->FillUnique(1.0f);
auto operand_literal =
LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
@@ -1580,7 +1581,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) {
// { 1, 2, 3, 4 },
// { 5, 6, 7, 8 },
// }
- auto operand_array = MakeUnique<Array2D<float>>(2, 4);
+ auto operand_array = absl::make_unique<Array2D<float>>(2, 4);
operand_array->FillUnique(1.0f);
auto operand_literal =
LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
@@ -1614,7 +1615,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
- auto operand_array = MakeUnique<Array2D<double>>(2, 3);
+ auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
operand_array->FillUnique(1.0);
auto operand_literal =
LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
@@ -1651,7 +1652,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
- auto operand_array = MakeUnique<Array2D<double>>(2, 3);
+ auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
operand_array->FillUnique(1.0);
auto operand_literal2 =
LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
@@ -1687,7 +1688,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
- auto operand_array = MakeUnique<Array2D<double>>(2, 3);
+ auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
operand_array->FillUnique(1.0);
HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
@@ -1826,21 +1827,20 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[2,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 3}
+ slice_sizes={1, 3}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(LiteralTestUtil::Equal(
*LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
@@ -1851,21 +1851,20 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(LiteralTestUtil::Equal(
*LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
@@ -1876,22 +1875,22 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,3,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
EXPECT_TRUE(LiteralTestUtil::Equal(
*LiteralUtil::CreateR3<int32>(
{{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
@@ -1902,11 +1901,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
}
)";
ParseAndVerifyModule(hlo_text);
@@ -1914,11 +1913,11 @@ ENTRY main {
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest,
@@ -1930,11 +1929,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
}
)";
ParseAndVerifyModule(hlo_text);
@@ -1942,11 +1941,11 @@ ENTRY main {
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
@@ -1957,21 +1956,20 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[1,1] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={0,1},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
EXPECT_TRUE(
LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{5}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
@@ -1982,21 +1980,21 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,1,1] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
EXPECT_TRUE(
LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
@@ -2007,20 +2005,19 @@ ENTRY main {
operand = s32[3,0] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[2,0] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 0}
+ slice_sizes={1, 0}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(
LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{}, {}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
@@ -2031,21 +2028,21 @@ ENTRY main {
operand = s32[3] parameter(0)
indices = s32[2,2,1] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1}
+ slice_sizes={1}
}
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
EXPECT_TRUE(
LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ *Evaluate({operand.get(), start_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) {
@@ -2517,6 +2514,31 @@ TEST_P(HloEvaluatorTest, DoesCompareBF16) {
std::move(rhs));
}
+TEST_P(HloEvaluatorTest, Bf16Reduction) {
+ const string hlo_text = R"(
+HloModule Bf16Reduction
+
+add_bf16 (lhs: bf16[], rhs: bf16[]) -> bf16[] {
+ lhs = bf16[] parameter(0)
+ rhs = bf16[] parameter(1)
+ ROOT add = bf16[] add(bf16[] lhs, bf16[] rhs)
+}
+
+ENTRY main {
+ arg0 = bf16[4]{0} parameter(0)
+ init = bf16[] constant(0)
+ ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_bf16
+}
+)";
+ ParseAndVerifyModule(hlo_text);
+
+ std::unique_ptr<Literal> arg = LiteralUtil::CreateR1<bfloat16>(
+ {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)});
+ std::unique_ptr<Literal> expected =
+ LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *Evaluate({arg.get()})));
+}
+
INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest,
::testing::ValuesIn(use_bf16_params));
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 084b49b478..83d7b404f0 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
@@ -86,6 +88,29 @@ bool SafeLess(const NativeT& a, const NativeT& b) {
// of this class.
template <typename ReturnT, typename ElementwiseT = ReturnT>
class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
+ private:
+ // Get the value in the given literal static_cast as a double.
+ template <
+ typename NativeT,
+ typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
+ double GetAsDouble(const Literal& literal,
+ tensorflow::gtl::ArraySlice<int64> input_index) {
+ return static_cast<double>(literal.Get<NativeT>(input_index));
+ }
+
+ // Specialization for complex types. In this case it is not possible to
+ // static_cast value to a double so just CHECK fail. This method is not used
+ // at run-time, but must be available at compile-time to keep the compiler
+ // happy.
+ template <
+ typename NativeT,
+ typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
+ double GetAsDouble(const Literal& literal,
+ tensorflow::gtl::ArraySlice<int64> input_index) {
+ LOG(FATAL) << "Trying to get complex literal as double: "
+ << literal.ToString();
+ }
+
public:
explicit HloEvaluatorTypedVisitor(HloEvaluator* p) : parent_(p) {}
@@ -873,7 +898,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
<< ShapeUtil::HumanString(inferred_return_shape);
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
- auto result = MakeUnique<Literal>(result_shape);
+ auto result = absl::make_unique<Literal>(result_shape);
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> out_index) {
@@ -1030,7 +1055,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return static_cast<ReturnT>(result_val);
};
- auto result = MakeUnique<Literal>(result_shape);
+ auto result = absl::make_unique<Literal>(result_shape);
TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func));
parent_->evaluated_[conv] = std::move(result);
@@ -1104,7 +1129,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
}
- auto result = MakeUnique<Literal>(dot->shape());
+ auto result = absl::make_unique<Literal>(dot->shape());
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> result_index) {
ElementwiseT result_val = static_cast<ElementwiseT>(0);
@@ -1153,7 +1178,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Create new HLO of padded shape with padding value.
ReturnT scalar =
parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({});
- auto result = MakeUnique<Literal>(pad->shape());
+ auto result = absl::make_unique<Literal>(pad->shape());
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&scalar](tensorflow::gtl::ArraySlice<int64> multi_index) {
return scalar;
@@ -1318,7 +1343,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto operands = map->operands();
HloComputation* computation = map->to_apply();
- auto result = MakeUnique<Literal>(map->shape());
+ auto result = absl::make_unique<Literal>(map->shape());
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
@@ -1432,7 +1457,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
[](const ReturnT& a, const ReturnT& b) {
return SafeLess<ReturnT>(a, b);
});
- auto result_literal = MakeUnique<Literal>(keys_literal.shape());
+ auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
result_literal->PopulateR1(
tensorflow::gtl::ArraySlice<ReturnT>(result_data));
VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
@@ -1444,7 +1469,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
} else {
// For R2 sort, the desired semantics are to sort each matrix row
// independently.
- auto result_literal = MakeUnique<Literal>(keys_literal.shape());
+ auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
int64 r1_length = keys->shape().dimensions(1);
for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
TF_ASSIGN_OR_RETURN(auto r1_slice,
@@ -1518,7 +1543,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- auto result = MakeUnique<Literal>(reduce->shape());
+ auto result = absl::make_unique<Literal>(reduce->shape());
// For each resulting dimension, calculate and assign computed value.
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
@@ -1536,7 +1561,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
IsScalarAdd(function)) {
double computed_result = 0;
auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) {
- computed_result += arg_literal.Get<float>(input_index);
+ computed_result += GetAsDouble<ReturnT>(arg_literal, input_index);
return true;
};
ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts,
@@ -1599,7 +1624,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
auto init_scalar = init_literal.Get<ReturnT>({});
- auto result = MakeUnique<Literal>(select_and_scatter->shape());
+ auto result = absl::make_unique<Literal>(select_and_scatter->shape());
// Initialize result array with the init value.
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
@@ -1735,7 +1760,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape()));
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- auto result = MakeUnique<Literal>(reduce_window->shape());
+ auto result = absl::make_unique<Literal>(reduce_window->shape());
// For each resulting dimension, calculate and assign computed value.
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> output_index) {
@@ -1802,7 +1827,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> index_count(updates_rank, 1);
for (int64 i = 0; i < updates_rank; i++) {
bool is_update_scatter_dim =
- !c_binary_search(dim_numbers.update_window_dims(), i);
+ !absl::c_binary_search(dim_numbers.update_window_dims(), i);
if (is_update_scatter_dim) {
index_count[i] = updates_shape.dimensions(i);
}
@@ -1821,7 +1846,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> index_count(updates_rank, 1);
for (int64 i = 0; i < updates_rank; i++) {
bool is_update_window_dim =
- c_binary_search(dim_numbers.update_window_dims(), i);
+ absl::c_binary_search(dim_numbers.update_window_dims(), i);
if (is_update_window_dim) {
index_count[i] = updates_shape.dimensions(i);
}
@@ -1848,7 +1873,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
: dim_numbers_(*dim_numbers), scatter_indices_(*scatter_indices) {
for (int64 i = 0; i < updates_shape.dimensions_size(); i++) {
update_dim_is_scatter_dims_.push_back(
- !c_binary_search(dim_numbers_.update_window_dims(), i));
+ !absl::c_binary_search(dim_numbers_.update_window_dims(), i));
}
for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
@@ -1978,7 +2003,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> window_index_to_update_index;
int64 update_index_count = 0;
for (int64 i = 0; i < updates_shape.dimensions_size(); i++) {
- if (c_binary_search(dim_numbers.update_window_dims(), i)) {
+ if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) {
window_index_to_update_index.push_back(update_index_count++);
} else {
update_index_count++;
@@ -1987,7 +2012,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
int64 window_dim_count = 0;
for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
- if (c_binary_search(dim_numbers.inserted_window_dims(), i)) {
+ if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) {
input_dim_value_to_update_index_.push_back(-1);
} else {
input_dim_value_to_update_index_.push_back(
@@ -2388,7 +2413,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::is_same<NativeT, int32>::value ||
std::is_same<NativeT, uint32>::value>::type* = nullptr>
Status HandleIota(HloInstruction* iota) {
- auto result = MakeUnique<Literal>(iota->shape());
+ auto result = absl::make_unique<Literal>(iota->shape());
auto data = result->data<ReturnT>();
std::iota(data.begin(), data.end(), 0);
parent_->evaluated_[iota] = std::move(result);
@@ -2470,7 +2495,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
std::vector<int64> operand_indices(start.size());
- auto result = MakeUnique<Literal>(result_shape);
+ auto result = absl::make_unique<Literal>(result_shape);
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
for (int64 i = 0; i < operand_indices.size(); ++i) {
@@ -2556,7 +2581,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
- auto result = MakeUnique<Literal>(shape);
+ auto result = absl::make_unique<Literal>(shape);
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
@@ -2594,7 +2619,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
- auto result = MakeUnique<Literal>(shape);
+ auto result = absl::make_unique<Literal>(shape);
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
index c3ccbf0f0c..de3d7a1677 100644
--- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/human_readable_profile_builder.h"
@@ -49,7 +51,7 @@ std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData(
size_t profile_counters_size = hlo_profile_index_map.total_count();
std::unique_ptr<HloProfilePrinterData> profile_printer_data =
- MakeUnique<HloProfilePrinterData>();
+ absl::make_unique<HloProfilePrinterData>();
profile_printer_data->set_profile_counters_size(profile_counters_size);
profile_printer_data->mutable_computation_infos()->Reserve(
hlo_profile_index_map.computation_count());
@@ -67,11 +69,11 @@ std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData(
// The profile indices were computed deterministically in
// HloProfileIndexMap::HloProfileIndexMap.
- c_sort(computation_and_profile_idx_list,
- [](const std::pair<const HloComputation*, int64>& left,
- const std::pair<const HloComputation*, int64>& right) {
- return left.second < right.second;
- });
+ absl::c_sort(computation_and_profile_idx_list,
+ [](const std::pair<const HloComputation*, int64>& left,
+ const std::pair<const HloComputation*, int64>& right) {
+ return left.second < right.second;
+ });
for (const auto& pair : computation_and_profile_idx_list) {
CHECK_LT(pair.second, profile_counters_size);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 8690f2cdaa..e3d6b2e753 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -21,10 +21,11 @@ limitations under the License.
#include <unordered_set>
#include <utility>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -224,7 +225,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
Literal::CreateFromProto(proto.literal()));
instruction = CreateConstant(std::move(literal));
} else {
- instruction = MakeUnique<HloConstantInstruction>(proto.shape());
+ instruction = absl::make_unique<HloConstantInstruction>(proto.shape());
}
break;
}
@@ -281,27 +282,14 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
case HloOpcode::kInfeed: {
const Shape& data_shape =
ShapeUtil::GetTupleElementShape(proto.shape(), 0);
- if (proto.operand_ids_size() == 0) {
- // TODO(b/80000000): Remove this when all uses of infeed are
- // converted to take tokens.
- instruction = CreateInfeed(data_shape, proto.infeed_config());
- } else {
- CHECK_EQ(proto.operand_ids_size(), 1);
- instruction =
- CreateInfeed(data_shape, operands(0), proto.infeed_config());
- }
+ TF_RET_CHECK(proto.operand_ids_size() == 1);
+ instruction =
+ CreateInfeed(data_shape, operands(0), proto.infeed_config());
} break;
case HloOpcode::kOutfeed:
- if (proto.operand_ids_size() == 1) {
- // TODO(b/80000000): Remove this when all uses of outfeed are
- // converted to take tokens.
- instruction = CreateOutfeed(proto.outfeed_shape(), operands(0),
- proto.outfeed_config());
- } else {
- CHECK_EQ(proto.operand_ids_size(), 2);
- instruction = CreateOutfeed(proto.outfeed_shape(), operands(0),
- operands(1), proto.outfeed_config());
- }
+ TF_RET_CHECK(proto.operand_ids_size() == 2);
+ instruction = CreateOutfeed(proto.outfeed_shape(), operands(0),
+ operands(1), proto.outfeed_config());
break;
case HloOpcode::kCrossReplicaSum: {
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
@@ -335,9 +323,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< proto.operand_ids_size();
TF_RET_CHECK(proto.has_window());
TF_RET_CHECK(proto.has_convolution_dimension_numbers());
- instruction =
- CreateConvolve(proto.shape(), operands(0), operands(1),
- proto.window(), proto.convolution_dimension_numbers());
+ instruction = CreateConvolve(
+ proto.shape(), operands(0), operands(1), proto.window(),
+ proto.convolution_dimension_numbers(),
+ std::max(static_cast<int64>(proto.feature_group_count()), 1LL));
break;
case HloOpcode::kReduceWindow:
TF_RET_CHECK(proto.operand_ids_size() == 2)
@@ -391,7 +380,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< "DynamicSlice instruction should have 2 operands but sees "
<< proto.operand_ids_size();
std::vector<int64> slice_sizes(proto.dynamic_slice_sizes_size());
- c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
+ absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
instruction = CreateDynamicSlice(proto.shape(), operands(0), operands(1),
slice_sizes);
break;
@@ -403,14 +392,14 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.has_gather_dimension_numbers())
<< "Gather instruction should have GatherDimensionNumbers set.";
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers =
- MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers());
- std::vector<int64> gather_window_bounds;
- for (int64 bound : proto.gather_window_bounds()) {
- gather_window_bounds.push_back(bound);
+ absl::make_unique<GatherDimensionNumbers>(
+ proto.gather_dimension_numbers());
+ std::vector<int64> gather_slice_sizes;
+ for (int64 bound : proto.gather_slice_sizes()) {
+ gather_slice_sizes.push_back(bound);
}
- instruction =
- CreateGather(proto.shape(), operands(0), operands(1),
- *gather_dimension_numbers, gather_window_bounds);
+ instruction = CreateGather(proto.shape(), operands(0), operands(1),
+ *gather_dimension_numbers, gather_slice_sizes);
break;
}
case HloOpcode::kScatter: {
@@ -422,15 +411,16 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Scatter instruction should have 1 called computation but sees "
<< proto.called_computation_ids_size();
- auto scatter_dimension_numbers = MakeUnique<ScatterDimensionNumbers>(
- proto.scatter_dimension_numbers());
+ auto scatter_dimension_numbers =
+ absl::make_unique<ScatterDimensionNumbers>(
+ proto.scatter_dimension_numbers());
instruction =
CreateScatter(proto.shape(), operands(0), operands(1), operands(2),
computations(0), *scatter_dimension_numbers);
break;
}
default: {
- instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
+ instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
TF_RET_CHECK(ContainsKey(instruction_map, operand_id))
<< "No instruction with id " << operand_id;
@@ -461,7 +451,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
if (proto.has_dot_dimension_numbers()) {
instruction->dot_dimension_numbers_ =
- MakeUnique<DotDimensionNumbers>(proto.dot_dimension_numbers());
+ absl::make_unique<DotDimensionNumbers>(proto.dot_dimension_numbers());
}
if (proto.has_sharding()) {
@@ -475,34 +465,36 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
int64 parameter_number, const Shape& shape, const string& name) {
- return MakeUnique<HloParameterInstruction>(parameter_number, shape, name);
+ return absl::make_unique<HloParameterInstruction>(parameter_number, shape,
+ name);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
const string& tag, HloInstruction* operand) {
- return MakeUnique<HloTraceInstruction>(tag, operand);
+ return absl::make_unique<HloTraceInstruction>(tag, operand);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
std::unique_ptr<Literal> literal) {
- return MakeUnique<HloConstantInstruction>(std::move(literal));
+ return absl::make_unique<HloConstantInstruction>(std::move(literal));
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota(
const Shape& shape) {
- return WrapUnique(new HloInstruction(HloOpcode::kIota, shape));
+ return absl::WrapUnique(new HloInstruction(HloOpcode::kIota, shape));
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateGetTupleElement(const Shape& shape,
HloInstruction* operand, int64 index) {
- return MakeUnique<HloGetTupleElementInstruction>(shape, operand, index);
+ return absl::make_unique<HloGetTupleElementInstruction>(shape, operand,
+ index);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
const Shape& shape, RandomDistribution distribution,
tensorflow::gtl::ArraySlice<HloInstruction*> parameters) {
- return MakeUnique<HloRngInstruction>(shape, distribution, parameters);
+ return absl::make_unique<HloRngInstruction>(shape, distribution, parameters);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
@@ -512,7 +504,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
// It is impossible to copy an opaque shape, we don't know how big it is.
CHECK(!ShapeUtil::IsOpaque(shape));
}
- auto instruction = WrapUnique(new HloInstruction(opcode, shape));
+ auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
for (auto operand : operands) {
instruction->AppendOperand(operand);
}
@@ -617,31 +609,33 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* map_computation) {
- return MakeUnique<HloMapInstruction>(shape, operands, map_computation);
+ return absl::make_unique<HloMapInstruction>(shape, operands, map_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers) {
- return MakeUnique<HloConvolutionInstruction>(shape, lhs, rhs, window,
- dimension_numbers);
+ const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
+ return absl::make_unique<HloConvolutionInstruction>(
+ shape, lhs, rhs, window, dimension_numbers, feature_group_count);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
const Shape& shape, HloInstruction* operand, FftType fft_type,
tensorflow::gtl::ArraySlice<int64> fft_length) {
- return MakeUnique<HloFftInstruction>(shape, operand, fft_type, fft_length);
+ return absl::make_unique<HloFftInstruction>(shape, operand, fft_type,
+ fft_length);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dimension_numbers) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
+ auto instruction =
+ absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
instruction->AppendOperand(lhs);
instruction->AppendOperand(rhs);
instruction->dot_dimension_numbers_ =
- MakeUnique<DotDimensionNumbers>(dimension_numbers);
+ absl::make_unique<DotDimensionNumbers>(dimension_numbers);
return instruction;
}
@@ -650,10 +644,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2);
CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2);
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
+ auto instruction =
+ absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
instruction->AppendOperand(lhs);
instruction->AppendOperand(rhs);
- instruction->dot_dimension_numbers_ = MakeUnique<DotDimensionNumbers>();
+ instruction->dot_dimension_numbers_ =
+ absl::make_unique<DotDimensionNumbers>();
instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1);
instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0);
return instruction;
@@ -664,7 +660,7 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
HloInstruction* operand,
const int exponent_bits,
const int mantissa_bits) {
- return MakeUnique<HloReducePrecisionInstruction>(
+ return absl::make_unique<HloReducePrecisionInstruction>(
shape, operand, exponent_bits, mantissa_bits);
}
@@ -675,7 +671,7 @@ HloInstruction::CreateCrossReplicaSum(
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
tensorflow::StringPiece barrier,
const tensorflow::gtl::optional<int64>& all_reduce_id) {
- return MakeUnique<HloAllReduceInstruction>(
+ return absl::make_unique<HloAllReduceInstruction>(
shape, operands, reduce_computation, replica_group_ids, barrier,
all_reduce_id);
}
@@ -684,40 +680,29 @@ HloInstruction::CreateCrossReplicaSum(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
const std::vector<ReplicaGroup>& replica_groups,
tensorflow::StringPiece barrier) {
- return MakeUnique<HloAllToAllInstruction>(shape, operands, replica_groups,
- barrier);
+ return absl::make_unique<HloAllToAllInstruction>(shape, operands,
+ replica_groups, barrier);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
const Shape& infeed_shape, HloInstruction* token_operand,
const string& config) {
- return MakeUnique<HloInfeedInstruction>(infeed_shape, token_operand, config);
-}
-
-/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
- const Shape& infeed_shape, const string& config) {
- return MakeUnique<HloInfeedInstruction>(infeed_shape, config);
+ return absl::make_unique<HloInfeedInstruction>(infeed_shape, token_operand,
+ config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
const Shape& outfeed_shape, HloInstruction* operand,
HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) {
- return MakeUnique<HloOutfeedInstruction>(outfeed_shape, operand,
- token_operand, outfeed_config);
-}
-
-/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
- const Shape& outfeed_shape, HloInstruction* operand,
- tensorflow::StringPiece outfeed_config) {
- return MakeUnique<HloOutfeedInstruction>(outfeed_shape, operand,
- outfeed_config);
+ return absl::make_unique<HloOutfeedInstruction>(
+ outfeed_shape, operand, token_operand, outfeed_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
HloInstruction* operand, HloInstruction* token, int64 channel_id,
bool is_host_transfer) {
- return MakeUnique<HloSendInstruction>(operand, token, channel_id,
- is_host_transfer);
+ return absl::make_unique<HloSendInstruction>(operand, token, channel_id,
+ is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
@@ -725,14 +710,15 @@ HloInstruction::CreateCrossReplicaSum(
auto send_operand = DynCast<HloSendInstruction>(operand);
CHECK(send_operand != nullptr)
<< "SendDone must take the context operand from Send";
- return MakeUnique<HloSendDoneInstruction>(send_operand, is_host_transfer);
+ return absl::make_unique<HloSendDoneInstruction>(send_operand,
+ is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
const Shape& shape, HloInstruction* token, int64 channel_id,
bool is_host_transfer) {
- return MakeUnique<HloRecvInstruction>(shape, token, channel_id,
- is_host_transfer);
+ return absl::make_unique<HloRecvInstruction>(shape, token, channel_id,
+ is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
@@ -740,19 +726,20 @@ HloInstruction::CreateCrossReplicaSum(
auto recv_operand = DynCast<HloRecvInstruction>(operand);
CHECK(recv_operand != nullptr)
<< "RecvDone must take the context operand from Recv";
- return MakeUnique<HloRecvDoneInstruction>(recv_operand, is_host_transfer);
+ return absl::make_unique<HloRecvDoneInstruction>(recv_operand,
+ is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions) {
- return MakeUnique<HloReverseInstruction>(shape, operand, dimensions);
+ return absl::make_unique<HloReverseInstruction>(shape, operand, dimensions);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
CHECK(!operands.empty());
- auto instruction = WrapUnique(
+ auto instruction = absl::WrapUnique(
new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
for (auto operand : operands) {
instruction->AppendOperand(operand);
@@ -761,14 +748,15 @@ HloInstruction::CreateCrossReplicaSum(
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() {
- return WrapUnique(
+ return absl::WrapUnique(
new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
const Shape& shape, HloComputation* condition, HloComputation* body,
HloInstruction* init) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
+ auto instruction =
+ absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
instruction->AppendOperand(init);
// Body comes before condition computation in the vector.
instruction->called_computations_.push_back(body);
@@ -781,7 +769,7 @@ HloInstruction::CreateCrossReplicaSum(
HloInstruction* true_computation_arg, HloComputation* true_computation,
HloInstruction* false_computation_arg, HloComputation* false_computation) {
auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
+ absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
instruction->AppendOperand(pred);
instruction->AppendOperand(true_computation_arg);
instruction->AppendOperand(false_computation_arg);
@@ -798,15 +786,15 @@ HloInstruction::CreateCrossReplicaSum(
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices,
tensorflow::gtl::ArraySlice<int64> strides) {
- return MakeUnique<HloSliceInstruction>(shape, operand, start_indices,
- limit_indices, strides);
+ return absl::make_unique<HloSliceInstruction>(shape, operand, start_indices,
+ limit_indices, strides);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
- return MakeUnique<HloDynamicSliceInstruction>(shape, operand, start_indices,
- slice_sizes);
+ return absl::make_unique<HloDynamicSliceInstruction>(
+ shape, operand, start_indices, slice_sizes);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -814,8 +802,8 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
HloInstruction* operand,
HloInstruction* update,
HloInstruction* start_indices) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape));
+ auto instruction = absl::WrapUnique(
+ new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape));
instruction->AppendOperand(operand);
instruction->AppendOperand(update);
instruction->AppendOperand(start_indices);
@@ -825,12 +813,14 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
int64 dimension) {
- return MakeUnique<HloConcatenateInstruction>(shape, operands, dimension);
+ return absl::make_unique<HloConcatenateInstruction>(shape, operands,
+ dimension);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
const Shape& shape, HloInstruction* operand) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
+ auto instruction =
+ absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
instruction->AppendOperand(operand);
return instruction;
}
@@ -839,7 +829,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
HloInstruction::CreateBitcastConvert(const Shape& shape,
HloInstruction* operand) {
auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
+ absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
instruction->AppendOperand(operand);
return instruction;
}
@@ -848,7 +838,7 @@ HloInstruction::CreateBitcastConvert(const Shape& shape,
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation) {
- auto instruction = WrapUnique(new HloReduceInstruction(
+ auto instruction = absl::WrapUnique(new HloReduceInstruction(
shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
return std::move(instruction);
}
@@ -862,15 +852,15 @@ HloInstruction::CreateBitcastConvert(const Shape& shape,
all_args.reserve(operands.size() * 2);
all_args.insert(all_args.end(), operands.begin(), operands.end());
all_args.insert(all_args.end(), init_values.begin(), init_values.end());
- return MakeUnique<HloReduceInstruction>(shape, all_args, dimensions_to_reduce,
- reduce_computation);
+ return absl::make_unique<HloReduceInstruction>(
+ shape, all_args, dimensions_to_reduce, reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
const Window& window, HloComputation* reduce_computation) {
- return MakeUnique<HloReduceWindowInstruction>(shape, operand, init_value,
- window, reduce_computation);
+ return absl::make_unique<HloReduceWindowInstruction>(
+ shape, operand, init_value, window, reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -879,7 +869,7 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape,
HloInstruction* scale,
HloInstruction* offset, float epsilon,
int64 feature_index) {
- return MakeUnique<HloBatchNormTrainingInstruction>(
+ return absl::make_unique<HloBatchNormTrainingInstruction>(
shape, operand, scale, offset, epsilon, feature_index);
}
@@ -888,7 +878,7 @@ HloInstruction::CreateBatchNormInference(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
float epsilon, int64 feature_index) {
- return MakeUnique<HloBatchNormInferenceInstruction>(
+ return absl::make_unique<HloBatchNormInferenceInstruction>(
shape, operand, scale, offset, mean, variance, epsilon, feature_index);
}
@@ -898,9 +888,9 @@ HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
HloInstruction* variance,
HloInstruction* grad_output, float epsilon,
int64 feature_index) {
- return MakeUnique<HloBatchNormGradInstruction>(shape, operand, scale, mean,
- variance, grad_output, epsilon,
- feature_index);
+ return absl::make_unique<HloBatchNormGradInstruction>(
+ shape, operand, scale, mean, variance, grad_output, epsilon,
+ feature_index);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -908,15 +898,15 @@ HloInstruction::CreateSelectAndScatter(
const Shape& shape, HloInstruction* operand, HloComputation* select,
const Window& window, HloInstruction* source, HloInstruction* init_value,
HloComputation* scatter) {
- return MakeUnique<HloSelectAndScatterInstruction>(
+ return absl::make_unique<HloSelectAndScatterInstruction>(
shape, operand, select, window, source, init_value, scatter);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return MakeUnique<HloBroadcastInstruction>(shape, operand,
- broadcast_dimensions);
+ return absl::make_unique<HloBroadcastInstruction>(shape, operand,
+ broadcast_dimensions);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -974,8 +964,8 @@ HloInstruction::CreateBroadcastSequence(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
const PaddingConfig& padding_config) {
- return MakeUnique<HloPadInstruction>(shape, operand, padding_value,
- padding_config);
+ return absl::make_unique<HloPadInstruction>(shape, operand, padding_value,
+ padding_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
@@ -984,7 +974,8 @@ HloInstruction::CreateBroadcastSequence(
ShapeUtil::ElementsIn(operand->shape()))
<< "shape: " << ShapeUtil::HumanString(shape)
<< " operand: " << ShapeUtil::HumanString(operand->shape());
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape));
+ auto instruction =
+ absl::WrapUnique(new HloInstruction(HloOpcode::kReshape, shape));
instruction->AppendOperand(operand);
return instruction;
}
@@ -992,26 +983,27 @@ HloInstruction::CreateBroadcastSequence(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions) {
- return MakeUnique<HloTransposeInstruction>(shape, operand, dimensions);
+ return absl::make_unique<HloTransposeInstruction>(shape, operand, dimensions);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
const Shape& shape, int64 dimension, HloInstruction* keys,
HloInstruction* values) {
- return MakeUnique<HloSortInstruction>(shape, dimension, keys, values);
+ return absl::make_unique<HloSortInstruction>(shape, dimension, keys, values);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
- return MakeUnique<HloFusionInstruction>(shape, fusion_kind, fused_root);
+ return absl::make_unique<HloFusionInstruction>(shape, fusion_kind,
+ fused_root);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
const Shape& shape, FusionKind fusion_kind,
tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* fusion_computation) {
- return MakeUnique<HloFusionInstruction>(shape, fusion_kind, operands,
- fusion_computation);
+ return absl::make_unique<HloFusionInstruction>(shape, fusion_kind, operands,
+ fusion_computation);
}
void HloInstruction::set_single_sharding(const HloSharding& sharding) {
@@ -1069,7 +1061,7 @@ bool HloInstruction::HasSideEffect() const {
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* computation) {
std::unique_ptr<HloInstruction> instruction =
- WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
+ absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
for (auto operand : operands) {
instruction->AppendOperand(operand);
}
@@ -1080,15 +1072,15 @@ bool HloInstruction::HasSideEffect() const {
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
tensorflow::StringPiece custom_call_target) {
- return MakeUnique<HloCustomCallInstruction>(shape, operands,
- custom_call_target);
+ return absl::make_unique<HloCustomCallInstruction>(shape, operands,
+ custom_call_target);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateHostCompute(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) {
- return MakeUnique<HloHostComputeInstruction>(shape, operands, channel_name,
- cost_estimate_ns);
+ return absl::make_unique<HloHostComputeInstruction>(
+ shape, operands, channel_name, cost_estimate_ns);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
@@ -1102,11 +1094,11 @@ bool HloInstruction::HasSideEffect() const {
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
- const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices,
+ const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds) {
- return MakeUnique<HloGatherInstruction>(shape, operand, gather_indices,
- gather_dim_numbers, window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ return absl::make_unique<HloGatherInstruction>(
+ shape, operand, start_indices, gather_dim_numbers, slice_sizes);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
@@ -1114,16 +1106,17 @@ bool HloInstruction::HasSideEffect() const {
HloInstruction* scatter_indices, HloInstruction* updates,
HloComputation* update_computation,
const ScatterDimensionNumbers& scatter_dim_numbers) {
- return MakeUnique<HloScatterInstruction>(shape, operand, scatter_indices,
- updates, update_computation,
- scatter_dim_numbers);
+ return absl::make_unique<HloScatterInstruction>(
+ shape, operand, scatter_indices, updates, update_computation,
+ scatter_dim_numbers);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
const Shape& shape, HloInstruction* operand,
std::unique_ptr<DomainMetadata> operand_side_metadata,
std::unique_ptr<DomainMetadata> user_side_metadata) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDomain, shape));
+ auto instruction =
+ absl::WrapUnique(new HloInstruction(HloOpcode::kDomain, shape));
instruction->operand_side_metadata_ = std::move(operand_side_metadata);
instruction->user_side_metadata_ = std::move(user_side_metadata);
instruction->AppendOperand(operand);
@@ -3206,6 +3199,10 @@ void HloInstruction::set_convolution_dimension_numbers(
}
}
+int64 HloInstruction::feature_group_count() const {
+ return Cast<HloConvolutionInstruction>(this)->feature_group_count();
+}
+
HloComputation* HloInstruction::select() const {
return Cast<HloSelectAndScatterInstruction>(this)->select();
}
@@ -3246,9 +3243,8 @@ const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
}
-tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_window_bounds()
- const {
- return Cast<HloGatherInstruction>(this)->gather_window_bounds();
+tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_slice_sizes() const {
+ return Cast<HloGatherInstruction>(this)->gather_slice_sizes();
}
const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 3c575ae6ea..30dbabfced 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -32,6 +32,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -402,7 +403,8 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateConvolve(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1);
// Creates an FFT op, of the type indicated by fft_type.
static std::unique_ptr<HloInstruction> CreateFft(
@@ -486,11 +488,6 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateInfeed(
const Shape& infeed_shape, HloInstruction* token_operand,
const string& config);
- // Overload which does not require a token.
- // TODO(b/80000000): Remove this overload when all uses of infeed are
- // converted to take tokens.
- static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& infeed_shape,
- const string& config);
// Creates an outfeed instruction, which outputs data. outfeed_shape is the
// shape of the data being outfed *not* the shape of the outfeed instruction
@@ -498,12 +495,6 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateOutfeed(
const Shape& outfeed_shape, HloInstruction* operand,
HloInstruction* token_operand, tensorflow::StringPiece outfeed_config);
- // Overload which does not require a token.
- // TODO(b/80000000): Remove this overload when all uses of outfeed are
- // converted to take tokens.
- static std::unique_ptr<HloInstruction> CreateOutfeed(
- const Shape& outfeed_shape, HloInstruction* operand,
- tensorflow::StringPiece outfeed_config);
// Creates an asynchronous send instruction with the given channel id, which
// initiates sending the operand data to a unique receive instruction in
@@ -677,9 +668,9 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateGather(
const Shape& shape, HloInstruction* operand,
- HloInstruction* gather_indices,
+ HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
static std::unique_ptr<HloInstruction> CreateScatter(
const Shape& shape, HloInstruction* operand,
@@ -1062,7 +1053,7 @@ class HloInstruction {
// Sets the sharding of this operator. Should only be called by HloModule or
// HloComputation methods.
void set_sharding(const HloSharding& sharding) {
- sharding_ = MakeUnique<HloSharding>(sharding);
+ sharding_ = absl::make_unique<HloSharding>(sharding);
}
void set_single_sharding(const HloSharding& sharding);
// Sets a sharding that assigns the current instruction to device.
@@ -1466,6 +1457,10 @@ class HloInstruction {
void set_convolution_dimension_numbers(
const ConvolutionDimensionNumbers& dnums);
+ // The number of feature groups. Must be a divisor of the input feature
+ // dimension and output feature dimension.
+ int64 feature_group_count() const;
+
// Delegates to HloSelectAndScatterInstruction::select.
HloComputation* select() const;
@@ -1495,8 +1490,8 @@ class HloInstruction {
// Delegates to HloGatherInstruction::gather_dimension_numbers.
const GatherDimensionNumbers& gather_dimension_numbers() const;
- // Delegates to HloGatherInstruction::gather_window_bounds.
- tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const;
+ // Delegates to HloGatherInstruction::gather_slice_sizes.
+ tensorflow::gtl::ArraySlice<int64> gather_slice_sizes() const;
// Delegates to HloScatterInstruction::scatter_dimension_numbers().
const ScatterDimensionNumbers& scatter_dimension_numbers() const;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 8a694dde80..504b13043f 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1355,7 +1355,7 @@ TEST_F(HloInstructionTest, Stringification) {
TEST_F(HloInstructionTest, StringifyGather_0) {
Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
- Shape gather_indices_tensor_shape =
+ Shape start_indices_tensor_shape =
ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
Shape gather_result_shape =
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26});
@@ -1363,19 +1363,18 @@ TEST_F(HloInstructionTest, StringifyGather_0) {
HloComputation::Builder builder("Gather");
HloInstruction* input = builder.AddInstruction(
HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
- HloInstruction* gather_indices =
+ HloInstruction* start_indices =
builder.AddInstruction(HloInstruction::CreateParameter(
- 1, gather_indices_tensor_shape, "gather_indices"));
-
- HloInstruction* gather_instruction =
- builder.AddInstruction(HloInstruction::CreateGather(
- gather_result_shape, input, gather_indices,
- HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
- /*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ 1, start_indices_tensor_shape, "start_indices"));
+
+ HloInstruction* gather_instruction = builder.AddInstruction(
+ HloInstruction::CreateGather(gather_result_shape, input, start_indices,
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4),
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -1383,15 +1382,15 @@ TEST_F(HloInstructionTest, StringifyGather_0) {
EXPECT_EQ(gather_instruction->ToString(),
"%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
"gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
- "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), "
- "output_window_dims={4,5,6,7,8}, elided_window_dims={}, "
- "gather_dims_to_operand_dims={0,1,2,3,4}, "
- "index_vector_dim=4, window_bounds={30,29,28,27,26}");
+ "s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), "
+ "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, "
+ "start_index_map={0,1,2,3,4}, "
+ "index_vector_dim=4, slice_sizes={30,29,28,27,26}");
}
TEST_F(HloInstructionTest, StringifyGather_1) {
Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
- Shape gather_indices_tensor_shape =
+ Shape start_indices_tensor_shape =
ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
Shape gather_result_shape =
ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26});
@@ -1399,19 +1398,18 @@ TEST_F(HloInstructionTest, StringifyGather_1) {
HloComputation::Builder builder("Gather");
HloInstruction* input = builder.AddInstruction(
HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
- HloInstruction* gather_indices =
+ HloInstruction* start_indices =
builder.AddInstruction(HloInstruction::CreateParameter(
- 1, gather_indices_tensor_shape, "gather_indices"));
-
- HloInstruction* gather_instruction =
- builder.AddInstruction(HloInstruction::CreateGather(
- gather_result_shape, input, gather_indices,
- HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
- /*index_vector_dim=*/2),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ 1, start_indices_tensor_shape, "start_indices"));
+
+ HloInstruction* gather_instruction = builder.AddInstruction(
+ HloInstruction::CreateGather(gather_result_shape, input, start_indices,
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/2),
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -1419,10 +1417,10 @@ TEST_F(HloInstructionTest, StringifyGather_1) {
EXPECT_EQ(gather_instruction->ToString(),
"%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
"gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
- "s64[10,9,5,7,6]{4,3,2,1,0} %gather_indices), "
- "output_window_dims={4,5,6,7,8}, elided_window_dims={}, "
- "gather_dims_to_operand_dims={0,1,2,3,4}, "
- "index_vector_dim=2, window_bounds={30,29,28,27,26}");
+ "s64[10,9,5,7,6]{4,3,2,1,0} %start_indices), "
+ "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, "
+ "start_index_map={0,1,2,3,4}, "
+ "index_vector_dim=2, slice_sizes={30,29,28,27,26}");
}
TEST_F(HloInstructionTest, StringifyScatter) {
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 1de5032670..79a5e7481d 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -17,6 +17,8 @@ limitations under the License.
#include <deque>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -89,7 +91,7 @@ HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 3);
- return MakeUnique<HloBatchNormTrainingInstruction>(
+ return absl::make_unique<HloBatchNormTrainingInstruction>(
shape, new_operands[0], new_operands[1], new_operands[2], epsilon(),
feature_index());
}
@@ -111,7 +113,7 @@ HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 5);
- return MakeUnique<HloBatchNormInferenceInstruction>(
+ return absl::make_unique<HloBatchNormInferenceInstruction>(
shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
new_operands[4], epsilon(), feature_index());
}
@@ -133,7 +135,7 @@ HloBatchNormGradInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 5);
- return MakeUnique<HloBatchNormGradInstruction>(
+ return absl::make_unique<HloBatchNormGradInstruction>(
shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
new_operands[4], epsilon(), feature_index());
}
@@ -175,8 +177,8 @@ std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloFftInstruction>(shape, new_operands[0], fft_type_,
- fft_length_);
+ return absl::make_unique<HloFftInstruction>(shape, new_operands[0], fft_type_,
+ fft_length_);
}
HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
@@ -230,8 +232,8 @@ std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloSendInstruction>(new_operands[0], new_operands[1],
- channel_id(), is_host_transfer());
+ return absl::make_unique<HloSendInstruction>(
+ new_operands[0], new_operands[1], channel_id(), is_host_transfer());
}
HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
@@ -248,7 +250,7 @@ HloSendDoneInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloSendDoneInstruction>(
+ return absl::make_unique<HloSendDoneInstruction>(
Cast<HloSendInstruction>(new_operands[0]), is_host_transfer());
}
@@ -269,7 +271,7 @@ std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloRecvInstruction>(
+ return absl::make_unique<HloRecvInstruction>(
ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(),
is_host_transfer());
}
@@ -291,7 +293,7 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloRecvDoneInstruction>(
+ return absl::make_unique<HloRecvDoneInstruction>(
Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer());
}
@@ -354,7 +356,7 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* /*context*/) const {
- return MakeUnique<HloAllReduceInstruction>(
+ return absl::make_unique<HloAllReduceInstruction>(
shape, new_operands, to_apply(), replica_group_ids(),
cross_replica_sum_barrier(), all_reduce_id());
}
@@ -390,7 +392,7 @@ HloAllToAllInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* /*context*/) const {
- return MakeUnique<HloAllToAllInstruction>(
+ return absl::make_unique<HloAllToAllInstruction>(
shape, new_operands, replica_groups(), cross_replica_sum_barrier());
}
@@ -454,8 +456,8 @@ std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloReverseInstruction>(shape, new_operands[0],
- dimensions());
+ return absl::make_unique<HloReverseInstruction>(shape, new_operands[0],
+ dimensions());
}
HloConcatenateInstruction::HloConcatenateInstruction(
@@ -494,8 +496,8 @@ HloConcatenateInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- return MakeUnique<HloConcatenateInstruction>(shape, new_operands,
- dimensions(0));
+ return absl::make_unique<HloConcatenateInstruction>(shape, new_operands,
+ dimensions(0));
}
HloReduceInstruction::HloReduceInstruction(
@@ -539,8 +541,8 @@ std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloReduceInstruction>(shape, new_operands, dimensions(),
- to_apply());
+ return absl::make_unique<HloReduceInstruction>(shape, new_operands,
+ dimensions(), to_apply());
}
HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension,
@@ -580,7 +582,8 @@ std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
HloCloneContext* context) const {
HloInstruction* keys = new_operands[0];
HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr;
- return MakeUnique<HloSortInstruction>(shape, dimensions(0), keys, values);
+ return absl::make_unique<HloSortInstruction>(shape, dimensions(0), keys,
+ values);
}
HloTransposeInstruction::HloTransposeInstruction(
@@ -633,8 +636,8 @@ HloTransposeInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloTransposeInstruction>(shape, new_operands[0],
- dimensions());
+ return absl::make_unique<HloTransposeInstruction>(shape, new_operands[0],
+ dimensions());
}
HloBroadcastInstruction::HloBroadcastInstruction(
@@ -672,8 +675,8 @@ HloBroadcastInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloBroadcastInstruction>(shape, new_operands[0],
- dimensions());
+ return absl::make_unique<HloBroadcastInstruction>(shape, new_operands[0],
+ dimensions());
}
HloMapInstruction::HloMapInstruction(
@@ -730,7 +733,7 @@ std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- return MakeUnique<HloMapInstruction>(shape, new_operands, to_apply());
+ return absl::make_unique<HloMapInstruction>(shape, new_operands, to_apply());
}
HloSliceInstruction::HloSliceInstruction(
@@ -792,8 +795,8 @@ std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloSliceInstruction>(shape, new_operands[0], slice_starts_,
- slice_limits_, slice_strides_);
+ return absl::make_unique<HloSliceInstruction>(
+ shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
}
HloConstantInstruction::HloConstantInstruction(std::unique_ptr<Literal> literal)
@@ -845,7 +848,7 @@ HloConstantInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- return MakeUnique<HloConstantInstruction>(literal_->CloneToUnique());
+ return absl::make_unique<HloConstantInstruction>(literal_->CloneToUnique());
}
string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
@@ -1339,8 +1342,8 @@ std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
new_fused_computation = module->AddEmbeddedComputation(
fused_instructions_computation()->Clone("clone", context));
}
- return MakeUnique<HloFusionInstruction>(shape, fusion_kind(), new_operands,
- new_fused_computation);
+ return absl::make_unique<HloFusionInstruction>(
+ shape, fusion_kind(), new_operands, new_fused_computation);
}
Status HloFusionInstruction::DeduplicateFusionOperands() {
@@ -1399,7 +1402,8 @@ std::unique_ptr<HloInstruction> HloRngInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- return MakeUnique<HloRngInstruction>(shape, distribution_, new_operands);
+ return absl::make_unique<HloRngInstruction>(shape, distribution_,
+ new_operands);
}
HloParameterInstruction::HloParameterInstruction(int64 parameter_number,
@@ -1435,7 +1439,8 @@ HloParameterInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- return MakeUnique<HloParameterInstruction>(parameter_number_, shape, name());
+ return absl::make_unique<HloParameterInstruction>(parameter_number_, shape,
+ name());
}
HloGetTupleElementInstruction::HloGetTupleElementInstruction(
@@ -1471,8 +1476,8 @@ HloGetTupleElementInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloGetTupleElementInstruction>(shape, new_operands[0],
- tuple_index());
+ return absl::make_unique<HloGetTupleElementInstruction>(
+ shape, new_operands[0], tuple_index());
}
HloReducePrecisionInstruction::HloReducePrecisionInstruction(
@@ -1514,7 +1519,7 @@ HloReducePrecisionInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloReducePrecisionInstruction>(
+ return absl::make_unique<HloReducePrecisionInstruction>(
shape, new_operands[0], exponent_bits(), mantissa_bits());
}
@@ -1528,13 +1533,6 @@ HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape,
AppendOperand(token_operand);
}
-HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape,
- const string& config)
- : HloInstruction(HloOpcode::kInfeed,
- ShapeUtil::MakeTupleShape(
- {infeed_shape, ShapeUtil::MakeTokenShape()})),
- infeed_config_(config) {}
-
HloInstructionProto HloInfeedInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
proto.set_infeed_config(infeed_config_);
@@ -1561,13 +1559,9 @@ std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- if (new_operands.empty()) {
- return MakeUnique<HloInfeedInstruction>(infeed_shape(), infeed_config());
- } else {
- CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloInfeedInstruction>(infeed_shape(), new_operands[0],
- infeed_config());
- }
+ CHECK_EQ(new_operands.size(), 1);
+ return absl::make_unique<HloInfeedInstruction>(
+ infeed_shape(), new_operands[0], infeed_config());
}
HloOutfeedInstruction::HloOutfeedInstruction(
@@ -1583,18 +1577,6 @@ HloOutfeedInstruction::HloOutfeedInstruction(
AppendOperand(token_operand);
}
-HloOutfeedInstruction::HloOutfeedInstruction(
- const Shape& outfeed_shape, HloInstruction* operand,
- tensorflow::StringPiece outfeed_config)
- : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
- outfeed_shape_(outfeed_shape),
- outfeed_config_(outfeed_config.begin(), outfeed_config.end()) {
- CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape))
- << "Outfeed shape " << outfeed_shape
- << " must be compatible with operand shape " << operand->shape();
- AppendOperand(operand);
-}
-
HloInstructionProto HloOutfeedInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
proto.set_outfeed_config(outfeed_config());
@@ -1622,22 +1604,19 @@ std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- if (new_operands.size() == 1) {
- return MakeUnique<HloOutfeedInstruction>(outfeed_shape(), new_operands[0],
- outfeed_config());
- } else {
- CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloOutfeedInstruction>(outfeed_shape(), new_operands[0],
- new_operands[1], outfeed_config());
- }
+ CHECK_EQ(new_operands.size(), 2);
+ return absl::make_unique<HloOutfeedInstruction>(
+ outfeed_shape(), new_operands[0], new_operands[1], outfeed_config());
}
HloConvolutionInstruction::HloConvolutionInstruction(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window, const ConvolutionDimensionNumbers& dimension_numbers)
+ const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count)
: HloInstruction(HloOpcode::kConvolution, shape),
window_(window),
- convolution_dimension_numbers_(dimension_numbers) {
+ convolution_dimension_numbers_(dimension_numbers),
+ feature_group_count_(feature_group_count) {
if (window_util::HasBaseDilation(window)) {
SetAndSanitizeName(StrCat(name(), "-base-dilated"));
}
@@ -1675,6 +1654,7 @@ std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
}
extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString(
convolution_dimension_numbers_)));
+ extra.push_back(StrCat("feature_group_count=", feature_group_count_));
return extra;
}
@@ -1696,9 +1676,9 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloConvolutionInstruction>(shape, new_operands[0],
- new_operands[1], window(),
- convolution_dimension_numbers_);
+ return absl::make_unique<HloConvolutionInstruction>(
+ shape, new_operands[0], new_operands[1], window(),
+ convolution_dimension_numbers_, feature_group_count_);
}
HloReduceWindowInstruction::HloReduceWindowInstruction(
@@ -1741,7 +1721,7 @@ HloReduceWindowInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloReduceWindowInstruction>(
+ return absl::make_unique<HloReduceWindowInstruction>(
shape, new_operands[0], new_operands[1], window(), to_apply());
}
@@ -1790,7 +1770,7 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 3);
- return MakeUnique<HloSelectAndScatterInstruction>(
+ return absl::make_unique<HloSelectAndScatterInstruction>(
shape, new_operands[0], select(), window(), new_operands[1],
new_operands[2], scatter());
}
@@ -1865,8 +1845,8 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- auto cloned = MakeUnique<HloCustomCallInstruction>(shape, new_operands,
- custom_call_target());
+ auto cloned = absl::make_unique<HloCustomCallInstruction>(
+ shape, new_operands, custom_call_target());
if (window_ != nullptr) {
cloned->set_window(*window_);
}
@@ -1907,7 +1887,7 @@ HloHostComputeInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- return MakeUnique<HloHostComputeInstruction>(
+ return absl::make_unique<HloHostComputeInstruction>(
shape, new_operands, channel_name_, cost_estimate_ns_);
}
@@ -1945,8 +1925,8 @@ std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloPadInstruction>(shape, new_operands[0], new_operands[1],
- padding_config_);
+ return absl::make_unique<HloPadInstruction>(shape, new_operands[0],
+ new_operands[1], padding_config_);
}
HloDynamicSliceInstruction::HloDynamicSliceInstruction(
@@ -1985,56 +1965,55 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloDynamicSliceInstruction>(
+ return absl::make_unique<HloDynamicSliceInstruction>(
shape, new_operands[0], new_operands[1], dynamic_slice_sizes_);
}
HloGatherInstruction::HloGatherInstruction(
- const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices,
+ const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds)
+ tensorflow::gtl::ArraySlice<int64> slice_sizes)
: HloInstruction(HloOpcode::kGather, shape) {
AppendOperand(operand);
- AppendOperand(gather_indices);
+ AppendOperand(start_indices);
gather_dimension_numbers_ =
- MakeUnique<GatherDimensionNumbers>(gather_dim_numbers);
- c_copy(window_bounds, std::back_inserter(gather_window_bounds_));
+ absl::make_unique<GatherDimensionNumbers>(gather_dim_numbers);
+ absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_));
}
string HloGatherInstruction::GatherDimensionNumbersToString() const {
CHECK(gather_dimension_numbers_ != nullptr);
- string output_window_dims =
- StrCat("output_window_dims={",
- Join(gather_dimension_numbers_->output_window_dims(), ","), "}");
- string elided_window_dims =
- StrCat("elided_window_dims={",
- Join(gather_dimension_numbers_->elided_window_dims(), ","), "}");
- string gather_dims_to_operand_dims = StrCat(
- "gather_dims_to_operand_dims={",
- Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}");
+ string offset_dims =
+ StrCat("offset_dims={",
+ Join(gather_dimension_numbers_->offset_dims(), ","), "}");
+ string collapsed_slice_dims =
+ StrCat("collapsed_slice_dims={",
+ Join(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}");
+ string start_index_map =
+ StrCat("start_index_map={",
+ Join(gather_dimension_numbers_->start_index_map(), ","), "}");
string index_vector_dim = StrCat(
"index_vector_dim=", gather_dimension_numbers_->index_vector_dim());
return Join<std::initializer_list<string>>(
- {output_window_dims, elided_window_dims, gather_dims_to_operand_dims,
- index_vector_dim},
+ {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim},
", ");
}
/* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
- tensorflow::gtl::ArraySlice<int64> output_window_dims,
- tensorflow::gtl::ArraySlice<int64> elided_window_dims,
- tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
+ tensorflow::gtl::ArraySlice<int64> offset_dims,
+ tensorflow::gtl::ArraySlice<int64> collapsed_slice_dims,
+ tensorflow::gtl::ArraySlice<int64> start_index_map,
int64 index_vector_dim) {
GatherDimensionNumbers gather_dim_numbers;
- for (int64 output_window_dim : output_window_dims) {
- gather_dim_numbers.add_output_window_dims(output_window_dim);
+ for (int64 output_window_dim : offset_dims) {
+ gather_dim_numbers.add_offset_dims(output_window_dim);
}
- for (int64 elided_window_dim : elided_window_dims) {
- gather_dim_numbers.add_elided_window_dims(elided_window_dim);
+ for (int64 elided_window_dim : collapsed_slice_dims) {
+ gather_dim_numbers.add_collapsed_slice_dims(elided_window_dim);
}
- for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) {
- gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim);
+ for (int64 gather_dim_to_input_dim : start_index_map) {
+ gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim);
}
gather_dim_numbers.set_index_vector_dim(index_vector_dim);
@@ -2044,8 +2023,8 @@ string HloGatherInstruction::GatherDimensionNumbersToString() const {
HloInstructionProto HloGatherInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
*proto.mutable_gather_dimension_numbers() = gather_dimension_numbers();
- for (int64 bound : gather_window_bounds()) {
- proto.add_gather_window_bounds(bound);
+ for (int64 bound : gather_slice_sizes()) {
+ proto.add_gather_slice_sizes(bound);
}
return proto;
}
@@ -2053,7 +2032,7 @@ HloInstructionProto HloGatherInstruction::ToProto() const {
std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
return {GatherDimensionNumbersToString(),
- StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")};
+ StrCat("slice_sizes={", Join(gather_slice_sizes(), ","), "}")};
}
bool HloGatherInstruction::IdenticalSlowPath(
@@ -2064,7 +2043,7 @@ bool HloGatherInstruction::IdenticalSlowPath(
return protobuf_util::ProtobufEquals(
gather_dimension_numbers(),
casted_other.gather_dimension_numbers()) &&
- gather_window_bounds() == casted_other.gather_window_bounds();
+ gather_slice_sizes() == casted_other.gather_slice_sizes();
}
std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
@@ -2072,9 +2051,9 @@ std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloGatherInstruction>(
+ return absl::make_unique<HloGatherInstruction>(
shape, new_operands[0], new_operands[1], gather_dimension_numbers(),
- gather_window_bounds());
+ gather_slice_sizes());
}
HloScatterInstruction::HloScatterInstruction(
@@ -2088,7 +2067,7 @@ HloScatterInstruction::HloScatterInstruction(
AppendOperand(updates);
AppendComputation(update_computation);
scatter_dimension_numbers_ =
- MakeUnique<ScatterDimensionNumbers>(scatter_dim_numbers);
+ absl::make_unique<ScatterDimensionNumbers>(scatter_dim_numbers);
}
string HloScatterInstruction::ScatterDimensionNumbersToString() const {
@@ -2159,7 +2138,7 @@ std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 3);
- return MakeUnique<HloScatterInstruction>(
+ return absl::make_unique<HloScatterInstruction>(
shape, new_operands[0], new_operands[1], new_operands[2], to_apply(),
scatter_dimension_numbers());
}
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 9586ad6673..19b69c2171 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -18,6 +18,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
namespace xla {
@@ -883,10 +884,6 @@ class HloInfeedInstruction : public HloInstruction {
explicit HloInfeedInstruction(const Shape& infeed_shape,
HloInstruction* token_operand,
const string& config);
- // TODO(b/80000000): Remove this constructor when all uses of infeed are
- // converted to take tokens.
- explicit HloInfeedInstruction(const Shape& infeed_shape,
- const string& config);
// Returns the infeed configuration string. The infeed configuration includes
// any metadata needed for the backend compiler (e.g., infeed buffer address)
// and is target-dependent.
@@ -925,12 +922,6 @@ class HloOutfeedInstruction : public HloInstruction {
HloInstruction* operand,
HloInstruction* token_operand,
tensorflow::StringPiece outfeed_config);
- // TODO(b/80000000): Remove this constructor when all uses of outfeed are
- // converted to take tokens.
- explicit HloOutfeedInstruction(const Shape& outfeed_shape,
- HloInstruction* operand,
- tensorflow::StringPiece outfeed_config);
-
// Returns the shape for the Outfeed instruction.
const Shape& outfeed_shape() const {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_));
@@ -965,7 +956,8 @@ class HloConvolutionInstruction : public HloInstruction {
explicit HloConvolutionInstruction(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count);
const Window& window() const override { return window_; }
void set_window(const Window& window) override { window_ = window; }
const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
@@ -975,6 +967,9 @@ class HloConvolutionInstruction : public HloInstruction {
const ConvolutionDimensionNumbers& dnums) {
convolution_dimension_numbers_ = dnums;
}
+ // The number of feature groups. Must be a divisor of the input feature
+ // dimension and output feature dimension.
+ int64 feature_group_count() const { return feature_group_count_; }
string ToCategory() const override;
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -994,6 +989,9 @@ class HloConvolutionInstruction : public HloInstruction {
Window window_;
// Describes the dimension numbers used for a convolution.
ConvolutionDimensionNumbers convolution_dimension_numbers_;
+ // The number of feature groups. Must be a divisor of the input feature
+ // dimension and output feature dimension.
+ int64 feature_group_count_;
};
class HloReduceWindowInstruction : public HloInstruction {
@@ -1083,7 +1081,7 @@ class HloCustomCallInstruction : public HloInstruction {
}
void set_window(const Window& window) override {
- window_ = MakeUnique<Window>(window);
+ window_ = absl::make_unique<Window>(window);
}
const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
@@ -1094,7 +1092,7 @@ class HloCustomCallInstruction : public HloInstruction {
void set_convolution_dimension_numbers(
const ConvolutionDimensionNumbers& dnums) {
convolution_dimension_numbers_ =
- MakeUnique<ConvolutionDimensionNumbers>(dnums);
+ absl::make_unique<ConvolutionDimensionNumbers>(dnums);
}
const string& custom_call_target() const { return custom_call_target_; }
// Returns a serialized representation of this instruction.
@@ -1215,15 +1213,15 @@ class HloGatherInstruction : public HloInstruction {
public:
explicit HloGatherInstruction(
const Shape& shape, HloInstruction* operand,
- HloInstruction* gather_indices,
+ HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
const GatherDimensionNumbers& gather_dimension_numbers() const {
CHECK(gather_dimension_numbers_ != nullptr);
return *gather_dimension_numbers_;
}
- tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const {
- return gather_window_bounds_;
+ tensorflow::gtl::ArraySlice<int64> gather_slice_sizes() const {
+ return gather_slice_sizes_;
}
// Returns the dump string of the gather dimension numbers.
string GatherDimensionNumbersToString() const;
@@ -1232,9 +1230,9 @@ class HloGatherInstruction : public HloInstruction {
// Creates an instance of GatherDimensionNumbers.
static GatherDimensionNumbers MakeGatherDimNumbers(
- tensorflow::gtl::ArraySlice<int64> output_window_dims,
- tensorflow::gtl::ArraySlice<int64> elided_window_dims,
- tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
+ tensorflow::gtl::ArraySlice<int64> offset_dims,
+ tensorflow::gtl::ArraySlice<int64> collapsed_slice_dims,
+ tensorflow::gtl::ArraySlice<int64> start_index_map,
int64 index_vector_dim);
private:
@@ -1250,7 +1248,7 @@ class HloGatherInstruction : public HloInstruction {
HloCloneContext* context) const override;
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
- std::vector<int64> gather_window_bounds_;
+ std::vector<int64> gather_slice_sizes_;
};
class HloScatterInstruction : public HloInstruction {
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
index 43c41ece6e..18f17b75ae 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include <deque>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -296,7 +296,7 @@ StatusOr<std::unique_ptr<HloLivenessAnalysis>> HloLivenessAnalysis::Run(
VLOG(1) << "HloLivenessAnalysis::Run on module " << module.name();
XLA_VLOG_LINES(2, module.ToString());
- auto liveness_analysis = WrapUnique(new HloLivenessAnalysis(module));
+ auto liveness_analysis = absl::WrapUnique(new HloLivenessAnalysis(module));
liveness_analysis->RunAnalysis();
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 55ff073d3f..d60b76d63f 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -22,8 +22,9 @@ limitations under the License.
#include <unordered_set>
#include <utility>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -274,7 +275,7 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
}
TF_RET_CHECK(entry != nullptr);
- auto module = MakeUnique<HloModule>(proto.name(), module_config);
+ auto module = absl::make_unique<HloModule>(proto.name(), module_config);
// Sort the computations in the proto id's order.
std::sort(computations.begin(), computations.end(),
@@ -507,7 +508,7 @@ std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
- auto module = MakeUnique<HloModule>(name_ + "-" + suffix, config_);
+ auto module = absl::make_unique<HloModule>(name_ + "-" + suffix, config_);
HloCloneContext context(module.get(), suffix);
auto cloned_computation = entry_computation_->Clone(suffix, &context);
@@ -538,9 +539,9 @@ uint64 HloModule::RandomNew64() const {
HloComputation* HloModule::GetComputationWithName(
tensorflow::StringPiece name) {
auto computations_in_module = computations();
- auto it = c_find_if(computations_in_module, [&](HloComputation* computation) {
- return computation->name() == name;
- });
+ auto it = absl::c_find_if(
+ computations_in_module,
+ [&](HloComputation* computation) { return computation->name() == name; });
return it == computations_in_module.end() ? nullptr : *it;
}
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc
index 07a8c798db..f9708283eb 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_config.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <atomic>
#include <vector>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/strings/str_util.h"
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index 10bf9ffd6c..cd10913763 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <string>
#include <utility>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -59,7 +59,7 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const {
/* static */ StatusOr<std::unique_ptr<HloModuleGroupMetadata>>
HloModuleGroupMetadata::Build(const std::vector<HloModule*>& modules) {
- auto metadata = MakeUnique<HloModuleGroupMetadata>(modules);
+ auto metadata = absl::make_unique<HloModuleGroupMetadata>(modules);
TF_RETURN_IF_ERROR(metadata->Build());
return std::move(metadata);
}
@@ -204,6 +204,10 @@ const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel(
return channels_[channel_id_map_.at(channel_id)];
}
+bool HloModuleGroupMetadata::HasChannel(int64 channel_id) const {
+ return channel_id_map_.find(channel_id) != channel_id_map_.end();
+}
+
HloComputation* HloModuleGroupMetadata::PeerComputation(
const HloInstruction* instruction) const {
CHECK(IsChannelInstruction(instruction));
@@ -383,7 +387,7 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1,
if (!ContainsKey(companion_set_index_, instruction1) &&
!ContainsKey(companion_set_index_, instruction2)) {
companion_sets_.push_back(
- tensorflow::MakeUnique<std::unordered_set<HloInstruction*>>());
+ absl::make_unique<std::unordered_set<HloInstruction*>>());
auto companion_set = companion_sets_.back().get();
companion_set->insert(instruction1);
companion_set->insert(instruction2);
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index 1b256cd00e..924c8fda71 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -125,6 +125,9 @@ class HloModuleGroupMetadata {
// Returns the Channel instance for the given channel id.
const Channel& GetChannel(int64 channel_id) const;
+ // Returns if the given channel id exists in metadata.
+ bool HasChannel(int64 channel_id) const;
+
// Returns the all-reduce instructions with the same all_reduce_id.
const std::vector<HloInstruction*>& GetAllReduceGroup(
int64 all_reduce_id) const;
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
index 0dc5676148..1a4da388e4 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
@@ -22,7 +22,9 @@ limitations under the License.
#include <string>
#include <utility>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -94,12 +96,14 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
add_unique_predecessor(control_predecessor);
}
}
- if (instruction->opcode() == HloOpcode::kRecvDone) {
+ if (instruction->opcode() == HloOpcode::kRecvDone &&
+ !DynCast<HloRecvDoneInstruction>(instruction)->is_host_transfer()) {
// Send is a remote predecessor of RecvDone.
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
add_unique_predecessor(send);
}
- if (instruction->opcode() == HloOpcode::kSend) {
+ if (instruction->opcode() == HloOpcode::kSend &&
+ !DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
// Recv is a remote predecessor of Send.
HloInstruction* recv_done =
metadata_.GetChannel(instruction->channel_id()).recv_done;
@@ -170,14 +174,16 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
add_unique_successor(control_successor);
}
}
- if (instruction->opcode() == HloOpcode::kRecv) {
+ if (instruction->opcode() == HloOpcode::kRecv &&
+ !DynCast<HloRecvInstruction>(instruction)->is_host_transfer()) {
// Send is a remote successor of Recv.
const HloInstruction* recv_done = instruction->users().front();
CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
add_unique_successor(send);
}
- if (instruction->opcode() == HloOpcode::kSend) {
+ if (instruction->opcode() == HloOpcode::kSend &&
+ !DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
// RecvDone is a remote successor of Send.
HloInstruction* recv_done =
metadata_.GetChannel(instruction->channel_id()).recv_done;
@@ -332,7 +338,7 @@ HloModuleGroupUtil::ComputeReachability(
TF_RETURN_IF_ERROR(
VisitTopologicalOrder(&visit_states, visit_function, root));
}
- auto reachability = MakeUnique<HloReachabilityMap>(post_order);
+ auto reachability = absl::make_unique<HloReachabilityMap>(post_order);
for (HloInstruction* hlo : post_order) {
reachability->FastSetReachabilityToUnion(GlobalPredecessors(hlo), hlo);
}
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 236f450086..209ad5e58c 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h"
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index ec279867e5..0e0d96ab09 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -156,7 +156,7 @@ enum HloOpcodeProperty {
// Returns a string representation of the opcode.
string HloOpcodeString(HloOpcode opcode);
-// Returns a string representation of the opcode.
+// Retrieves the opcode enum by name if the opcode exists.
StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name);
inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) {
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 4b3cd99dc0..3768da8a73 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
@@ -304,7 +306,7 @@ bool HloParser::ParseHloModule() {
return false;
}
- module_ = MakeUnique<HloModule>(name, config_);
+ module_ = absl::make_unique<HloModule>(name, config_);
return ParseComputations();
}
@@ -357,7 +359,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) {
if (!ParseName(&name)) {
return false;
}
- auto builder = MakeUnique<HloComputation::Builder>(name);
+ auto builder = absl::make_unique<HloComputation::Builder>(name);
LocTy shape_loc = nullptr;
Shape shape;
@@ -635,12 +637,13 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
std::vector<ReplicaGroup> replica_groups;
if (tmp_groups) {
- c_transform(*tmp_groups, std::back_inserter(replica_groups),
- [](const std::vector<int64>& ids) {
- ReplicaGroup group;
- *group.mutable_replica_ids() = {ids.begin(), ids.end()};
- return group;
- });
+ absl::c_transform(
+ *tmp_groups, std::back_inserter(replica_groups),
+ [](const std::vector<int64>& ids) {
+ ReplicaGroup group;
+ *group.mutable_replica_ids() = {ids.begin(), ids.end()};
+ return group;
+ });
}
instruction = builder->AddInstruction(HloInstruction::CreateAllToAll(
shape, operands, replica_groups, barrier ? *barrier : ""));
@@ -825,9 +828,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kConvolution: {
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
+ optional<int64> feature_group_count;
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["dim_labels"] = {/*required=*/true,
AttrTy::kConvolutionDimensionNumbers, &dnums};
+ attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
+ &feature_group_count};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
@@ -835,8 +841,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (!window) {
window.emplace();
}
+ if (!feature_group_count) {
+ feature_group_count = 1;
+ }
instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
- shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums));
+ shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums,
+ feature_group_count.value()));
break;
}
case HloOpcode::kFft: {
@@ -1073,7 +1083,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kInfeed: {
optional<string> config;
attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config};
- if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ !ParseAttributes(attrs)) {
return false;
}
// We need to know the infeed data shape to construct the infeed
@@ -1085,41 +1096,21 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
return Error(lexer_.GetLoc(),
"infeed must have a non-empty tuple shape");
}
-
- if (operands.empty()) {
- // TODO(b/80000000): Remove this when all uses of infeed are
- // converted to take tokens.
- instruction = builder->AddInstruction(HloInstruction::CreateInfeed(
- ShapeUtil::GetTupleElementShape(shape, 0), config ? *config : ""));
- } else if (operands.size() == 1) {
- instruction = builder->AddInstruction(HloInstruction::CreateInfeed(
- ShapeUtil::GetTupleElementShape(shape, 0), operands[0],
- config ? *config : ""));
- } else {
- return Error(lexer_.GetLoc(),
- "infeed must have exactly zero or one operands");
- }
+ instruction = builder->AddInstruction(HloInstruction::CreateInfeed(
+ ShapeUtil::GetTupleElementShape(shape, 0), operands[0],
+ config ? *config : ""));
break;
}
case HloOpcode::kOutfeed: {
optional<string> config;
attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config};
- if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
+ if (!ParseOperands(&operands, /*expected_size=*/2) ||
+ !ParseAttributes(attrs)) {
return false;
}
- if (operands.size() == 1) {
- // TODO(b/80000000): Remove this when all uses of outfeed are
- // converted to take tokens.
- instruction = builder->AddInstruction(HloInstruction::CreateOutfeed(
- operands[0]->shape(), operands[0], config ? *config : ""));
- } else if (operands.size() == 2) {
- instruction = builder->AddInstruction(
- HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0],
- operands[1], config ? *config : ""));
- } else {
- return Error(lexer_.GetLoc(),
- "outfeed must have exactly one or two operands");
- }
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0],
+ operands[1], config ? *config : ""));
break;
}
case HloOpcode::kRng: {
@@ -1245,22 +1236,21 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kGather: {
- optional<std::vector<tensorflow::int64>> output_window_dims;
- attrs["output_window_dims"] = {
- /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims};
- optional<std::vector<tensorflow::int64>> elided_window_dims;
- attrs["elided_window_dims"] = {
- /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims};
- optional<std::vector<tensorflow::int64>> gather_dims_to_operand_dims;
- attrs["gather_dims_to_operand_dims"] = {/*required=*/true,
- AttrTy::kBracedInt64List,
- &gather_dims_to_operand_dims};
+ optional<std::vector<tensorflow::int64>> offset_dims;
+ attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &offset_dims};
+ optional<std::vector<tensorflow::int64>> collapsed_slice_dims;
+ attrs["collapsed_slice_dims"] = {
+ /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims};
+ optional<std::vector<tensorflow::int64>> start_index_map;
+ attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &start_index_map};
optional<tensorflow::int64> index_vector_dim;
attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
&index_vector_dim};
- optional<std::vector<tensorflow::int64>> window_bounds;
- attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List,
- &window_bounds};
+ optional<std::vector<tensorflow::int64>> slice_sizes;
+ attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &slice_sizes};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
@@ -1269,14 +1259,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
GatherDimensionNumbers dim_numbers =
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/*output_window_dims,
- /*elided_window_dims=*/*elided_window_dims,
- /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims,
+ /*offset_dims=*/*offset_dims,
+ /*collapsed_slice_dims=*/*collapsed_slice_dims,
+ /*start_index_map=*/*start_index_map,
/*index_vector_dim=*/*index_vector_dim);
instruction = builder->AddInstruction(HloInstruction::CreateGather(
- shape, /*operand=*/operands[0], /*gather_indices=*/operands[1],
- dim_numbers, *window_bounds));
+ shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
+ dim_numbers, *slice_sizes));
break;
}
case HloOpcode::kScatter: {
@@ -1522,14 +1512,14 @@ bool HloParser::ParseDomain(DomainData* domain) {
return false;
}
if (*kind == ShardingMetadata::KindName()) {
- auto entry_sharding_ptr = MakeUnique<HloSharding>(
+ auto entry_sharding_ptr = absl::make_unique<HloSharding>(
HloSharding::FromProto(*entry_sharding).ValueOrDie());
- auto exit_sharding_ptr = MakeUnique<HloSharding>(
+ auto exit_sharding_ptr = absl::make_unique<HloSharding>(
HloSharding::FromProto(*exit_sharding).ValueOrDie());
domain->entry_metadata =
- MakeUnique<ShardingMetadata>(std::move(entry_sharding_ptr));
+ absl::make_unique<ShardingMetadata>(std::move(entry_sharding_ptr));
domain->exit_metadata =
- MakeUnique<ShardingMetadata>(std::move(exit_sharding_ptr));
+ absl::make_unique<ShardingMetadata>(std::move(exit_sharding_ptr));
} else {
return TokenError(StrCat("unsupported domain kind: ", *kind));
}
@@ -1938,7 +1928,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
tensorflow::int64 rank = ShapeUtil::Rank(shape);
- *literal = MakeUnique<Literal>(shape);
+ *literal = absl::make_unique<Literal>(shape);
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of a sparse literal")) {
diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h
index 3f3a51215e..5f0f75c480 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.h
+++ b/tensorflow/compiler/xla/service/hlo_parser.h
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_lexer.h"
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 5990a3d478..0d7919346b 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -380,7 +380,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
%input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1)
- ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f
+ ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=1
}
)"
@@ -393,7 +393,7 @@ R"(HloModule ConvolveR2_module
ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] {
%input = f32[1,2]{1,0} parameter(0)
%filter = f32[1,1]{1,0} parameter(1)
- ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf
+ ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf, feature_group_count=1
}
)"
@@ -406,7 +406,7 @@ R"(HloModule ConvolveBackward_module
ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] {
%input = f32[128,7,7,512]{0,3,2,1} parameter(0)
%filter = f32[3,3,512,512]{3,2,1,0} parameter(1)
- ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f
+ ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f, feature_group_count=1
}
)"
@@ -752,10 +752,10 @@ ENTRY %sparse_f32_r1 () -> f32[9] {
"gather",
R"(HloModule StringifyGather
-ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
+ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
%input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
- %gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
- ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26}
+ %start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
+ ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
}
)"
@@ -1030,8 +1030,8 @@ R"(HloModule gather
ENTRY Gather {
input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
- gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
- ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26}
+ start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
+ ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
}
)"
@@ -1370,7 +1370,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
%input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1)
- ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2}
+ ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), feature_group_count=1, sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2}
}
)";
diff --git a/tensorflow/compiler/xla/service/hlo_pass_fix.h b/tensorflow/compiler/xla/service/hlo_pass_fix.h
index 28194deb0e..791b1a97b0 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_fix.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_fix.h
@@ -45,7 +45,7 @@ class HloPassFix : public Pass {
++iteration_count;
if (iteration_count == limit) {
LOG(ERROR)
- << "Unexpectedly number of iterations in HLO passes ("
+ << "Unexpectedly high number of iterations in HLO passes ("
<< iteration_count
<< ")\nIf compilation hangs here, please file a bug with XLA.";
}
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
index a42d7e59fe..3bb1342aa3 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
@@ -21,7 +21,7 @@ limitations under the License.
#include <string>
#include <vector>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index b2725e2918..8f3ae9c621 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -19,9 +19,9 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/memory/memory.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -233,7 +233,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
int64 device = device_assignment(i, 0);
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
backend().stream_executor(device));
- streams.push_back(MakeUnique<se::Stream>(executor));
+ streams.push_back(absl::make_unique<se::Stream>(executor));
streams.back()->Init();
service_run_options.emplace_back(GetServiceRunOptionsForDevice(
device, streams.back().get(), &device_assignment));
@@ -260,7 +260,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
num_threads += options.num_replicas;
}
if (num_threads > 0) {
- pool = MakeUnique<tensorflow::thread::ThreadPool>(
+ pool = absl::make_unique<tensorflow::thread::ThreadPool>(
tensorflow::Env::Default(), "infeed_outfeed",
/*num_threads=*/num_threads);
}
@@ -291,7 +291,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
VLOG(1) << "Starting outfeed on device " << device;
for (int64 step = 1;
options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
- auto literal = MakeUnique<Literal>();
+ auto literal = absl::make_unique<Literal>();
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
executor, options.outfeed_shape, literal.get()));
if (options.outfeed_values != nullptr) {
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 879fb3bbab..0cba9ebbcb 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -453,7 +453,7 @@ tensorflow::gtl::optional<HloSharding> HloSharding::ExtractSingleSharding()
}
size_t HloSharding::Hash() const {
- if (!tuple_) {
+ if (tuple_) {
size_t h = 0;
for (const auto& element : tuple_elements_) {
h = tensorflow::Hash64Combine(h, element.Hash());
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
index a2c1d39d0d..4e19557f82 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -121,9 +122,9 @@ std::unique_ptr<HloSharding> CloneShardingForDomain(
const HloSharding& sharding) {
auto single_sharding = sharding.ExtractSingleSharding();
if (!single_sharding) {
- return MakeUnique<HloSharding>(sharding);
+ return absl::make_unique<HloSharding>(sharding);
}
- return MakeUnique<HloSharding>(*single_sharding);
+ return absl::make_unique<HloSharding>(*single_sharding);
}
Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain,
@@ -318,9 +319,9 @@ std::unique_ptr<HloInstruction> CreateDomain(HloInstruction* instruction,
: "None");
std::unique_ptr<DomainMetadata> operand_side_metadata =
- MakeUnique<ShardingMetadata>(std::move(real_operand_sharding));
+ absl::make_unique<ShardingMetadata>(std::move(real_operand_sharding));
std::unique_ptr<DomainMetadata> user_side_metadata =
- MakeUnique<ShardingMetadata>(std::move(real_instruction_sharding));
+ absl::make_unique<ShardingMetadata>(std::move(real_instruction_sharding));
return HloInstruction::CreateDomain(operand->shape(), operand,
std::move(operand_side_metadata),
std::move(user_side_metadata));
@@ -357,9 +358,9 @@ StatusOr<std::unique_ptr<HloSharding>> ExtractOriginalCommonSharding(
std::unique_ptr<DomainMetadata> ShardingMetadata::Clone() const {
std::unique_ptr<HloSharding> sharding;
if (sharding_ != nullptr) {
- sharding = MakeUnique<HloSharding>(*sharding_);
+ sharding = absl::make_unique<HloSharding>(*sharding_);
}
- return MakeUnique<ShardingMetadata>(std::move(sharding));
+ return absl::make_unique<ShardingMetadata>(std::move(sharding));
}
bool ShardingMetadata::Matches(const DomainMetadata& other) const {
diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc
index 7fd99fc930..14703aaf64 100644
--- a/tensorflow/compiler/xla/service/hlo_value.cc
+++ b/tensorflow/compiler/xla/service/hlo_value.cc
@@ -18,8 +18,8 @@ limitations under the License.
#include <algorithm>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index e7674f3ddd..ac1a663633 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -84,7 +84,8 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
const Shape expected,
ShapeInference::InferConvolveShape(
convolution->operand(0)->shape(), convolution->operand(1)->shape(),
- convolution->window(), convolution->convolution_dimension_numbers()));
+ convolution->window(), convolution->convolution_dimension_numbers(),
+ convolution->feature_group_count()));
return CheckShape(convolution, expected);
}
@@ -156,11 +157,7 @@ Status CheckOperandAndParameter(const HloInstruction* instruction,
Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
- // Infeed has an optional single token operand.
- // TODO(b/80000000): Update when token is not optional.
- if (infeed->operand_count() == 1) {
- TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0));
- }
+ TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0));
// The output of infeed is a tuple containing the data value and a token.
return CheckShape(infeed,
@@ -170,11 +167,7 @@ Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) {
HloOutfeedInstruction* outfeed = Cast<HloOutfeedInstruction>(instruction);
- // Outfeed has an optional token operand (operand 1).
- // TODO(b/80000000): Update when token is not optional.
- if (outfeed->operand_count() == 2) {
- TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1));
- }
+ TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1));
// Outfeed has a separate shape field for the value which is outfed to the
// host. The shape of the instruction itself is always a token.
@@ -579,7 +572,7 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) {
gather,
ShapeInference::InferGatherShape(
gather->operand(0)->shape(), gather->operand(1)->shape(),
- gather->gather_dimension_numbers(), gather->gather_window_bounds()));
+ gather->gather_dimension_numbers(), gather->gather_slice_sizes()));
}
Status ShapeVerifier::HandleScatter(HloInstruction* scatter) {
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index c942fab08e..9e54b54b26 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
namespace xla {
@@ -128,11 +129,11 @@ class HloVerifier : public HloPassInterface {
// Uses standard shape inference.
explicit HloVerifier()
: shape_verifier_factory_(
- [] { return MakeUnique<ShapeVerifier>(false); }) {}
+ [] { return absl::make_unique<ShapeVerifier>(false); }) {}
explicit HloVerifier(bool allow_mixed_precision)
: shape_verifier_factory_([allow_mixed_precision] {
- return MakeUnique<ShapeVerifier>(allow_mixed_precision);
+ return absl::make_unique<ShapeVerifier>(allow_mixed_precision);
}) {}
// Uses custom shape verification.
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 3531b7223f..39dff567d4 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/util.h"
@@ -153,7 +154,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor(
TF_ASSIGN_OR_RETURN(
computed_array,
ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(),
- instr->gather_window_bounds(),
+ instr->gather_slice_sizes(),
FindOrDie(cache_, instr->operand(0)),
FindOrDie(cache_, instr->operand(1))));
} else if (instr->opcode() == HloOpcode::kReshape) {
@@ -251,24 +252,23 @@ StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldGatherOfGather(
StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
const Shape& shape, const GatherDimensionNumbers& dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes, Array* source,
Array* indices) {
if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) {
VLOG(3) << "ComputeArrayForGather: indices are not scalar";
return nullptr;
}
- CHECK_EQ(dim_numbers.gather_dims_to_operand_dims_size(), 1);
+ CHECK_EQ(dim_numbers.start_index_map_size(), 1);
- // We can also handle dim_numbers.elided_window_dims_size() == 0 here, should
- // it become relevant.
+ // We can also handle dim_numbers.collapsed_slice_dims_size() == 0 here,
+ // should it become relevant.
- if (dim_numbers.elided_window_dims_size() != 1 ||
- dim_numbers.elided_window_dims(0) !=
- dim_numbers.gather_dims_to_operand_dims(0)) {
+ if (dim_numbers.collapsed_slice_dims_size() != 1 ||
+ dim_numbers.collapsed_slice_dims(0) != dim_numbers.start_index_map(0)) {
VLOG(3) << "ComputeArrayForGather: gather operations must elide "
- "gather_dims_to_operand_dims[0] and "
- "gather_dims_to_operand_dims[0] only";
+ "start_index_map[0] and "
+ "start_index_map[0] only";
return nullptr;
}
@@ -277,27 +277,27 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
// arrays from an array of size [7,4,6]. We check that condition down below:
for (int64 i = 0, e = source->shape().dimensions_size(); i < e; i++) {
- if (i != dim_numbers.elided_window_dims(0) &&
- source->shape().dimensions(i) != window_bounds[i]) {
- VLOG(3) << "ComputeArrayForGather: window_bounds[" << i
+ if (i != dim_numbers.collapsed_slice_dims(0) &&
+ source->shape().dimensions(i) != slice_sizes[i]) {
+ VLOG(3) << "ComputeArrayForGather: slice_sizes[" << i
<< "] != source->shape().dimensions(" << i << ") -- "
- << source->shape().dimensions(i) << " vs. " << window_bounds[i]
- << " with dim_numbers.elided_window_dims(0) = "
- << dim_numbers.elided_window_dims(0);
+ << source->shape().dimensions(i) << " vs. " << slice_sizes[i]
+ << " with dim_numbers.collapsed_slice_dims(0) = "
+ << dim_numbers.collapsed_slice_dims(0);
return nullptr;
}
}
- int64 source_dim = dim_numbers.gather_dims_to_operand_dims(0);
+ int64 source_dim = dim_numbers.start_index_map(0);
std::vector<int64> output_dims;
for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
- if (!c_binary_search(dim_numbers.output_window_dims(), i)) {
+ if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
output_dims.push_back(i);
}
}
if (auto* indexed = dynamic_cast<ScalarIndexedArray*>(source)) {
- if (c_linear_search(indexed->output_dims(), source_dim)) {
+ if (absl::c_linear_search(indexed->output_dims(), source_dim)) {
return FoldGatherOfGather(indexed, indices, source_dim, output_dims,
shape);
}
@@ -315,7 +315,7 @@ namespace {
// [values.begin()+index, values.end()) is equal to `product`. If there is no
// such index, return -1. All integers in `values` must be positive.
int64 FindSuffixWithProduct(ArraySlice<int64> values, int64 product) {
- DCHECK(c_all_of(values, [](int64 value) { return value > 0; }));
+ DCHECK(absl::c_all_of(values, [](int64 value) { return value > 0; }));
int64 current_product = 1;
int64 i;
@@ -389,26 +389,26 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
result_subarray_size *= result_shape[result_dim];
}
- c_reverse(result);
+ absl::c_reverse(result);
if (VLOG_IS_ON(3)) {
std::vector<string> result_strings;
- c_transform(result, std::back_inserter(result_strings),
- [](ReshapePassthroughDimPair value) {
- return tensorflow::strings::StrCat(value.result_dim, "->",
- value.operand_dim);
- });
+ absl::c_transform(result, std::back_inserter(result_strings),
+ [](ReshapePassthroughDimPair value) {
+ return tensorflow::strings::StrCat(
+ value.result_dim, "->", value.operand_dim);
+ });
VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to ["
<< Join(result_shape, ",") << "] passthrough indices are ["
<< Join(result_strings, ",") << "] (legend: `result`->`operand`)";
}
- DCHECK(c_is_sorted(
+ DCHECK(absl::c_is_sorted(
result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) {
return lhs.result_dim < rhs.result_dim;
}));
- DCHECK(c_is_sorted(
+ DCHECK(absl::c_is_sorted(
result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) {
return lhs.operand_dim < rhs.operand_dim;
}));
@@ -420,20 +420,20 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
// `passthrough_dims`.
bool IsReshapePassthroughOperandDim(
ArraySlice<ReshapePassthroughDimPair> passthrough_dims, int64 dim) {
- return c_any_of(passthrough_dims,
- [&](ReshapePassthroughDimPair passthrough_dim_pair) {
- return passthrough_dim_pair.operand_dim == dim;
- });
+ return absl::c_any_of(passthrough_dims,
+ [&](ReshapePassthroughDimPair passthrough_dim_pair) {
+ return passthrough_dim_pair.operand_dim == dim;
+ });
}
// Maps `operand_dim` which must be an passthrough operand dimension to its
// corresponding passthrough result dimension based on `passthrough_dims`.
int64 MapPassthroughOperandDimToResultDim(
ArraySlice<ReshapePassthroughDimPair> passthrough_dims, int64 operand_dim) {
- auto it = c_find_if(passthrough_dims,
- [&](ReshapePassthroughDimPair passthrough_dim_pair) {
- return passthrough_dim_pair.operand_dim == operand_dim;
- });
+ auto it = absl::c_find_if(
+ passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) {
+ return passthrough_dim_pair.operand_dim == operand_dim;
+ });
CHECK(it != passthrough_dims.end());
return it->result_dim;
}
@@ -454,8 +454,8 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice<int64> operand_shape,
Shape StripDegenerateDimensions(const Shape& shape) {
DimensionVector new_dims;
- c_copy_if(shape.dimensions(), std::back_inserter(new_dims),
- [](int64 dim) { return dim != 1; });
+ absl::c_copy_if(shape.dimensions(), std::back_inserter(new_dims),
+ [](int64 dim) { return dim != 1; });
return ShapeUtil::MakeShape(shape.element_type(), new_dims);
}
}; // namespace
@@ -553,8 +553,8 @@ StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::ReshapeToAddDegenerateDims(
}();
DimensionVector new_result_shape_dims;
- c_copy(operand->shape().dimensions(),
- std::back_inserter(new_result_shape_dims));
+ absl::c_copy(operand->shape().dimensions(),
+ std::back_inserter(new_result_shape_dims));
for (int64 degenerate_dim : degenerate_dims) {
InsertAt(&new_result_shape_dims, degenerate_dim, 1);
}
@@ -695,8 +695,8 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
operand_dim);
};
- if (!c_all_of(scalar_indexed->output_dims(),
- is_reshape_passthrough_operand_dim)) {
+ if (!absl::c_all_of(scalar_indexed->output_dims(),
+ is_reshape_passthrough_operand_dim)) {
VLOG(3) << "Not all output dims are passthrough dims "
<< ToString(scalar_indexed);
return nullptr;
@@ -735,11 +735,11 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
// operand = s32[3,5,2] constant({...})
// indices = s32[7] parameter(0)
// gather = s32[3,2,7] gather(operand, indices),
- // output_window_dims={0,1},
- // elided_window_dims={1},
- // gather_dims_to_operand_dims={1},
+ // offset_dims={0,1},
+ // collapsed_slice_dims={1},
+ // start_index_map={1},
// index_vector_dim=1,
- // window_bounds={3,1,2}
+ // slice_sizes={3,1,2}
// reshape = s32[6,7] reshape(gather)
//
// In this case the gather maps to:
@@ -764,8 +764,8 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
&new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node,
scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim()));
- CHECK_EQ(c_accumulate(new_scalar_indexed_source_shape, 1LL,
- std::multiplies<int64>()),
+ CHECK_EQ(absl::c_accumulate(new_scalar_indexed_source_shape, 1LL,
+ std::multiplies<int64>()),
ShapeUtil::ElementsIn(scalar_indexed_source_shape));
CHECK(IsReshapePassthroughOperandDim(
@@ -781,9 +781,9 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
};
std::vector<int64> output_dims_for_new_scalar_indexed_node;
- c_transform(scalar_indexed->output_dims(),
- std::back_inserter(output_dims_for_new_scalar_indexed_node),
- map_passthrough_operand_dim_to_result_dim);
+ absl::c_transform(scalar_indexed->output_dims(),
+ std::back_inserter(output_dims_for_new_scalar_indexed_node),
+ map_passthrough_operand_dim_to_result_dim);
TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal,
TakeOwnership(scalar_indexed->literal().Reshape(
@@ -874,11 +874,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
ArraySlice<int64> broadcast_dims = broadcast_instr->dimensions();
auto is_broadcasted_dim = [&](int64 output_dim) {
- return c_find(broadcast_dims, output_dim) == broadcast_dims.end();
+ return absl::c_find(broadcast_dims, output_dim) == broadcast_dims.end();
};
// All of the output dims must be "broadcasted" dims for the other operand.
- if (!c_all_of(scalar_indexed_const->output_dims(), is_broadcasted_dim)) {
+ if (!absl::c_all_of(scalar_indexed_const->output_dims(),
+ is_broadcasted_dim)) {
return nullptr;
}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index e923dc39f7..675eb31d26 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -265,7 +265,7 @@ class IndexedArrayAnalysis {
StatusOr<Array*> ComputeArrayForGather(
const Shape& shape, const GatherDimensionNumbers& dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes, Array* source,
Array* indices);
StatusOr<Array*> ComputeArrayForDotWithIndexedLhs(
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
index 5f4b42799b..97052edf7d 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
@@ -82,11 +82,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[5] parameter(1)
ROOT gather = s32[5,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
}
)";
@@ -102,11 +102,11 @@ ENTRY main {
operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}})
indices = s32[5] parameter(0)
ROOT gather = s32[5,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
}
)";
@@ -122,11 +122,11 @@ ENTRY main {
operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}})
indices = s32[5,2] parameter(0)
ROOT gather = s32[5] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
@@ -141,11 +141,11 @@ ENTRY main {
operand = s32[3,3,1] parameter(0)
indices = s32[5] parameter(1)
ROOT gather = s32[5,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,2},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0,2},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3,1}
+ slice_sizes={1,3,1}
}
)";
@@ -160,11 +160,11 @@ ENTRY main {
operand = s32[3,3,1] parameter(0)
indices = s32[5] parameter(1)
ROOT gather = s32[5,2,3] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={2},
- gather_dims_to_operand_dims={0},
+ offset_dims={1,2},
+ collapsed_slice_dims={2},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={2,3,1}
+ slice_sizes={2,3,1}
}
)";
@@ -179,11 +179,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[5] parameter(1)
ROOT gather = s32[5,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,2}
+ slice_sizes={1,2}
}
)";
@@ -199,17 +199,17 @@ ENTRY main {
indices_a = s32[5] parameter(0)
indices_b = s32[2] parameter(1)
gather_a = s32[5,3] gather(operand, indices_a),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
ROOT gather_b = s32[2,3] gather(gather_a, indices_b),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
}
)";
@@ -228,17 +228,17 @@ ENTRY main {
indices_a = s32[5,7] parameter(1)
indices_b = s32[2] parameter(2)
gather_a = s32[5,3,7] gather(operand, indices_a),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT gather_b = s32[5,3,2] gather(gather_a, indices_b),
- output_window_dims={0,1},
- elided_window_dims={2},
- gather_dims_to_operand_dims={2},
+ offset_dims={0,1},
+ collapsed_slice_dims={2},
+ start_index_map={2},
index_vector_dim=1,
- window_bounds={5,3,1}
+ slice_sizes={5,3,1}
}
)";
@@ -256,17 +256,17 @@ ENTRY main {
indices_a = s32[2] parameter(1)
indices_b = s32[5,7] parameter(2)
gather_a = s32[2,6] gather(operand, indices_a),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,6}
+ slice_sizes={1,6}
ROOT gather_b = s32[5,6,7] gather(gather_a, indices_b),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,6}
+ slice_sizes={1,6}
}
)";
@@ -284,17 +284,17 @@ ENTRY main {
indices_a = s32[5,7] parameter(1)
indices_b = s32[4,8] parameter(2)
gather_a = s32[5,3,7] gather(operand, indices_a),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT gather_b = s32[4,5,3,8] gather(gather_a, indices_b),
- output_window_dims={1,2},
- elided_window_dims={2},
- gather_dims_to_operand_dims={2},
+ offset_dims={1,2},
+ collapsed_slice_dims={2},
+ start_index_map={2},
index_vector_dim=2,
- window_bounds={5,3,1}
+ slice_sizes={5,3,1}
}
)";
@@ -312,11 +312,11 @@ ENTRY main {
operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}})
indices = s32[5] parameter(0)
gather = s32[5,4] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT reshape = s32[5,2,2] reshape(gather)
}
)";
@@ -333,11 +333,11 @@ ENTRY main {
operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}})
indices = s32[5,7] parameter(0)
gather = s32[5,4,7] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT reshape = s32[5,2,2,7] reshape(gather)
}
)";
@@ -358,11 +358,11 @@ ENTRY main {
{{1,2,3,4,5,6},{1,2,3,4,5,6}}})
indices = s32[5,7] parameter(0)
gather = s32[5,2,6,7] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1,2},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,2,6}
+ slice_sizes={1,2,6}
ROOT reshape = s32[5,3,4,7] reshape(gather)
}
)";
@@ -381,11 +381,11 @@ ENTRY main {
{1,2,3,4,5,6},{1,2,3,4,5,6}})
indices = s32[1] parameter(0)
gather = s32[1,6] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,6}
+ slice_sizes={1,6}
ROOT reshape = s32[1,1,6] reshape(gather)
}
)";
@@ -408,14 +408,14 @@ ENTRY main {
operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } })
i.0 = s64[1,3]{1,0} parameter(0)
- g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), output_window_dims={2},
- elided_window_dims={0}, gather_dims_to_operand_dims={0},
- index_vector_dim=2, window_bounds={1,3}
+ g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), offset_dims={2},
+ collapsed_slice_dims={0}, start_index_map={0},
+ index_vector_dim=2, slice_sizes={1,3}
i.1 = s64[1] parameter(1)
- g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), output_window_dims={0,2},
- elided_window_dims={1}, gather_dims_to_operand_dims={1},
- index_vector_dim=1, window_bounds={1,1,3}
+ g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), offset_dims={0,2},
+ collapsed_slice_dims={1}, start_index_map={1},
+ index_vector_dim=1, slice_sizes={1,1,3}
ROOT reshape = s32[1,3]{1,0} reshape(g.1)
}
@@ -441,11 +441,11 @@ ENTRY main {
operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}})
indices = s32[1] parameter(0)
gather = s32[1,6] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,6}
+ slice_sizes={1,6}
ROOT reshape = s32[1,1,6] reshape(gather)
}
)";
@@ -469,11 +469,11 @@ ENTRY main {
{1,2,3,4,5,6},{1,2,3,4,5,6}}})
indices = s32[1] parameter(0)
gather = s32[1,1,6] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1,2},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={1,1,6}
+ slice_sizes={1,1,6}
ROOT reshape = s32[1,1,1,6] reshape(gather)
}
)";
@@ -500,11 +500,11 @@ ENTRY main {
{1,2,3,4,5,6},{1,2,3,4,5,6}})
indices = s32[1,5] parameter(0)
gather = s32[1,5,6] gather(operand, indices),
- output_window_dims={2},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={2},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,6}
+ slice_sizes={1,6}
ROOT reshape = s32[1,1,5,6] reshape(gather)
}
)";
@@ -530,11 +530,11 @@ ENTRY main {
operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}})
indices = s32[5,6] parameter(0)
gather = s32[5,4,6] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT reshape = s32[5,2,2,2,3] reshape(gather)
}
)";
@@ -562,11 +562,11 @@ ENTRY main {
{{1,2},{3,4},{5,6},{7,8},{9,10}}})
indices = s32[7] parameter(0)
gather = s32[3,2,7] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0,1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3,1,2}
+ slice_sizes={3,1,2}
ROOT reshape = s32[6,7] reshape(gather)
}
)";
@@ -594,11 +594,11 @@ ENTRY main {
{{1},{2},{3},{4}}})
indices = s32[5,6] parameter(0)
gather = s32[5,4,6,1] gather(operand, indices),
- output_window_dims={1,3},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1,3},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,4,1}
+ slice_sizes={1,4,1}
ROOT reshape = s32[5,2,2,2,3,1] reshape(gather)
}
)";
@@ -623,11 +623,11 @@ ENTRY main {
operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
indices = s32[5] parameter(0)
gather = f32[5,4] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT tanh = f32[5,4] tanh(gather)
}
)";
@@ -650,11 +650,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT add = s32[5,4] add(gather, constant_broadcasted)
}
)";
@@ -678,11 +678,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT sub = s32[5,4] subtract(gather, constant_broadcasted)
}
)";
@@ -706,11 +706,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT sub = s32[5,4] subtract(constant_broadcasted, gather)
}
)";
@@ -733,11 +733,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT add = s32[5,4] add(gather, constant_broadcasted)
}
)";
@@ -760,11 +760,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT add = s32[5,4] add(gather, constant_broadcasted)
}
)";
@@ -808,11 +808,11 @@ ENTRY main {
dot_rhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
indices = s32[5] parameter(0)
dot_lhs = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
@@ -835,11 +835,11 @@ ENTRY main {
dot_rhs_constant = s32[3,3] constant(s32[3,3]{{1,2,3},{4,5,6},{7,8,9}})
indices = s32[5] parameter(0)
dot_lhs = s32[3,5] gather(gather_operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={0}, rhs_contracting_dims={0}
}
)";
@@ -863,11 +863,11 @@ ENTRY main {
dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
indices = s32[5] parameter(0)
dot_rhs = s32[3,5] gather(gather_operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
@@ -892,11 +892,11 @@ ENTRY main {
dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
indices = s32[5] parameter(0)
dot_rhs = s32[5,3] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1}
}
)";
@@ -921,11 +921,11 @@ ENTRY main {
dot_lhs_constant = s32[2,2,3] constant(s32[2,2,3]{{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}})
indices = s32[4] parameter(0)
dot_rhs = s32[2,3,4] gather(gather_operand, indices),
- output_window_dims={0,1},
- elided_window_dims={2},
- gather_dims_to_operand_dims={2},
+ offset_dims={0,1},
+ collapsed_slice_dims={2},
+ start_index_map={2},
index_vector_dim=1,
- window_bounds={2,3,1}
+ slice_sizes={2,3,1}
ROOT dot = s32[2,2,4] dot(dot_lhs_constant, dot_rhs),
lhs_contracting_dims={2}, rhs_contracting_dims={1},
lhs_batch_dims={0}, rhs_batch_dims={0}
@@ -952,11 +952,11 @@ ENTRY main {
dot_rhs_constant = s32[2,3] constant(s32[2,3]{{1,2,3},{4,5,6}})
indices = s32[2] parameter(0)
dot_lhs = s32[3,2] gather(gather_operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT dot = s32[3,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc
index 32937b33b3..5695bc2420 100644
--- a/tensorflow/compiler/xla/service/inliner_test.cc
+++ b/tensorflow/compiler/xla/service/inliner_test.cc
@@ -18,8 +18,8 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index f33942d679..2fd2214806 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -497,7 +498,7 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput(
bool InstructionFusion::MultiOutputFusionCreatesCycle(
HloInstruction* producer, HloInstruction* consumer) {
- return c_any_of(
+ return absl::c_any_of(
consumer->operands(), [&](const HloInstruction* consumer_operand) {
// The fusion algorithm traverses the HLO graph in reverse post order.
// Thus `cosumers` is visited before its operands (including
diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD
index 8652599dc6..581f8d2e92 100644
--- a/tensorflow/compiler/xla/service/interpreter/BUILD
+++ b/tensorflow/compiler/xla/service/interpreter/BUILD
@@ -12,12 +12,11 @@ cc_library(
srcs = ["interpreter_transfer_manager.cc"],
hdrs = ["interpreter_transfer_manager.h"],
deps = [
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:generic_transfer_manager",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/compiler/xla/service/interpreter:platform_id",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
alwayslink = True, # Contains per-platform transfer manager registration
)
@@ -32,8 +31,6 @@ cc_library(
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:algebraic_simplifier",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:computation_placer",
@@ -54,6 +51,7 @@ cc_library(
"//tensorflow/compiler/xla/service:while_loop_simplifier",
"//tensorflow/core:lib",
"//tensorflow/stream_executor",
+ "@com_google_absl//absl/memory",
],
alwayslink = True, # Contains compiler registration
)
@@ -79,7 +77,6 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo",
@@ -91,6 +88,7 @@ cc_library(
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc
index 9f8f4bda87..bb69cb9c47 100644
--- a/tensorflow/compiler/xla/service/interpreter/compiler.cc
+++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <string>
#include <utility>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
@@ -69,8 +69,8 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
// Create executable from only the Hlo module.
std::unique_ptr<Executable> executable =
- xla::MakeUnique<InterpreterExecutable>(std::move(hlo_module),
- xla::MakeUnique<HloEvaluator>());
+ absl::make_unique<InterpreterExecutable>(
+ std::move(hlo_module), absl::make_unique<HloEvaluator>());
return std::move(executable);
}
@@ -103,11 +103,11 @@ HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction()
static bool InitModule() {
xla::Compiler::RegisterCompilerFactory(
se::interpreter::kXlaInterpreterPlatformId, []() {
- return xla::MakeUnique<xla::interpreter::InterpreterCompiler>();
+ return absl::make_unique<xla::interpreter::InterpreterCompiler>();
});
xla::ComputationPlacer::RegisterComputationPlacer(
se::interpreter::kXlaInterpreterPlatformId,
- []() { return xla::MakeUnique<xla::ComputationPlacer>(); });
+ []() { return absl::make_unique<xla::ComputationPlacer>(); });
return true;
}
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 8d40c08d55..2259dc1083 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -21,8 +21,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/interpreter/executor.h"
diff --git a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc
index d27cd7502f..7955ee5cf3 100644
--- a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <memory>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/interpreter/platform_id.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
@@ -31,7 +31,7 @@ InterpreterTransferManager::InterpreterTransferManager()
static std::unique_ptr<xla::TransferManager>
CreateInterpreterTransferManager() {
- return xla::MakeUnique<xla::InterpreterTransferManager>();
+ return absl::make_unique<xla::InterpreterTransferManager>();
}
static bool InitModule() {
diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc
index 42c2c28997..e57a9b3672 100644
--- a/tensorflow/compiler/xla/service/interpreter/platform.cc
+++ b/tensorflow/compiler/xla/service/interpreter/platform.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/interpreter/executor.h"
#include "tensorflow/stream_executor/device_options.h"
#include "tensorflow/stream_executor/lib/initialize.h"
@@ -70,8 +71,8 @@ port::StatusOr<StreamExecutor*> XlaInterpreterPlatform::GetExecutor(
port::StatusOr<std::unique_ptr<StreamExecutor>>
XlaInterpreterPlatform::GetUncachedExecutor(
const StreamExecutorConfig& config) {
- auto executor = MakeUnique<StreamExecutor>(
- this, MakeUnique<XlaInterpreterExecutor>(config.plugin_config));
+ auto executor = absl::make_unique<StreamExecutor>(
+ this, absl::make_unique<XlaInterpreterExecutor>(config.plugin_config));
auto init_status = executor->Init(config.ordinal, config.device_options);
if (!init_status.ok()) {
return port::Status{
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 805fdb2d5b..c75bffc63d 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -26,9 +26,9 @@ limitations under the License.
#include <string>
#include <tuple>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -137,7 +137,7 @@ PointsToSet::BufferSet* LayoutConstraints::GetBufferSet(
}
auto& buffer_set =
buffer_sets_cache_
- .emplace(instruction, MakeUnique<PointsToSet::BufferSet>())
+ .emplace(instruction, absl::make_unique<PointsToSet::BufferSet>())
.first->second;
const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction);
points_to_set.ForEachElement(
@@ -1008,7 +1008,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
//
// TODO(jingyue): Other operations, such as kSlice and kConcat, can benefit
// from assigning the same layout to input and output.
- return MakeUnique<Layout>(output_layout);
+ return absl::make_unique<Layout>(output_layout);
}
if (instruction->opcode() == HloOpcode::kReshape) {
@@ -1031,13 +1031,13 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
*operand_shape.mutable_layout() =
LayoutUtil::GetDefaultLayoutForShape(operand_shape);
if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) {
- return MakeUnique<Layout>(operand_shape.layout());
+ return absl::make_unique<Layout>(operand_shape.layout());
}
if (ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)) {
*operand_shape.mutable_layout() = output_layout;
if (ShapeUtil::ReshapeIsBitcast(operand_shape,
output_shape_with_layout)) {
- return MakeUnique<Layout>(output_layout);
+ return absl::make_unique<Layout>(output_layout);
}
}
auto aligned_operand_shape =
@@ -1046,7 +1046,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
auto operand_layout = aligned_operand_shape.value().layout();
TF_CHECK_OK(
LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape));
- return MakeUnique<Layout>(operand_layout);
+ return absl::make_unique<Layout>(operand_layout);
}
}
@@ -1062,7 +1062,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major);
TF_CHECK_OK(
LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
- return MakeUnique<Layout>(operand_layout);
+ return absl::make_unique<Layout>(operand_layout);
}
return nullptr;
@@ -1080,7 +1080,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
!ShapeUtil::IsScalar(operand->shape()) &&
ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape())) {
// Assign users the same layout as the operand.
- return MakeUnique<Layout>(operand_layout);
+ return absl::make_unique<Layout>(operand_layout);
}
if (user->opcode() == HloOpcode::kReshape) {
@@ -1103,13 +1103,13 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
*output_shape.mutable_layout() =
LayoutUtil::GetDefaultLayoutForShape(output_shape);
if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) {
- return MakeUnique<Layout>(output_shape.layout());
+ return absl::make_unique<Layout>(output_shape.layout());
}
if (ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)) {
*output_shape.mutable_layout() = operand_layout;
if (ShapeUtil::ReshapeIsBitcast(output_shape,
operand_shape_with_layout)) {
- return MakeUnique<Layout>(operand_layout);
+ return absl::make_unique<Layout>(operand_layout);
}
}
auto aligned_user_shape =
@@ -1118,7 +1118,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
auto user_layout = aligned_user_shape.value().layout();
TF_CHECK_OK(
LayoutUtil::ValidateLayoutForShape(user_layout, output_shape));
- return MakeUnique<Layout>(user_layout);
+ return absl::make_unique<Layout>(user_layout);
}
}
@@ -1134,7 +1134,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
}
Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major);
TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
- return MakeUnique<Layout>(user_layout);
+ return absl::make_unique<Layout>(user_layout);
}
return nullptr;
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index cdd3daf73b..ce2d6678a5 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -88,6 +88,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
"@llvm//:core",
],
)
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
index 28ca793e3e..cbfd2e7012 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <map>
#include <vector>
+#include "absl/algorithm/container.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -81,7 +82,7 @@ class IrArray {
}
}
CHECK_NE(index_type_, nullptr);
- CHECK(c_all_of(multidim, [&](llvm::Value* v) {
+ CHECK(absl::c_all_of(multidim, [&](llvm::Value* v) {
return index_type_ == v->getType();
}));
}
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 5e02096ee5..597a788c5d 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -19,10 +19,10 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/executable.h"
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
index d631fb5ee4..eaa09591b7 100644
--- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
@@ -89,7 +90,7 @@ void LogicalBufferAnalysis::NewLogicalBuffer(HloInstruction* instruction,
const ShapeIndex& index) {
CHECK_EQ(logical_buffers_.size(), next_buffer_id_);
logical_buffers_.emplace_back(
- MakeUnique<LogicalBuffer>(instruction, index, next_buffer_id_));
+ absl::make_unique<LogicalBuffer>(instruction, index, next_buffer_id_));
output_buffers_[std::make_pair(instruction, index)] =
logical_buffers_.back().get();
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h
index 0019cd7254..6aa639a954 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.h
@@ -104,17 +104,17 @@ class MultiOutputFusion : public HloPassInterface {
// InstructionFusion instead.
virtual bool DoProducerConsumerMultiOutputFusion();
- private:
- // Update the internal data structures after instr1 and instr2 are fused into
- // one fusion instruction.
- void Update(HloInstruction* instr1, HloInstruction* instr2);
-
// Optimization fuel is a compiler debugging technique that makes an
// optimization pass stop what it is doing after having made N changes to the
// program, where N is the fuel. By varying N, this can be used to find the
// first single change that makes a test fail.
int64 fuel_;
+ private:
+ // Update the internal data structures after instr1 and instr2 are fused into
+ // one fusion instruction.
+ void Update(HloInstruction* instr1, HloInstruction* instr2);
+
// Computation for the pass.
HloComputation* computation_;
diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc
index ca86c5d13e..4df746fca9 100644
--- a/tensorflow/compiler/xla/service/reshape_mover.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover.cc
@@ -38,6 +38,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include <algorithm>
+
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -374,7 +376,7 @@ StatusOr<bool> TryReshapeMoveOnCandidates(
removed = false;
for (auto operand : nontrivial_operands) {
- if (c_any_of(operand->users(), [&](HloInstruction* user) {
+ if (absl::c_any_of(operand->users(), [&](HloInstruction* user) {
return !reshape_candidates->count(user);
})) {
for (auto* user : operand->users()) {
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc
index ccb9fb3e3a..7534a3f7e3 100644
--- a/tensorflow/compiler/xla/service/reshape_mover_test.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc
@@ -15,9 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/reshape_mover.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc
index 45ca731153..338f0c09e9 100644
--- a/tensorflow/compiler/xla/service/scatter_expander.cc
+++ b/tensorflow/compiler/xla/service/scatter_expander.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/scatter_expander.h"
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
@@ -92,7 +93,7 @@ static StatusOr<HloInstruction*> PermuteScatterAndWindowDims(
permutation.reserve(updates_rank);
for (int64 i = 0; i < updates_rank; ++i) {
- bool is_scatter_dim = !c_binary_search(update_window_dims, i);
+ bool is_scatter_dim = !absl::c_binary_search(update_window_dims, i);
if (is_scatter_dim) {
permutation.push_back(i);
}
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 1dbf540d13..18d1b7732b 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -20,10 +20,10 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@@ -245,7 +245,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
const ExecutionOptions* execution_options) {
- auto config = MakeUnique<HloModuleConfig>(program_shape);
+ auto config = absl::make_unique<HloModuleConfig>(program_shape);
ComputationLayout* computation_layout =
config->mutable_entry_computation_layout();
if (program_shape.parameters_size() != argument_shapes.size()) {
@@ -326,7 +326,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
if (directory_path.empty() && execution_directory_path.empty()) {
continue;
}
- auto hlo_snapshot = MakeUnique<HloSnapshot>();
+ auto hlo_snapshot = absl::make_unique<HloSnapshot>();
*hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i];
if (!directory_path.empty()) {
string filename =
@@ -409,7 +409,8 @@ Service::ExecuteParallelAndRegisterResult(
streams.push_back(std::move(stream));
if (replica == 0 && profile != nullptr) {
- timers.push_back(MakeUnique<se::Timer>(streams.back()->parent()));
+ timers.push_back(
+ absl::make_unique<se::Timer>(streams.back()->parent()));
streams.back()
->InitTimer(timers.back().get())
.ThenStartTimer(timers.back().get());
@@ -800,7 +801,7 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
module_proto.name().c_str());
// Dump computation proto state if flag is set.
- auto hlo_snapshot = MakeUnique<HloSnapshot>();
+ auto hlo_snapshot = absl::make_unique<HloSnapshot>();
const string& directory_path =
module_config->debug_options().xla_dump_computations_to();
const string& execution_directory_path =
@@ -954,7 +955,7 @@ namespace {
// shape and DeviceMemoryBase values of the clone are identical to the original.
std::unique_ptr<ShapedBuffer> CloneShapedBufferOnDevice(
const ShapedBuffer& shaped_buffer, int device_ordinal) {
- auto clone = MakeUnique<ShapedBuffer>(
+ auto clone = absl::make_unique<ShapedBuffer>(
shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape(),
shaped_buffer.platform(), device_ordinal);
clone->buffers() = shaped_buffer.buffers();
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index a4ea2b28f4..ec6aa6df55 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <set>
#include <string>
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -1530,7 +1531,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
const Shape& lhs, const Shape& rhs, const Window& window,
- const ConvolutionDimensionNumbers& dnums) {
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count) {
TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
@@ -1640,12 +1641,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const int64 kernel_output_features =
rhs.dimensions(dnums.kernel_output_feature_dimension());
- if (input_features != kernel_input_features) {
+ if (input_features != kernel_input_features * feature_group_count) {
return InvalidArgument(
"Expected LHS feature dimension (value %lld) to match RHS "
- "input feature dimension (value %lld); got <conv>(%s, %s)\n"
+ "input feature dimension * feature_group_count (value %lld); "
+ "got <conv>(%s, %s)\n"
"Dimension numbers: {%s}.",
- input_features, kernel_input_features,
+ input_features, kernel_input_features * feature_group_count,
ShapeUtil::HumanString(lhs).c_str(),
ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str());
}
@@ -2491,201 +2493,198 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
static Status ValidateGatherDimensionNumbers(
const Shape& input_shape,
- tensorflow::gtl::ArraySlice<int64> gather_indices_shape,
+ tensorflow::gtl::ArraySlice<int64> start_indices_shape,
const GatherDimensionNumbers& dim_numbers) {
- if (!c_is_sorted(dim_numbers.output_window_dims())) {
+ if (!absl::c_is_sorted(dim_numbers.offset_dims())) {
return InvalidArgument(
"Output window dimensions in gather op must be ascending; got: %s.",
- Join(dim_numbers.output_window_dims(), ", ").c_str());
+ Join(dim_numbers.offset_dims(), ", ").c_str());
}
- if (c_adjacent_find(dim_numbers.output_window_dims()) !=
- dim_numbers.output_window_dims().end()) {
+ if (absl::c_adjacent_find(dim_numbers.offset_dims()) !=
+ dim_numbers.offset_dims().end()) {
return InvalidArgument(
"Output window dimensions in gather op must not repeat; got: %s.",
- Join(dim_numbers.output_window_dims(), ", ").c_str());
+ Join(dim_numbers.offset_dims(), ", ").c_str());
}
- const int64 output_window_dim_count = dim_numbers.output_window_dims_size();
+ const int64 output_offset_dim_count = dim_numbers.offset_dims_size();
const int64 output_shape_rank =
- output_window_dim_count + gather_indices_shape.size() - 1;
+ output_offset_dim_count + start_indices_shape.size() - 1;
- for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) {
- int64 window_index = dim_numbers.output_window_dims(i);
- if (window_index < 0 || window_index >= output_shape_rank) {
+ for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) {
+ int64 offset_dim = dim_numbers.offset_dims(i);
+ if (offset_dim < 0 || offset_dim >= output_shape_rank) {
return InvalidArgument(
- "Window index %d in gather op is out of bounds; got %lld, but should "
+ "Offset dimension %d in gather op is out of bounds; got %lld, but "
+ "should "
"have been in [0,%lld).",
- i, window_index, output_shape_rank);
+ i, offset_dim, output_shape_rank);
}
}
- if (dim_numbers.gather_dims_to_operand_dims_size() !=
- gather_indices_shape[dim_numbers.index_vector_dim()]) {
+ if (dim_numbers.start_index_map_size() !=
+ start_indices_shape[dim_numbers.index_vector_dim()]) {
return InvalidArgument(
- "Gather op has %d elements in gather_dims_to_operand_dims and the "
- "bound of dimension index_vector_dim=%lld of gather_indices is "
+ "Gather op has %d elements in start_index_map and the "
+ "bound of dimension index_vector_dim=%lld of start_indices is "
"%lld. These two numbers must be equal.",
- dim_numbers.gather_dims_to_operand_dims_size(),
- dim_numbers.index_vector_dim(),
- gather_indices_shape[dim_numbers.index_vector_dim()]);
+ dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(),
+ start_indices_shape[dim_numbers.index_vector_dim()]);
}
- for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) {
- int64 gather_dim_to_input_dim = dim_numbers.gather_dims_to_operand_dims(i);
- if (gather_dim_to_input_dim < 0 ||
- gather_dim_to_input_dim >= input_shape.dimensions_size()) {
+ for (int i = 0; i < dim_numbers.start_index_map_size(); i++) {
+ int64 operand_dim_for_start_index_i = dim_numbers.start_index_map(i);
+ if (operand_dim_for_start_index_i < 0 ||
+ operand_dim_for_start_index_i >= input_shape.dimensions_size()) {
return InvalidArgument(
- "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), "
- "got: %d->%lld.",
- input_shape.dimensions_size(), i, gather_dim_to_input_dim);
+ "Invalid start_index_map; domain is [0, %d), got: %d->%lld.",
+ input_shape.dimensions_size(), i, operand_dim_for_start_index_i);
}
}
- std::vector<int64> sorted_gather_dims_to_operand_dims(
- dim_numbers.gather_dims_to_operand_dims().begin(),
- dim_numbers.gather_dims_to_operand_dims().end());
+ std::vector<int64> sorted_start_index_map(
+ dim_numbers.start_index_map().begin(),
+ dim_numbers.start_index_map().end());
- c_sort(sorted_gather_dims_to_operand_dims);
+ absl::c_sort(sorted_start_index_map);
- if (c_adjacent_find(sorted_gather_dims_to_operand_dims) !=
- sorted_gather_dims_to_operand_dims.end()) {
+ if (absl::c_adjacent_find(sorted_start_index_map) !=
+ sorted_start_index_map.end()) {
return InvalidArgument(
- "Repeated dimensions are not allowed in gather_dims_to_operand_dims; "
+ "Repeated dimensions are not allowed in start_index_map; "
"got: %s.",
- Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str());
+ Join(dim_numbers.start_index_map(), ", ").c_str());
}
- for (int64 elided_dim : dim_numbers.elided_window_dims()) {
- if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) {
+ for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) {
+ if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) {
return InvalidArgument(
- "Invalid elided_window_dims set in gather op; valid range is [0, "
+ "Invalid collapsed_slice_dims set in gather op; valid range is [0, "
"%d), got: %lld.",
- input_shape.dimensions_size(), elided_dim);
+ input_shape.dimensions_size(), collapsed_dim);
}
}
- if (!c_is_sorted(dim_numbers.elided_window_dims())) {
+ if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) {
return InvalidArgument(
- "elided_window_dims in gather op must be sorted; got: %s",
- Join(dim_numbers.elided_window_dims(), ", ").c_str());
+ "collapsed_slice_dims in gather op must be sorted; got: %s",
+ Join(dim_numbers.collapsed_slice_dims(), ", ").c_str());
}
- if (c_adjacent_find(dim_numbers.elided_window_dims()) !=
- dim_numbers.elided_window_dims().end()) {
+ if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) !=
+ dim_numbers.collapsed_slice_dims().end()) {
return InvalidArgument(
- "Repeated dimensions not allowed in elided_window_dims in gather op; "
+ "Repeated dimensions not allowed in collapsed_slice_dims in gather op; "
"got: %s.",
- Join(dim_numbers.elided_window_dims(), ", ").c_str());
+ Join(dim_numbers.collapsed_slice_dims(), ", ").c_str());
}
return Status::OK();
}
/*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
- const Shape& input_shape, const Shape& gather_indices_shape,
+ const Shape& input_shape, const Shape& start_indices_shape,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds) {
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
TF_RETURN_IF_ERROR(
ExpectArray(input_shape, "input tensor operand gather op"));
TF_RETURN_IF_ERROR(
- ExpectArray(gather_indices_shape, "gather indices operand of gather op"));
+ ExpectArray(start_indices_shape, "gather indices operand of gather op"));
- if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) {
+ if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
return InvalidArgument(
"Gather indices parameter must be an integral tensor; got %s.",
- ShapeUtil::HumanString(gather_indices_shape).c_str());
+ ShapeUtil::HumanString(start_indices_shape).c_str());
}
// We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if
// index_vector_dim is rank(P). The bounds of this expanded shape is
- // stored in expanded_gather_indices_shape.
+ // stored in expanded_start_indices_shape.
- if (gather_indices_shape.dimensions_size() <
+ if (start_indices_shape.dimensions_size() <
gather_dim_numbers.index_vector_dim() ||
gather_dim_numbers.index_vector_dim() < 0) {
return InvalidArgument(
- "Gather index leaf dimension must be within [0, rank(gather_indices) + "
- "1). rank(gather_indices) is %d and gather index leaf dimension is "
+ "Gather index leaf dimension must be within [0, rank(start_indices) + "
+ "1). rank(start_indices) is %d and gather index leaf dimension is "
"%lld.",
- gather_indices_shape.dimensions_size(),
+ start_indices_shape.dimensions_size(),
gather_dim_numbers.index_vector_dim());
}
- std::vector<int64> expanded_gather_indices_shape;
- expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size());
- c_copy(gather_indices_shape.dimensions(),
- std::back_inserter(expanded_gather_indices_shape));
- if (expanded_gather_indices_shape.size() ==
+ std::vector<int64> expanded_start_indices_shape;
+ expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size());
+ absl::c_copy(start_indices_shape.dimensions(),
+ std::back_inserter(expanded_start_indices_shape));
+ if (expanded_start_indices_shape.size() ==
gather_dim_numbers.index_vector_dim()) {
- expanded_gather_indices_shape.push_back(1);
+ expanded_start_indices_shape.push_back(1);
}
TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers(
- input_shape, expanded_gather_indices_shape, gather_dim_numbers));
+ input_shape, expanded_start_indices_shape, gather_dim_numbers));
- if (window_bounds.size() != input_shape.dimensions_size()) {
+ if (slice_sizes.size() != input_shape.dimensions_size()) {
return InvalidArgument(
- "Gather op must have one window bound for every input dimension; got: "
- "len(window_bounds)=%lu, input_shape.rank=%d.",
- window_bounds.size(), input_shape.dimensions_size());
+ "Gather op must have one slice size for every input dimension; got: "
+ "len(slice_sizes)=%lu, input_shape.rank=%d.",
+ slice_sizes.size(), input_shape.dimensions_size());
}
- if (window_bounds.size() !=
- gather_dim_numbers.output_window_dims_size() +
- gather_dim_numbers.elided_window_dims_size()) {
+ if (slice_sizes.size() !=
+ gather_dim_numbers.offset_dims_size() +
+ gather_dim_numbers.collapsed_slice_dims_size()) {
return InvalidArgument(
- "All components of the window index in a gather op must either be a "
- "output window index or explicitly elided; got len(window_bounds)=%lu, "
- "output_window_bounds=%s, elided_window_bounds=%s.",
- window_bounds.size(),
- Join(gather_dim_numbers.output_window_dims(), ",").c_str(),
- Join(gather_dim_numbers.elided_window_dims(), ",").c_str());
+ "All components of the offset index in a gather op must either be a "
+ "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, "
+ "output_slice_sizes=%s, collapsed_slice_dims=%s.",
+ slice_sizes.size(), Join(gather_dim_numbers.offset_dims(), ",").c_str(),
+ Join(gather_dim_numbers.collapsed_slice_dims(), ",").c_str());
}
- for (int i = 0; i < window_bounds.size(); i++) {
- int64 window_bound = window_bounds[i];
- int64 corresponding_input_bound = input_shape.dimensions(i);
- if (window_bound < 0 || window_bound > corresponding_input_bound) {
+ for (int i = 0; i < slice_sizes.size(); i++) {
+ int64 slice_size = slice_sizes[i];
+ int64 corresponding_input_size = input_shape.dimensions(i);
+ if (slice_size < 0 || slice_size > corresponding_input_size) {
return InvalidArgument(
- "Window bound at index %d in gather op is out of range, must be "
- "within "
- "[0, %lld), got %lld.",
- i, corresponding_input_bound + 1, window_bound);
+ "Slice size at index %d in gather op is out of range, must be "
+ "within [0, %lld), got %lld.",
+ i, corresponding_input_size + 1, slice_size);
}
}
- for (int i = 0; i < gather_dim_numbers.elided_window_dims_size(); i++) {
- if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) {
+ for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) {
+ if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] != 1) {
return InvalidArgument(
- "Gather op can only elide window indices with bound 1, but bound is "
+ "Gather op can only collapse slice dims with bound 1, but bound is "
"%lld for index %lld at position %d.",
- window_bounds[gather_dim_numbers.elided_window_dims(i)],
- gather_dim_numbers.elided_window_dims(i), i);
+ slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)],
+ gather_dim_numbers.collapsed_slice_dims(i), i);
}
}
- int64 result_rank = gather_dim_numbers.output_window_dims_size() +
- (expanded_gather_indices_shape.size() - 1);
- int64 window_dims_seen = 0;
+ int64 result_rank = gather_dim_numbers.offset_dims_size() +
+ (expanded_start_indices_shape.size() - 1);
+ int64 offset_dims_seen = 0;
int64 gather_dims_seen = 0;
std::vector<int64> output_dim_bounds;
output_dim_bounds.reserve(result_rank);
for (int64 i = 0; i < result_rank; i++) {
int64 current_bound;
bool is_window_index =
- c_binary_search(gather_dim_numbers.output_window_dims(), i);
+ absl::c_binary_search(gather_dim_numbers.offset_dims(), i);
if (is_window_index) {
- while (c_binary_search(gather_dim_numbers.elided_window_dims(),
- window_dims_seen)) {
- window_dims_seen++;
+ while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(),
+ offset_dims_seen)) {
+ offset_dims_seen++;
}
- current_bound = window_bounds[window_dims_seen++];
+ current_bound = slice_sizes[offset_dims_seen++];
} else {
if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) {
gather_dims_seen++;
}
- current_bound = expanded_gather_indices_shape[gather_dims_seen++];
+ current_bound = expanded_start_indices_shape[gather_dims_seen++];
}
output_dim_bounds.push_back(current_bound);
@@ -2701,12 +2700,12 @@ Status ValidateScatterDimensionNumbers(
tensorflow::gtl::ArraySlice<int64> scatter_indices_shape,
const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
// Validate update_window_dims in ScatterDimensionNumbers.
- if (!c_is_sorted(dim_numbers.update_window_dims())) {
+ if (!absl::c_is_sorted(dim_numbers.update_window_dims())) {
return InvalidArgument(
"update_window_dims in scatter op must be sorted; got: %s.",
Join(dim_numbers.update_window_dims(), ", ").c_str());
}
- if (c_adjacent_find(dim_numbers.update_window_dims()) !=
+ if (absl::c_adjacent_find(dim_numbers.update_window_dims()) !=
dim_numbers.update_window_dims().end()) {
return InvalidArgument(
"update_window_dims in scatter op must not repeat; got: %s.",
@@ -2723,12 +2722,12 @@ Status ValidateScatterDimensionNumbers(
}
// Validate inserted_window_dims in ScatterDimensionNumbers.
- if (!c_is_sorted(dim_numbers.inserted_window_dims())) {
+ if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) {
return InvalidArgument(
"inserted_window_dims in scatter op must be sorted; got: %s.",
Join(dim_numbers.inserted_window_dims(), ", ").c_str());
}
- if (c_adjacent_find(dim_numbers.inserted_window_dims()) !=
+ if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) !=
dim_numbers.inserted_window_dims().end()) {
return InvalidArgument(
"inserted_window_dims in scatter op must not repeat; got: %s.",
@@ -2768,8 +2767,8 @@ Status ValidateScatterDimensionNumbers(
std::vector<int64> sorted_scatter_dims_to_operand_dims(
dim_numbers.scatter_dims_to_operand_dims().begin(),
dim_numbers.scatter_dims_to_operand_dims().end());
- c_sort(sorted_scatter_dims_to_operand_dims);
- if (c_adjacent_find(sorted_scatter_dims_to_operand_dims) !=
+ absl::c_sort(sorted_scatter_dims_to_operand_dims);
+ if (absl::c_adjacent_find(sorted_scatter_dims_to_operand_dims) !=
sorted_scatter_dims_to_operand_dims.end()) {
return InvalidArgument(
"Repeated dimensions not allowed in scatter_dims_to_operand_dims; "
@@ -2836,32 +2835,32 @@ Status ValidateScatterDimensionNumbers(
scatter_dim_numbers));
int64 inserted_dims_seen = 0;
- std::vector<int64> max_update_window_bounds;
+ std::vector<int64> max_update_slice_sizes;
for (int i = 0; i < operand_shape.dimensions_size(); ++i) {
if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() &&
scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) {
++inserted_dims_seen;
} else {
- max_update_window_bounds.push_back(operand_shape.dimensions(i));
+ max_update_slice_sizes.push_back(operand_shape.dimensions(i));
}
}
for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) {
auto update_window_dim = scatter_dim_numbers.update_window_dims(i);
if (updates_shape.dimensions(update_window_dim) >
- max_update_window_bounds[i]) {
+ max_update_slice_sizes[i]) {
return InvalidArgument(
"Bounds of the window dimensions of updates must not exceed the "
"bounds of the corresponding dimensions of operand. For dimension "
"%lld, updates bound is %lld, operand bound is %lld.",
update_window_dim, updates_shape.dimensions(update_window_dim),
- max_update_window_bounds[i]);
+ max_update_slice_sizes[i]);
}
}
int64 scatter_dims_seen = 0;
for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) {
bool is_update_window_dim =
- c_binary_search(scatter_dim_numbers.update_window_dims(), i);
+ absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i);
if (is_update_window_dim) {
continue;
}
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index c185b0a1bd..4974ac9916 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -112,7 +112,8 @@ class ShapeInference {
// filter (rhs) to lhs in the way specified by the fields on window.
static StatusOr<Shape> InferConvolveShape(
const Shape& lhs, const Shape& rhs, const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1);
// Infers the shape produced by the given FFT type on the given operand.
static StatusOr<Shape> InferFftShape(
@@ -275,9 +276,9 @@ class ShapeInference {
// with the given input shape, gather indices shape and gather dimension
// numbers.
static StatusOr<Shape> InferGatherShape(
- const Shape& input_shape, const Shape& gather_indices_shape,
+ const Shape& input_shape, const Shape& start_indices_shape,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
// Helper that validates the given input shape, scatter indices shape, updates
// shape, and scatter dimension numbers that constitute a scatter operation,
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index a73fa181cd..4ed8fc6b86 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -1654,11 +1654,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) {
ShapeInference::InferGatherShape(
matrix_64_48_, s64_vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
+ /*offset_dims=*/{0},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{1},
/*index_vector_dim=*/1),
- /*window_bounds=*/{64, 1}));
+ /*slice_sizes=*/{64, 1}));
EXPECT_TRUE(
ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
<< ShapeUtil::HumanString(gather_shape);
@@ -1669,11 +1669,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) {
ShapeInference::InferGatherShape(
matrix_64_48_, s64_vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{1},
- /*elided_window_dims=*/{0},
- /*gather_dims_to_operand_dims=*/{0},
+ /*offset_dims=*/{1},
+ /*collapsed_slice_dims=*/{0},
+ /*start_index_map=*/{0},
/*index_vector_dim=*/1),
- /*window_bounds=*/{1, 48}));
+ /*slice_sizes=*/{1, 48}));
EXPECT_TRUE(
ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
<< ShapeUtil::HumanString(gather_shape);
@@ -1684,11 +1684,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) {
ShapeInference::InferGatherShape(
matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4},
- /*elided_window_dims=*/{0},
- /*gather_dims_to_operand_dims=*/{0},
+ /*offset_dims=*/{4},
+ /*collapsed_slice_dims=*/{0},
+ /*start_index_map=*/{0},
/*index_vector_dim=*/4),
- /*window_bounds=*/{1, 48}));
+ /*slice_sizes=*/{1, 48}));
EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
<< ShapeUtil::HumanString(gather_shape);
@@ -1700,11 +1700,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
EXPECT_TRUE(ShapeUtil::Equal(
gather_shape,
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26})))
@@ -1717,11 +1717,11 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/2),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
EXPECT_TRUE(ShapeUtil::Equal(
gather_shape,
@@ -1735,11 +1735,11 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/0),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
EXPECT_TRUE(ShapeUtil::Equal(
gather_shape,
@@ -1749,16 +1749,15 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) {
// This is equivalent to a dynamic slice.
- TF_ASSERT_OK_AND_ASSIGN(
- Shape gather_shape,
- ShapeInference::InferGatherShape(
- f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
- HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0, 1, 2, 3, 4},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
- /*index_vector_dim=*/0),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
+ ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*offset_dims=*/{0, 1, 2, 3, 4},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/0),
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26})))
@@ -1772,11 +1771,11 @@ TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) {
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0, 1, 2, 3},
- /*elided_window_dims=*/{0},
- /*gather_dims_to_operand_dims=*/{0},
+ /*offset_dims=*/{0, 1, 2, 3},
+ /*collapsed_slice_dims=*/{0},
+ /*start_index_map=*/{0},
/*index_vector_dim=*/0),
- /*window_bounds=*/{1, 30, 29, 28, 27}));
+ /*slice_sizes=*/{1, 30, 29, 28, 27}));
EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
ShapeUtil::MakeShape(F32, {30, 29, 28, 27})))
@@ -1787,11 +1786,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
tuple_shape_, s64_vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
+ /*offset_dims=*/{0},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{1},
/*index_vector_dim=*/1),
- /*window_bounds=*/{64, 1});
+ /*slice_sizes=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("Expected array argument for input"))
@@ -1802,11 +1801,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
s64_vector_32_, tuple_shape_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
+ /*offset_dims=*/{0},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{1},
/*index_vector_dim=*/0),
- /*window_bounds=*/{64, 1});
+ /*slice_sizes=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("Expected array argument for gather indices"))
@@ -1817,11 +1816,11 @@ TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
s64_vector_32_, vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
+ /*offset_dims=*/{0},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{1},
/*index_vector_dim=*/0),
- /*window_bounds=*/{64, 1});
+ /*slice_sizes=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("Gather indices parameter must be an integral tensor"))
@@ -1833,11 +1832,11 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 8, 7},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 8, 7},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
@@ -1850,11 +1849,11 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 7},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 7},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
@@ -1867,14 +1866,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 99, 100, 101},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 99, 100, 101},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Window index 2 in gather op is out of bounds"))
+ HasSubstr("Offset dimension 2 in gather op is out of bounds"))
<< statusor.status();
}
@@ -1883,14 +1882,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 9},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 9},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Window index 4 in gather op is out of bounds"))
+ HasSubstr("Offset dimension 4 in gather op is out of bounds"))
<< statusor.status();
}
@@ -1899,16 +1898,16 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{4},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{4},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
- HasSubstr("All components of the window index in a gather op must either "
- "be a output window index or explicitly elided"))
+ HasSubstr("All components of the offset index in a gather op must either "
+ "be a offset dimension or explicitly collapsed"))
<< statusor.status();
}
@@ -1917,14 +1916,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{0, 1, 2, 3, 19},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{0, 1, 2, 3, 19},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Invalid elided_window_dims set in gather op; valid "
+ HasSubstr("Invalid collapsed_slice_dims set in gather op; valid "
"range is [0, 5), got: 19"))
<< statusor.status();
}
@@ -1934,16 +1933,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{0, 1, 2, 3, 3},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{0, 1, 2, 3, 3},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
- EXPECT_THAT(
- statusor.status().error_message(),
- HasSubstr(
- "Repeated dimensions not allowed in elided_window_dims in gather op"))
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Repeated dimensions not allowed in "
+ "collapsed_slice_dims in gather op"))
<< statusor.status();
}
@@ -1952,17 +1950,16 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
- EXPECT_THAT(
- statusor.status().error_message(),
- HasSubstr("Gather op has 4 elements in gather_dims_to_operand_dims and "
- "the bound of dimension index_vector_dim=4 of "
- "gather_indices is 5. These two numbers must be equal."))
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Gather op has 4 elements in start_index_map and "
+ "the bound of dimension index_vector_dim=4 of "
+ "start_indices is 5. These two numbers must be equal."))
<< statusor.status();
}
@@ -1971,16 +1968,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 7},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
- EXPECT_THAT(
- statusor.status().error_message(),
- HasSubstr("Invalid gather_dims_to_operand_dims mapping; domain is "
- "[0, 5), got: 4->7"))
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Invalid start_index_map; domain is [0, 5), got: 4->7"))
<< statusor.status();
}
@@ -1989,16 +1984,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 3},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
- HasSubstr(
- "Repeated dimensions are not allowed in gather_dims_to_operand_dims"))
+ HasSubstr("Repeated dimensions are not allowed in start_index_map"))
<< statusor.status();
}
@@ -2007,14 +2001,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{2, 1},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{2, 1},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{1, 1, 28, 27, 26});
+ /*slice_sizes=*/{1, 1, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("elided_window_dims in gather op must be sorted"))
+ HasSubstr("collapsed_slice_dims in gather op must be sorted"))
<< statusor.status();
}
@@ -2023,15 +2017,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7},
- /*elided_window_dims=*/{2},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7},
+ /*collapsed_slice_dims=*/{2},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 1, 300, 26});
+ /*slice_sizes=*/{30, 29, 1, 300, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Window bound at index 3 in gather op is out of range, "
- "must be within [0, 48), got 300"))
+ HasSubstr("Slice size at index 3 in gather op is out of range, "
+ "must be within [0, 48), got 300."))
<< statusor.status();
}
@@ -2040,16 +2034,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 26});
+ /*slice_sizes=*/{30, 29, 28, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
- HasSubstr(
- "Gather op must have one window bound for every input dimension"))
+ HasSubstr("Gather op must have one slice size for every input dimension"))
<< statusor.status();
}
@@ -2058,15 +2051,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 26, 20});
+ /*slice_sizes=*/{30, 29, 28, 26, 20});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Gather op can only elide window indices with bound 1, "
- "but bound is 29 for index 1 at position 0"))
+ HasSubstr("Gather op can only collapse slice dims with bound 1, "
+ "but bound is 29 for index 1 at position 0."))
<< statusor.status();
}
@@ -2074,16 +2067,16 @@ TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/32),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("Gather index leaf dimension must be within [0, "
- "rank(gather_indices) + 1)"))
+ "rank(start_indices) + 1)"))
<< statusor.status();
}
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc
index 7d7dcac10b..70714ffff0 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.cc
+++ b/tensorflow/compiler/xla/service/shaped_buffer.cc
@@ -18,8 +18,8 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc
index 0fc2436679..d69e6362e9 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer_test.cc
+++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -34,7 +35,7 @@ TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) {
xla::StreamExecutorMemoryAllocator allocator(platform, executors);
const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {});
const int kDeviceOrdinal = 0;
- auto scoped_buffer = tensorflow::MakeUnique<xla::ScopedShapedBuffer>(
+ auto scoped_buffer = absl::make_unique<xla::ScopedShapedBuffer>(
shape, shape, &allocator, kDeviceOrdinal);
std::unique_ptr<xla::ShapedBuffer> buffer = std::move(scoped_buffer);
buffer = nullptr;
diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc
index c0582c6a2d..5d1cd1c442 100644
--- a/tensorflow/compiler/xla/service/stream_pool.cc
+++ b/tensorflow/compiler/xla/service/stream_pool.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/stream_pool.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -35,7 +35,7 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
if (!stream) {
// Create a new stream.
- stream = MakeUnique<se::Stream>(executor);
+ stream = absl::make_unique<se::Stream>(executor);
stream->Init();
VLOG(1) << stream->DebugStreamPointers()
<< " StreamPool created new stream";
diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc
index 32d368a904..e0f995fd0d 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/transfer_manager.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -61,7 +62,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
if (!s.ok()) {
return s;
}
- return MakeUnique<Literal>(std::move(literal));
+ return absl::make_unique<Literal>(std::move(literal));
}
Status TransferManager::TransferLiteralFromDevice(
@@ -120,7 +121,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
if (!s.ok()) {
return s;
}
- return MakeUnique<Literal>(std::move(literal));
+ return absl::make_unique<Literal>(std::move(literal));
}
Status TransferManager::TransferArrayToDevice(
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index 475a2e5c14..f77690a462 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -152,6 +152,26 @@ class TransferManager {
const Shape& on_host_shape, DeviceMemoryAllocator* allocator,
int device_ordinal);
+ // The given ShapedBuffer holds a handle to allocated memory, but it is not
+ // in the general case legal to immediately copy or access that allocated
+ // memory because queued operations on the device may alias that memory.
+ // Memory ordering is enforced by the Stream's happens-before relationship
+ // which allows eager deallocation and reallocation of buffers host-side even
+ // if the device hasn't finished with them.
+ //
+ // In certain cases, it can be known that a ShapedBuffer does not have any
+ // conflicting accesses on the device and thus is eligible to be accessed at
+ // any time from the host.
+ //
+ // This function returns true if device_buffer can be accessed immediately
+ // without waiting for the Stream's previously enqueued items. This only
+ // returns true if all subbuffers in device_buffer can be accessed
+ // immediately.
+ virtual bool CanShapedBufferBeAccessedNow(
+ se::StreamExecutor* executor, const ShapedBuffer& device_buffer) const {
+ return false;
+ }
+
/////
// The TransferManager class also serves as a point to register objects for
// the various platforms.
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index 0447807a41..0c2f2112af 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -441,7 +442,7 @@ PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet(
PerInstruction* pi = PerInst(instruction);
CHECK(pi->points_to_set == nullptr)
<< "instruction should not have been present in the map.";
- auto set = MakeUnique<PointsToSet>(&instruction->shape());
+ auto set = absl::make_unique<PointsToSet>(&instruction->shape());
pi->points_to_set = std::move(set);
// Return *set using the iterator returned by emplace.
return *pi->points_to_set;
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
index 62af45128a..aab1180662 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -32,7 +33,7 @@ static Status ReplaceUsesWhileKeepingLoopInvariance(
std::vector<HloInstruction*> users;
users.reserve(old_instr->user_count());
- c_copy(old_instr->users(), std::back_inserter(users));
+ absl::c_copy(old_instr->users(), std::back_inserter(users));
for (auto* user : users) {
for (int64 i = 0, e = user->operand_count(); i < e; i++) {
@@ -108,10 +109,10 @@ StatusOr<bool> WhileLoopConstantSinking::Run(HloModule* module) {
//
// This will let us sink the constant into the outer while first and then
// into the inner while in a single run of this pass.
- c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
- [](const HloInstruction* instr) {
- return instr->opcode() == HloOpcode::kWhile;
- });
+ absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
+ [](const HloInstruction* instr) {
+ return instr->opcode() == HloOpcode::kWhile;
+ });
}
for (HloInstruction* while_instr : while_instrs) {
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc
index 266039d2ff..0e7667de83 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc
@@ -206,7 +206,8 @@ body {
p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0
p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1
- outfeed = token[] outfeed(p_body.0)
+ token = token[] after-all()
+ outfeed = token[] outfeed(p_body.0, token)
ROOT root = (f32[2],f32[2],f32[2]) tuple(p_body.0, p_body.1, p_body.1)
}
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
index 09ddcffb22..cb132d4f16 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/service/tuple_util.h"
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/util.h"
@@ -65,8 +66,8 @@ static void CreateLoopInvariantCopy(
};
InlinedVector<HloInstruction*, 4> new_operands;
- c_transform(old_instruction->operands(), std::back_inserter(new_operands),
- get_new_operand);
+ absl::c_transform(old_instruction->operands(),
+ std::back_inserter(new_operands), get_new_operand);
HloInstruction* new_instruction =
parent_of_while->AddInstruction(old_instruction->CloneWithNewOperands(
@@ -197,7 +198,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody(
op->opcode() == HloOpcode::kConstant;
};
- if (!c_all_of(instruction->operands(), is_invariant)) {
+ if (!absl::c_all_of(instruction->operands(), is_invariant)) {
continue;
}
@@ -257,10 +258,10 @@ StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) {
bool changed = false;
std::vector<HloInstruction*> while_instrs;
for (auto* comp : module->computations()) {
- c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
- [](const HloInstruction* instr) {
- return instr->opcode() == HloOpcode::kWhile;
- });
+ absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
+ [](const HloInstruction* instr) {
+ return instr->opcode() == HloOpcode::kWhile;
+ });
}
for (HloInstruction* while_instr : while_instrs) {
diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc
index 1ef17b9d7d..52d9c3e5ae 100644
--- a/tensorflow/compiler/xla/service/while_util.cc
+++ b/tensorflow/compiler/xla/service/while_util.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_util.h"
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
@@ -206,7 +207,7 @@ static StatusOr<HloInstruction*> MakeInitTupleFromInitValues(
HloInstruction* zero = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
init_values_with_indvar.push_back(zero);
- c_copy(init_values, std::back_inserter(init_values_with_indvar));
+ absl::c_copy(init_values, std::back_inserter(init_values_with_indvar));
return computation->AddInstruction(
HloInstruction::CreateTuple(init_values_with_indvar));
}
@@ -215,8 +216,9 @@ static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) {
std::vector<Shape> loop_state_shape_components;
loop_state_shape_components.reserve(init_values.size() + 1);
loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {}));
- c_transform(init_values, std::back_inserter(loop_state_shape_components),
- [](HloInstruction* instr) { return instr->shape(); });
+ absl::c_transform(init_values,
+ std::back_inserter(loop_state_shape_components),
+ [](HloInstruction* instr) { return instr->shape(); });
return ShapeUtil::MakeTupleShape(loop_state_shape_components);
}
diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc
index 2ccb919acf..5e69419333 100644
--- a/tensorflow/compiler/xla/service/while_util_test.cc
+++ b/tensorflow/compiler/xla/service/while_util_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_util.h"
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
@@ -206,7 +207,7 @@ ENTRY main {
auto is_while = [](const HloInstruction* instr) {
return instr->opcode() == HloOpcode::kWhile;
};
- EXPECT_EQ(c_count_if(main->instructions(), is_while), 1);
+ EXPECT_EQ(absl::c_count_if(main->instructions(), is_while), 1);
}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index c74dd648ad..186c42ed13 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -21,8 +21,8 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc
index c4c958be4a..c8ff55e784 100644
--- a/tensorflow/compiler/xla/shape_tree_test.cc
+++ b/tensorflow/compiler/xla/shape_tree_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_tree.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -242,7 +243,7 @@ TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) {
TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) {
ShapeTree<std::unique_ptr<int>> shape_tree{tuple_shape_};
EXPECT_EQ(shape_tree.element({2}).get(), nullptr);
- *shape_tree.mutable_element({2}) = MakeUnique<int>(42);
+ *shape_tree.mutable_element({2}) = absl::make_unique<int>(42);
EXPECT_EQ(*shape_tree.element({2}), 42);
}
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 0f8cffd466..4d5c9efe9b 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -113,7 +113,6 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:backend",
@@ -127,6 +126,8 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
],
)
@@ -144,6 +145,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -187,7 +189,6 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:global_data",
@@ -201,6 +202,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -274,6 +276,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//third_party/eigen3",
+ "@com_google_absl//absl/memory",
],
)
@@ -385,6 +388,7 @@ xla_test(
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -813,6 +817,7 @@ CONVOLUTION_TEST_DEPS = [
"//tensorflow/compiler/xla/client:padding",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@@ -824,7 +829,7 @@ xla_test(
timeout = "long",
srcs = ["convolution_test.cc"],
shard_count = 25,
- deps = CONVOLUTION_TEST_DEPS,
+ deps = CONVOLUTION_TEST_DEPS + ["@com_google_absl//absl/memory"],
)
xla_test(
@@ -834,7 +839,7 @@ xla_test(
backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]},
backends = ["gpu"],
shard_count = 25,
- deps = CONVOLUTION_TEST_DEPS,
+ deps = CONVOLUTION_TEST_DEPS + ["@com_google_absl//absl/memory"],
)
xla_test(
@@ -885,6 +890,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -1067,6 +1073,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -1148,6 +1155,7 @@ xla_test_library(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -1155,6 +1163,7 @@ xla_test(
name = "reduce_window_test",
timeout = "long",
srcs = [],
+ shard_count = 20,
tags = [
"enable_for_xla_interpreter",
"optonly",
@@ -1210,6 +1219,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -1285,6 +1295,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -1350,6 +1361,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -1411,6 +1423,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -1540,17 +1553,16 @@ xla_test(
],
deps = [
"//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -1751,6 +1763,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -1772,6 +1785,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/stream_executor",
+ "@com_google_absl//absl/memory",
"@llvm//:core",
],
)
@@ -1823,6 +1837,7 @@ xla_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//third_party/eigen3",
+ "@com_google_absl//absl/memory",
],
)
@@ -1849,6 +1864,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -2086,6 +2102,7 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
"//tensorflow/core:test",
],
)
diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc
index c7b94b5bba..74d4d2eb10 100644
--- a/tensorflow/compiler/xla/tests/broadcast_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 59d917054b..2cab3264a7 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -17,12 +17,12 @@ limitations under the License.
#include <string>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -546,7 +546,7 @@ XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() {
std::unique_ptr<Array2D<float>> ClientLibraryTestBase::CreatePatternedMatrix(
int rows, int cols, float offset) {
- auto array = MakeUnique<Array2D<float>>(rows, cols);
+ auto array = absl::make_unique<Array2D<float>>(rows, cols);
for (int64 row = 0; row < rows; ++row) {
for (int64 col = 0; col < cols; ++col) {
(*array)(row, col) = col + (row * 1000.0f) + offset;
@@ -561,7 +561,7 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols,
int cols_padded) {
CHECK_GE(rows_padded, rows);
CHECK_GE(cols_padded, cols);
- auto array = MakeUnique<Array2D<float>>(rows_padded, cols_padded, 0.0);
+ auto array = absl::make_unique<Array2D<float>>(rows_padded, cols_padded, 0.0);
for (int64 row = 0; row < rows; ++row) {
for (int64 col = 0; col < cols; ++col) {
(*array)(row, col) = col + (row * 1000.0f);
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index b04a3b105c..24d0325929 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <type_traits>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -30,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
@@ -613,7 +613,7 @@ template <typename NativeT>
std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
const int rows, const int cols, NativeT min_value, NativeT max_value,
uint32 seed) {
- auto result = MakeUnique<Array2D<NativeT>>(rows, cols);
+ auto result = absl::make_unique<Array2D<NativeT>>(rows, cols);
PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
for (int y = 0; y < rows; ++y) {
for (int x = 0; x < cols; ++x) {
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
index 1adc68cc48..7a203d6873 100644
--- a/tensorflow/compiler/xla/tests/convert_test.cc
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -447,11 +448,11 @@ std::vector<float> GetInterestingF16ConversionTestCases() {
XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) {
std::vector<float> test_cases = GetInterestingF16ConversionTestCases();
std::vector<half> input;
- c_transform(test_cases, std::back_inserter(input),
- [](float f) { return Eigen::half(f); });
+ absl::c_transform(test_cases, std::back_inserter(input),
+ [](float f) { return Eigen::half(f); });
std::vector<float> expected_output;
- c_transform(input, std::back_inserter(expected_output),
- [](Eigen::half h) { return static_cast<float>(h); });
+ absl::c_transform(input, std::back_inserter(expected_output),
+ [](Eigen::half h) { return static_cast<float>(h); });
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
@@ -470,8 +471,8 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) {
XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) {
std::vector<float> input = GetInterestingF16ConversionTestCases();
std::vector<half> expected_output;
- c_transform(input, std::back_inserter(expected_output),
- [](float f) { return Eigen::half(f); });
+ absl::c_transform(input, std::back_inserter(expected_output),
+ [](float f) { return Eigen::half(f); });
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
index 7b6bbc4f57..38b6da4fa9 100644
--- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
@@ -17,11 +17,11 @@ limitations under the License.
#include <array>
#include <memory>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -88,9 +88,9 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidOutputDimensionNumbers) {
XLA_TEST_F(ConvolutionDimensionNumbersTest,
TwoConvsWithDifferentDimensionNumbers) {
- auto input_array = MakeUnique<Array4D<float>>(2, 3, 5, 5);
+ auto input_array = absl::make_unique<Array4D<float>>(2, 3, 5, 5);
input_array->FillWithMultiples(0.1);
- auto weight_array = MakeUnique<Array4D<float>>(4, 3, 1, 1);
+ auto weight_array = absl::make_unique<Array4D<float>>(4, 3, 1, 1);
weight_array->FillWithMultiples(0.2);
auto weight_data =
client_
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index 5ed8122e00..40658c3b77 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
@@ -26,11 +27,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -70,16 +71,16 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest {
const int kKernelSizeY = 2;
const int kOutputActivationSizeZ = 256;
const int kMiniBatchSize = 4;
- auto alhs =
- MakeUnique<Array4D<T>>(kMiniBatchSize, kInputActivationSizeZ,
- kInputActivationSizeY, kInputActivationSizeX);
+ auto alhs = absl::make_unique<Array4D<T>>(
+ kMiniBatchSize, kInputActivationSizeZ, kInputActivationSizeY,
+ kInputActivationSizeX);
alhs->FillWithMultiples(static_cast<T>(1.0f));
ASSERT_EQ(3, alhs->width());
ASSERT_EQ(3, alhs->height());
- auto arhs =
- MakeUnique<Array4D<T>>(kOutputActivationSizeZ, kInputActivationSizeZ,
- kKernelSizeY, kKernelSizeX);
+ auto arhs = absl::make_unique<Array4D<T>>(kOutputActivationSizeZ,
+ kInputActivationSizeZ,
+ kKernelSizeY, kKernelSizeX);
Array2D<T> rhs_raster({
{1.0f, 0.0f}, // row 0
{0.0f, 0.0f}, // row 1
@@ -465,7 +466,7 @@ void iota_int_init_value(std::vector<T>& values, int init_value) {
}
template <typename T>
-class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest {
+class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest {
public:
void RunTest() {
XlaBuilder builder(TestName());
@@ -520,8 +521,139 @@ class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest {
}
};
-TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x5x5_Valid, TestTypes);
-TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x5_Valid, Types) { this->RunTest(); }
+TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x5x3_Valid, TestTypes);
+TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x3_Valid, Types) { this->RunTest(); }
+
+template <typename T>
+class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
+ public:
+ void RunTest() {
+ XlaBuilder builder(TestName());
+ std::vector<int64> input_dims = {1, 3, 3, 5};
+ std::vector<int64> filter_dims = {3, 3, 1, 15};
+ Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
+ Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
+ {
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
+
+ // Tensorflow dimension numbers for 2D convolution.
+ ConvolutionDimensionNumbers dnums;
+ dnums.set_input_batch_dimension(0);
+ dnums.set_output_batch_dimension(0);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
+ dnums.set_input_feature_dimension(3);
+ dnums.set_output_feature_dimension(3);
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(3);
+
+ ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
+ /*feature_group_count=*/5);
+ }
+
+ std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
+ iota_int_init_value(input_elems, 1);
+ auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
+ auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+
+ std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
+ iota_int_init_value(filter_elems, 1);
+ auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
+ auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+
+ auto expected_r1 = LiteralUtil::CreateR1<T>(
+ {static_cast<T>(16029), static_cast<T>(16218), static_cast<T>(16407),
+ static_cast<T>(17172), static_cast<T>(17370), static_cast<T>(17568),
+ static_cast<T>(18369), static_cast<T>(18576), static_cast<T>(18783),
+ static_cast<T>(19620), static_cast<T>(19836), static_cast<T>(20052),
+ static_cast<T>(20925), static_cast<T>(21150), static_cast<T>(21375)});
+ auto expected_r4 = expected_r1->Reshape({1, 1, 1, 15}).ConsumeValueOrDie();
+
+ auto input_literal =
+ client_->TransferToServer(*input_r4).ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
+
+ ComputeAndCompareLiteral(&builder, *expected_r4,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+ }
+};
+
+TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, TestTypes);
+TYPED_TEST(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, Types) {
+ this->RunTest();
+}
+
+template <typename T>
+class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest {
+ public:
+ void RunTest() {
+ XlaBuilder builder(TestName());
+ std::vector<int64> input_dims = {1, 2, 2, 6};
+ std::vector<int64> filter_dims = {2, 2, 2, 12};
+ Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
+ Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
+ {
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
+
+ // Tensorflow dimension numbers for 2D convolution.
+ ConvolutionDimensionNumbers dnums;
+ dnums.set_input_batch_dimension(0);
+ dnums.set_output_batch_dimension(0);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
+ dnums.set_input_feature_dimension(3);
+ dnums.set_output_feature_dimension(3);
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(3);
+
+ ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
+ /*feature_group_count=*/3);
+ }
+
+ std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
+ iota_int_init_value(input_elems, 1);
+ auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
+ auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+
+ std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
+ iota_int_init_value(filter_elems, 1);
+ auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
+ auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+
+ auto expected_r1 = LiteralUtil::CreateR1<T>(
+ {static_cast<T>(5076), static_cast<T>(5160), static_cast<T>(5244),
+ static_cast<T>(5328), static_cast<T>(6164), static_cast<T>(6264),
+ static_cast<T>(6364), static_cast<T>(6464), static_cast<T>(7380),
+ static_cast<T>(7496), static_cast<T>(7612), static_cast<T>(7728)});
+ auto expected_r4 = expected_r1->Reshape({1, 1, 1, 12}).ConsumeValueOrDie();
+
+ auto input_literal =
+ client_->TransferToServer(*input_r4).ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
+
+ ComputeAndCompareLiteral(&builder, *expected_r4,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+ }
+};
+
+TYPED_TEST_CASE(Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid, TestTypes);
+TYPED_TEST(Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid, Types) {
+ this->RunTest();
+}
// Test fixture to run convolution tests with and without convolution
// canonicalization enabled.
@@ -765,5 +897,44 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) {
std::move(*LiteralUtil::CreateFromArray(filter_data))});
}
+class ConvolutionHloTest : public HloTestBase {};
+
+XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64Forward)) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+ENTRY Test {
+ %arg0 = f64[3,56,56,16] parameter(0)
+ %arg1 = f64[3,3,3,64] parameter(1)
+ ROOT %conv = f64[54,54,16,64] convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf
+})";
+ EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
+}
+
+XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardFilter)) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+ENTRY Test {
+ %arg0 = f64[2,5,8,1] parameter(0)
+ %arg1 = f64[2,5,8,2] parameter(1)
+ ROOT %conv = f64[4,4,1,2] convolution(%arg0, %arg1), window={size=5x8 pad=1_2x1_2}, dim_labels=f01b_i01o->01bf
+})";
+ EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
+}
+
+XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardInput)) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+ENTRY Test {
+ %output = f64[4,5,16,16] parameter(0)
+ %kernel = f64[5,3,7,7] parameter(1)
+ %reverse = f64[5,3,7,7] reverse(f64[5,3,7,7] %kernel), dimensions={2,3}
+ ROOT %convolution = f64[4,3,16,16] convolution(%output, %reverse), window={size=7x7 pad=3_3x3_3}, dim_labels=bf01_io01->bf01
+})";
+ EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc
index 5ef273e5a2..50a9ebc1e9 100644
--- a/tensorflow/compiler/xla/tests/copy_test.cc
+++ b/tensorflow/compiler/xla/tests/copy_test.cc
@@ -16,10 +16,10 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc
index 13c777835e..6f7fc0e6e5 100644
--- a/tensorflow/compiler/xla/tests/custom_call_test.cc
+++ b/tensorflow/compiler/xla/tests/custom_call_test.cc
@@ -16,9 +16,9 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index 792be0d3fc..341124170a 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -22,13 +22,13 @@ limitations under the License.
#define EIGEN_USE_THREADS
+#include "absl/memory/memory.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc
index b77bece85a..f866ed6519 100644
--- a/tensorflow/compiler/xla/tests/gather_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc
@@ -30,8 +30,8 @@ using tensorflow::gtl::nullopt;
class GatherOperationTest : public HloTestBase {
protected:
void RunTest(const string& hlo_text, Literal* operand,
- Literal* gather_indices) {
- RunTest(hlo_text, {operand, gather_indices});
+ Literal* start_indices) {
+ RunTest(hlo_text, {operand, start_indices});
}
void RunTest(const string& hlo_text,
@@ -52,18 +52,17 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[2,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 3}
+ slice_sizes={1, 3}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) {
@@ -74,18 +73,17 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) {
@@ -96,18 +94,18 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,3,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) {
@@ -118,18 +116,18 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=2,
- window_bounds={1, 1}
+ slice_sizes={1, 1}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) {
@@ -140,18 +138,18 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2,2] parameter(1)
ROOT gather = s32[2,1,1,2] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=2,
- window_bounds={1, 1}
+ slice_sizes={1, 1}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) {
@@ -162,20 +160,20 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) {
@@ -186,20 +184,20 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, DynamicSlice) {
@@ -210,18 +208,17 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[1,1] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={0,1},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) {
@@ -232,18 +229,18 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,1,1] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, ZeroDimBounds) {
@@ -254,17 +251,16 @@ ENTRY main {
operand = s32[3,0] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[2,0] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 0}
+ slice_sizes={1, 0}
}
)";
std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) {
@@ -278,19 +274,19 @@ ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = s32[6,2]{1,0} parameter(1)
gather = s32[6,1,1]{2,1,0} gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
ROOT result = s32[6]{0} reshape(gather)
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) {
@@ -304,19 +300,19 @@ ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = u32[6,2]{1,0} parameter(1)
gather = s32[6,1,1]{2,1,0} gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
ROOT result = s32[6]{0} reshape(gather)
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<uint32>(
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<uint32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, NegativeIndex) {
@@ -330,19 +326,19 @@ ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = s32[6,2]{1,0} parameter(1)
gather = s32[6,1,1]{2,1,0} gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
ROOT result = s32[6]{0} reshape(gather)
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
{{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) {
@@ -356,19 +352,19 @@ ENTRY main {
operand = u32[3,3]{1,0} parameter(0)
indices = s32[6,2]{1,0} parameter(1)
gather = u32[6,1,1]{2,1,0} gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
ROOT result = u32[6]{0} reshape(gather)
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<uint32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
{{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, OneScalarIndex) {
@@ -379,17 +375,17 @@ ENTRY main {
operand = s32[2,3,2]{2,1,0} parameter(0)
index = s32[] parameter(1)
ROOT gather = s32[1,3,2]{2,1,0} gather(operand, index),
- output_window_dims={0,1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0},
+ offset_dims={0,1,2},
+ collapsed_slice_dims={},
+ start_index_map={0},
index_vector_dim=0,
- window_bounds={1,3,2}
+ slice_sizes={1,3,2}
}
)";
std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR0<int32>(1);
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR0<int32>(1);
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, ScalarResult) {
@@ -400,16 +396,16 @@ ENTRY main {
operand = s32[4]{0} parameter(0)
index = s32[] parameter(1)
ROOT gather = s32[] gather(operand, index),
- output_window_dims={},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=0,
- window_bounds={1}
+ slice_sizes={1}
}
)";
std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR0<int32>(1);
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR0<int32>(1);
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, ZeroSizedResult) {
@@ -420,17 +416,17 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[0] parameter(1)
ROOT gather = s32[0,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 3}
+ slice_sizes={1, 3}
}
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR1<int32>({});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) {
@@ -441,11 +437,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
one = s32[] constant(1)
one_broadcasted = s32[3,2] broadcast(one), dimensions={}
ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted)
@@ -453,9 +449,8 @@ ENTRY main {
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) {
@@ -466,11 +461,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,3,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
one = s32[] constant(1)
one_broadcasted = s32[2,3,2] broadcast(one), dimensions={}
ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted)
@@ -478,9 +473,9 @@ ENTRY main {
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) {
@@ -491,11 +486,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=2,
- window_bounds={1, 1}
+ slice_sizes={1, 1}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -503,9 +498,9 @@ ENTRY main {
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) {
@@ -516,11 +511,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -530,9 +525,9 @@ ENTRY main {
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest,
@@ -544,11 +539,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -558,9 +553,9 @@ ENTRY main {
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) {
@@ -571,11 +566,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
gather = s32[1,1] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={0,1},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
one = s32[] constant(1)
one_broadcasted = s32[1,1] broadcast(one), dimensions={}
ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted)
@@ -583,9 +578,8 @@ ENTRY main {
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) {
@@ -596,11 +590,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,1,1] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
one = s32[] constant(1)
one_broadcasted = s32[2,1,1] broadcast(one), dimensions={}
ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted)
@@ -608,9 +602,9 @@ ENTRY main {
)";
std::unique_ptr<Literal> operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ std::unique_ptr<Literal> start_indices =
LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, operand.get(), start_indices.get());
}
class GatherClientLibraryTest : public ClientLibraryTestBase {};
@@ -622,11 +616,11 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
// operand = s32[3,3] parameter(0)
// indices = s32[2] parameter(1)
// ROOT gather = s32[2,3] gather(operand, indices),
- // output_window_dims={1},
- // elided_window_dims={0},
- // gather_dims_to_operand_dims={0},
+ // offset_dims={1},
+ // collapsed_slice_dims={0},
+ // start_index_map={0},
// index_vector_dim=1,
- // window_bounds={1, 3}
+ // slice_sizes={1, 3}
// }
XlaBuilder builder("gather_basic");
@@ -637,9 +631,9 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
auto operand = Parameter(&builder, 0, operand_shape, "operand");
auto indices = Parameter(&builder, 1, indices_shape, "indices");
GatherDimensionNumbers dim_numbers;
- dim_numbers.add_output_window_dims(1);
- dim_numbers.add_elided_window_dims(0);
- dim_numbers.add_gather_dims_to_operand_dims(0);
+ dim_numbers.add_offset_dims(1);
+ dim_numbers.add_collapsed_slice_dims(0);
+ dim_numbers.add_start_index_map(0);
dim_numbers.set_index_vector_dim(1);
Gather(operand, indices, dim_numbers, {1, 3});
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index f05d1a8b9d..2167d4240e 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -20,12 +20,15 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
@@ -92,15 +95,29 @@ HloTestBase::HloTestBase(se::Platform* test_platform,
bool allow_mixed_precision_in_hlo_verifier)
: test_runner_(test_platform), reference_runner_(reference_platform) {
hlo_verifier_ =
- MakeUnique<HloVerifier>(allow_mixed_precision_in_hlo_verifier);
+ absl::make_unique<HloVerifier>(allow_mixed_precision_in_hlo_verifier);
}
-/* static */
std::unique_ptr<HloModule> HloTestBase::CreateNewModule(const string& name) {
- return MakeUnique<HloModule>(name, GetModuleConfigForTest());
+ return absl::make_unique<HloModule>(name, GetModuleConfigForTest());
+}
+
+/* static */
+StatusOr<bool> HloTestBase::RunHloPass(HloPassInterface* hlo_pass,
+ HloModule* module) {
+ const string module_str_before_run = module->ToProto().ShortDebugString();
+ const auto status_or = hlo_pass->Run(module);
+ if (status_or.status().ok()) {
+ const string module_str_after_run = module->ToProto().ShortDebugString();
+ if (!status_or.ValueOrDie()) {
+ // Check that the proto remains same.
+ EXPECT_EQ(module_str_after_run, module_str_before_run);
+ }
+ }
+ return status_or;
}
-/*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() {
+DebugOptions HloTestBase::GetDebugOptionsForTest() {
auto debug_options = legacy_flags::GetDebugOptionsFromFlags();
// TODO(b/38354253): Change tests to use Parameters instead of Constants.
debug_options.add_xla_disable_hlo_passes("constant_folding");
@@ -199,7 +216,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
MakeFakeArguments(module.get()).ConsumeValueOrDie();
std::vector<Literal*> fake_argument_ptrs;
- c_transform(
+ absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
[](const std::unique_ptr<Literal>& literal) { return literal.get(); });
@@ -213,7 +230,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
const auto& fake_arguments =
MakeFakeArguments(module.get()).ConsumeValueOrDie();
std::vector<Literal*> fake_argument_ptrs;
- c_transform(
+ absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
[](const std::unique_ptr<Literal>& literal) { return literal.get(); });
@@ -248,7 +265,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
MakeFakeArguments(module_or_status.ValueOrDie().get())
.ConsumeValueOrDie();
std::vector<Literal*> fake_argument_ptrs;
- c_transform(
+ absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
[](const std::unique_ptr<Literal>& literal) { return literal.get(); });
return test_runner_
@@ -303,8 +320,8 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
HloComputation* HloTestBase::FindComputation(HloModule* module,
tensorflow::StringPiece name) {
auto computations = module->computations();
- auto it = c_find_if(computations,
- [&](HloComputation* c) { return c->name() == name; });
+ auto it = absl::c_find_if(
+ computations, [&](HloComputation* c) { return c->name() == name; });
if (it == computations.end()) {
return nullptr;
}
@@ -315,8 +332,8 @@ HloInstruction* HloTestBase::FindInstruction(HloModule* module,
tensorflow::StringPiece name) {
for (const HloComputation* c : module->computations()) {
auto instructions = c->instructions();
- auto it = c_find_if(instructions,
- [&](HloInstruction* i) { return i->name() == name; });
+ auto it = absl::c_find_if(
+ instructions, [&](HloInstruction* i) { return i->name() == name; });
if (it != instructions.end()) {
return *it;
}
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index 4232eeceb1..5c7304b4de 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -72,8 +72,13 @@ class HloTestBase : public ::testing::Test {
// options from command-line flags. If you want a fresh HloModule object and
// then add HloComputations to it, it's recommended to use this method in your
// tests.
- static std::unique_ptr<HloModule> CreateNewModule(
- const string& name = TestName());
+ std::unique_ptr<HloModule> CreateNewModule(const string& name = TestName());
+
+ // Runs the hlo_pass with the provided module and returns the result. This
+ // function also verifies that the module remains unchanged when hlo_pass
+ // returns false as the StatusOr value.
+ static StatusOr<bool> RunHloPass(HloPassInterface* hlo_pass,
+ HloModule* module);
protected:
// This uses the interpreter backend as the reference backend and
@@ -93,10 +98,13 @@ class HloTestBase : public ::testing::Test {
// Populates debug options from command-line flags and adjusts the options for
// testing. It is recommended to use this when you need to pass in
// DebugOptions, e.g. when creating a module from a string or a file.
- static DebugOptions GetDebugOptionsForTest();
+ //
+ // This function is virtual so tests can specify an alternative set of debug
+ // options (e.g. disabling additional passes).
+ virtual DebugOptions GetDebugOptionsForTest();
// Gets an HloModuleConfig with options appropriate for tests.
- static HloModuleConfig GetModuleConfigForTest() {
+ HloModuleConfig GetModuleConfigForTest() {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
return config;
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
index ad1f5b9eed..a509ee3207 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -25,7 +26,7 @@ limitations under the License.
namespace xla {
HloVerifiedTestBase::HloVerifiedTestBase()
- : shape_verifier_(MakeUnique<ShapeVerifier>()) {}
+ : shape_verifier_(absl::make_unique<ShapeVerifier>()) {}
HloVerifiedTestBase::~HloVerifiedTestBase() {
// We can't call the ASSERT or EXPECT test macros in destructors, so we
diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
index e719da54d4..8d65869557 100644
--- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
+++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
@@ -125,7 +126,7 @@ class LLVMCompilerTest : public ::testing::Test {
static std::unique_ptr<HloModule> CreateNewModule() {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
- return MakeUnique<HloModule>(TestName(), config);
+ return absl::make_unique<HloModule>(TestName(), config);
}
};
diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
index 6fc1115097..0487d31409 100644
--- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
+++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
@@ -51,8 +51,9 @@ void LlvmIrGenTestBase::CompileAndVerifyIr(
std::unique_ptr<HloModule> hlo_module, const string& pattern,
bool match_optimized_ir) {
SetIrHook(match_optimized_ir);
- TF_ASSERT_OK(CompileToExecutable(std::move(hlo_module)).status());
+ Status status = CompileToExecutable(std::move(hlo_module)).status();
ResetIrHook();
+ TF_ASSERT_OK(status);
StatusOr<bool> filecheck_result = RunFileCheck(ir_, pattern);
TF_ASSERT_OK(filecheck_result.status());
@@ -73,9 +74,10 @@ void LlvmIrGenTestBase::CompileAheadOfTimeAndVerifyIr(
std::unique_ptr<HloModule> hlo_module, const AotCompilationOptions& options,
const string& pattern, bool match_optimized_ir) {
SetIrHook(match_optimized_ir);
- TF_ASSERT_OK(
- CompileToAotCompilationResult(std::move(hlo_module), options).status());
+ Status status =
+ CompileToAotCompilationResult(std::move(hlo_module), options).status();
ResetIrHook();
+ TF_ASSERT_OK(status);
StatusOr<bool> filecheck_result = RunFileCheck(ir_, pattern);
ASSERT_TRUE(filecheck_result.ok());
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc
index eaddf756db..948b60061e 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.cc
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc
@@ -18,11 +18,11 @@ limitations under the License.
#include <vector>
+#include "absl/memory/memory.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test_helpers.h"
diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
index da8c42d465..b6035a21a6 100644
--- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
@@ -17,12 +17,12 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -133,7 +133,7 @@ class TestLinspaceMaxParametric
float from = -128.0, to = 256.0;
std::unique_ptr<Array2D<T>> alhs =
MakeLinspaceArray2D<T>(from, to, rows, cols);
- auto arhs = MakeUnique<Array2D<T>>(rows, cols, static_cast<T>(1.0f));
+ auto arhs = absl::make_unique<Array2D<T>>(rows, cols, static_cast<T>(1.0f));
XlaBuilder builder(
tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols));
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index eb06b115da..cadf1c5523 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -19,10 +19,10 @@ limitations under the License.
#include <new>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc
index ca21b0b2ba..cbeddffacf 100644
--- a/tensorflow/compiler/xla/tests/pad_test.cc
+++ b/tensorflow/compiler/xla/tests/pad_test.cc
@@ -16,12 +16,12 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
@@ -140,7 +140,7 @@ XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) {
TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) {
XlaBuilder b(TestName());
- auto input = MakeUnique<Array4D<float>>(1, 1, 3, 2);
+ auto input = absl::make_unique<Array4D<float>>(1, 1, 3, 2);
Array2D<float> input_xy({
{1.0f, 2.0f}, // row 0
{3.0f, 4.0f}, // row 1
@@ -151,7 +151,7 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) {
Pad(AddParam(*input, &b), AddParam(*LiteralUtil::CreateR0<float>(1.5), &b),
r4_padding_on_dim0_dim1_);
- auto expected = MakeUnique<Array4D<float>>(2, 3, 3, 2);
+ auto expected = absl::make_unique<Array4D<float>>(2, 3, 3, 2);
expected->Fill(1.5);
(*expected)(1, 0, 0, 0) = 1.0f;
(*expected)(1, 0, 0, 1) = 2.0f;
@@ -171,7 +171,7 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) {
AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b),
r4_padding_on_dim0_dim1_);
- auto expected = MakeUnique<Array4D<float>>(8, 5, 1, 1);
+ auto expected = absl::make_unique<Array4D<float>>(8, 5, 1, 1);
expected->Fill(pad_value);
(*expected)(1, 0, 0, 0) = 1.0f;
(*expected)(1, 2, 0, 0) = 2.0f;
@@ -269,7 +269,7 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) {
XLA_TEST_F(PadTest, Pad4DU8Array) {
XlaBuilder b(TestName());
- auto input = MakeUnique<Array4D<uint8>>(1, 1, 3, 2);
+ auto input = absl::make_unique<Array4D<uint8>>(1, 1, 3, 2);
Array2D<uint8> input_xy({
{1, 2}, // row 0
{3, 4}, // row 1
@@ -280,7 +280,7 @@ XLA_TEST_F(PadTest, Pad4DU8Array) {
Pad(AddParam(*input, &b), ConstantR0<uint8>(&b, 35),
r4_padding_on_dim0_dim1_);
- auto expected = MakeUnique<Array4D<uint8>>(2, 3, 3, 2);
+ auto expected = absl::make_unique<Array4D<uint8>>(2, 3, 3, 2);
expected->Fill(35);
(*expected)(1, 0, 0, 0) = 1;
(*expected)(1, 0, 0, 1) = 2;
@@ -301,13 +301,13 @@ XLA_TEST_F(PadTest, Pad4DPredArray) {
Pad(input, ConstantR0<bool>(&b, false), r4_padding_on_dim0_dim1_);
// For the same reason, use Select to convert boolean values to int32.
- auto zeros = MakeUnique<Array4D<int32>>(2, 3, 3, 2);
- auto ones = MakeUnique<Array4D<int32>>(2, 3, 3, 2);
+ auto zeros = absl::make_unique<Array4D<int32>>(2, 3, 3, 2);
+ auto ones = absl::make_unique<Array4D<int32>>(2, 3, 3, 2);
zeros->Fill(0);
ones->Fill(1);
Select(padded, AddParam(*ones, &b), AddParam(*zeros, &b));
- auto expected = MakeUnique<Array4D<int32>>(2, 3, 3, 2);
+ auto expected = absl::make_unique<Array4D<int32>>(2, 3, 3, 2);
expected->Fill(0);
(*expected)(1, 0, 0, 0) = 1;
(*expected)(1, 0, 0, 1) = 1;
@@ -321,7 +321,7 @@ XLA_TEST_F(PadTest, Pad4DPredArray) {
XLA_TEST_P(PadTestFloat, Large2DPad) {
XlaBuilder b(TestName());
- auto ones = MakeUnique<Array2D<float>>(4, 4);
+ auto ones = absl::make_unique<Array2D<float>>(4, 4);
ones->Fill(1.0f);
auto input = AddParam(*ones, &b);
PaddingConfig padding_config = MakeNoPaddingConfig(2);
@@ -342,7 +342,7 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) {
constexpr int64 in_rows = 35;
constexpr int64 in_cols = 35;
- auto operand = MakeUnique<Array2D<float>>(in_rows, in_cols);
+ auto operand = absl::make_unique<Array2D<float>>(in_rows, in_cols);
operand->FillUnique(0.0f);
auto input = AddParam(*operand, &b);
@@ -368,7 +368,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) {
constexpr int64 low_padding = 0;
int64 high_padding[2] = {5, 7};
constexpr int64 interior_padding = 0;
- auto operand = MakeUnique<Array2D<float>>(in_rows, in_cols);
+ auto operand = absl::make_unique<Array2D<float>>(in_rows, in_cols);
operand->FillUnique(1.0f);
auto input = AddParam(*operand, &b);
PaddingConfig padding_config = MakeNoPaddingConfig(2);
@@ -395,7 +395,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) {
int64 low_padding[2] = {-1, -2};
int64 high_padding[2] = {-3, 4};
constexpr int64 interior_padding = 0;
- auto operand = MakeUnique<Array2D<float>>(in_rows, in_cols);
+ auto operand = absl::make_unique<Array2D<float>>(in_rows, in_cols);
operand->FillUnique(1.0f);
auto input = AddParam(*operand, &b);
PaddingConfig padding_config = MakeNoPaddingConfig(2);
@@ -423,7 +423,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) {
int64 low_padding[2] = {4, -1};
int64 high_padding[2] = {-2, -4};
int64 interior_padding[2] = {1, 2};
- auto operand = MakeUnique<Array2D<float>>(in_rows, in_cols);
+ auto operand = absl::make_unique<Array2D<float>>(in_rows, in_cols);
operand->FillUnique(1.0f);
auto input = AddParam(*operand, &b);
PaddingConfig padding_config = MakeNoPaddingConfig(2);
@@ -446,7 +446,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) {
// Regression test for b/31827337.
XLA_TEST_P(PadTestFloat, ReducePad) {
XlaBuilder b(TestName());
- auto ones = MakeUnique<Array4D<float>>(2, 2, 2, 2);
+ auto ones = absl::make_unique<Array4D<float>>(2, 2, 2, 2);
ones->Fill(1.0);
auto input = AddParam(*ones, &b);
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 92c93f08b2..09acadb2c2 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <limits>
#include <memory>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -357,7 +358,7 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) {
std::vector<int64> input_dims(6, 8);
auto shape = ShapeUtil::MakeShape(F32, input_dims);
- auto arg_literal = MakeUnique<Literal>(shape);
+ auto arg_literal = absl::make_unique<Literal>(shape);
arg_literal->PopulateWithValue(1.0f);
const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
@@ -368,7 +369,7 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) {
std::vector<int64> output_dims = {6, 8, 6, 6, 8, 8};
Shape result_shape =
ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout);
- auto expected = MakeUnique<Literal>(result_shape);
+ auto expected = absl::make_unique<Literal>(result_shape);
expected->PopulateWithValue(27.0f);
ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
}
@@ -1261,6 +1262,12 @@ struct R1ReduceWindowTestData {
/*pad_low=*/{5},
/*pad_high=*/{0},
/*reducer=*/Reducer::kAdd},
+
+ {/*base_bounds=*/{4096}, /*window_bounds=*/{4096},
+ /*strides=*/{1},
+ /*pad_low=*/{4095},
+ /*pad_high=*/{0},
+ /*reducer=*/Reducer::kMax},
};
string R1ReduceWindowTestDataToString(
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index faeec657b6..2f1d97b25d 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include <cmath>
+
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
namespace xla {
@@ -26,89 +29,101 @@ namespace {
template <typename FloatT, typename GeneratorT>
void PopulateWithRandomFloatingPointDataImpl(Literal* literal,
- std::minstd_rand0* engine) {
+ std::minstd_rand0* engine,
+ bool no_duplicates) {
CHECK(engine != nullptr);
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<FloatT>());
- // Create uniform numbers between 1 and 1.125 to avoid creating denormal
- // numbers.
- std::uniform_real_distribution<GeneratorT> generator(1.0f, 1.125f);
- const bool should_index_bias = ShapeUtil::ElementsIn(literal->shape()) > 1000;
- TF_CHECK_OK(literal->Populate<FloatT>(
- [&](tensorflow::gtl::ArraySlice<int64> indices) {
- // Generate a random uniform number from -0.0625 and 0.0625 and bias it
- // with a position dependent number with mean 0.037109375. These number
- // should allow for long chains of accumulation without being too close
- // to zero or too large to accumulate all numbers accurately. Only do
- // this for large literals where the number of elements is much greater
- // than 47 otherwise only negative values are produced.
- //
- // The value is positionally biased using a product of the indices. Add
- // one to each index value to avoid collapsing to zero if any of the
- // indices are zero.
- int64 index_product = 1;
- for (int64 i : indices) {
- index_product *= (1 + i);
- }
- const int64 negative_bias = should_index_bias ? 47 : 0;
- FloatT index_bias =
- static_cast<FloatT>(index_product % 113 - negative_bias) /
- static_cast<FloatT>(256.0f);
- return static_cast<FloatT>(generator(*engine) - 1.0625f) + index_bias;
- }));
+ if (no_duplicates) {
+ // Duplicates may be generated if the number of elements in the literal
+ // exceeds the number of positive values supported by the type.
+ FloatT next_value = std::numeric_limits<FloatT>::min();
+ for (FloatT& value : literal->data<FloatT>()) {
+ value = next_value;
+ next_value =
+ std::nextafter(next_value, std::numeric_limits<FloatT>::max());
+ }
+ std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
+ *engine);
+ } else {
+ std::uniform_real_distribution<GeneratorT> generator(-0.1f, 0.2f);
+ for (FloatT& value : literal->data<FloatT>()) {
+ value = static_cast<FloatT>(generator(*engine));
+ }
+ }
}
template <typename FloatT>
void PopulateWithRandomFloatingPointData(Literal* literal,
- std::minstd_rand0* engine) {
+ std::minstd_rand0* engine,
+ bool no_duplicates) {
CHECK(engine != nullptr);
- PopulateWithRandomFloatingPointDataImpl<FloatT, FloatT>(literal, engine);
+ PopulateWithRandomFloatingPointDataImpl<FloatT, FloatT>(literal, engine,
+ no_duplicates);
}
template <>
void PopulateWithRandomFloatingPointData<half>(Literal* literal,
- std::minstd_rand0* engine) {
+ std::minstd_rand0* engine,
+ bool no_duplicates) {
+ // no_duplicates is ignored for half types. Unique values can only be
+ // generated for arrays with fewer than ~2**16 elements and no_duplicates is
+ // best-effort anyway.
CHECK(engine != nullptr);
- PopulateWithRandomFloatingPointDataImpl<half, float>(literal, engine);
+ std::uniform_real_distribution<float> generator(-0.1f, 0.2f);
+ for (half& value : literal->data<half>()) {
+ value = static_cast<half>(generator(*engine));
+ }
}
-// The standard library does not have a case for bfloat16, unsurprisingly, so we
-// handle that one specially.
template <>
void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal,
- std::minstd_rand0* engine) {
+ std::minstd_rand0* engine,
+ bool no_duplicates) {
+ // no_duplicates is ignored for bfloat types. Unique values can only be
+ // generated for arrays with fewer than ~2**16 elements and no_duplicates is
+ // best-effort anyway.
CHECK(engine != nullptr);
- CHECK_EQ(literal->shape().element_type(), BF16);
- std::uniform_real_distribution<float> generator(-0.9f, 1.0f);
- TF_CHECK_OK(literal->Populate<bfloat16>(
- [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
- return static_cast<bfloat16>(generator(*engine));
- }));
+ std::uniform_real_distribution<float> generator(-0.1f, 0.2f);
+ for (bfloat16& value : literal->data<bfloat16>()) {
+ value = static_cast<bfloat16>(generator(*engine));
+ }
}
template <typename IntT>
-void PopulateWithRandomIntegralData(Literal* literal,
- std::minstd_rand0* engine) {
+void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine,
+ bool no_duplicates) {
CHECK(engine != nullptr);
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<IntT>());
- std::uniform_int_distribution<IntT> generator(
- std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
- TF_CHECK_OK(literal->Populate<IntT>(
- [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
- return generator(*engine);
- }));
+ if (no_duplicates && ShapeUtil::ElementsIn(literal->shape()) <
+ std::numeric_limits<IntT>::max()) {
+ std::iota(literal->data<IntT>().begin(), literal->data<IntT>().end(), 0);
+ std::shuffle(literal->data<IntT>().begin(), literal->data<IntT>().end(),
+ *engine);
+ } else {
+ std::uniform_int_distribution<IntT> generator(
+ std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
+ for (IntT& value : literal->data<IntT>()) {
+ value = generator(*engine);
+ }
+ }
}
// Similar to MakeFakeLiteral but takes a random number generator engine to
-// enable reusing the engine across randomly generated literals.
+// enable reusing the engine across randomly generated literals. 'no_duplicates'
+// indicates that there should be no duplicate values in each generated
+// array. This is uniqueness is best-effort only. Some types (half and bfloat16)
+// are not supported and uniqueness cannot be guaranteed if the number of
+// elements exceeds the number of different values supported by the type.
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
- const Shape& shape, std::minstd_rand0* engine) {
+ const Shape& shape, std::minstd_rand0* engine, bool no_duplicates) {
if (ShapeUtil::IsTuple(shape)) {
std::vector<std::unique_ptr<Literal>> elements;
for (const Shape& element_shape : shape.tuple_shapes()) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
- MakeFakeLiteralInternal(element_shape, engine));
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Literal> element,
+ MakeFakeLiteralInternal(element_shape, engine, no_duplicates));
elements.push_back(std::move(element));
}
return LiteralUtil::MakeTupleOwned(std::move(elements));
@@ -116,43 +131,55 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
if (engine == nullptr) {
return Literal::CreateFromShape(shape);
}
- auto literal = MakeUnique<Literal>(shape);
+ auto literal = absl::make_unique<Literal>(shape);
switch (shape.element_type()) {
case BF16:
- PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine);
+ PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine,
+ no_duplicates);
break;
case F16:
- PopulateWithRandomFloatingPointData<half>(literal.get(), engine);
+ PopulateWithRandomFloatingPointData<half>(literal.get(), engine,
+ no_duplicates);
break;
case F32:
- PopulateWithRandomFloatingPointData<float>(literal.get(), engine);
+ PopulateWithRandomFloatingPointData<float>(literal.get(), engine,
+ no_duplicates);
break;
case F64:
- PopulateWithRandomFloatingPointData<double>(literal.get(), engine);
+ PopulateWithRandomFloatingPointData<double>(literal.get(), engine,
+ no_duplicates);
break;
case S8:
- PopulateWithRandomIntegralData<int8>(literal.get(), engine);
+ PopulateWithRandomIntegralData<int8>(literal.get(), engine,
+ no_duplicates);
break;
case U8:
- PopulateWithRandomIntegralData<uint8>(literal.get(), engine);
+ PopulateWithRandomIntegralData<uint8>(literal.get(), engine,
+ no_duplicates);
break;
case S16:
- PopulateWithRandomIntegralData<int16>(literal.get(), engine);
+ PopulateWithRandomIntegralData<int16>(literal.get(), engine,
+ no_duplicates);
break;
case U16:
- PopulateWithRandomIntegralData<uint16>(literal.get(), engine);
+ PopulateWithRandomIntegralData<uint16>(literal.get(), engine,
+ no_duplicates);
break;
case S32:
- PopulateWithRandomIntegralData<int32>(literal.get(), engine);
+ PopulateWithRandomIntegralData<int32>(literal.get(), engine,
+ no_duplicates);
break;
case U32:
- PopulateWithRandomIntegralData<uint32>(literal.get(), engine);
+ PopulateWithRandomIntegralData<uint32>(literal.get(), engine,
+ no_duplicates);
break;
case S64:
- PopulateWithRandomIntegralData<int64>(literal.get(), engine);
+ PopulateWithRandomIntegralData<int64>(literal.get(), engine,
+ no_duplicates);
break;
case U64:
- PopulateWithRandomIntegralData<uint64>(literal.get(), engine);
+ PopulateWithRandomIntegralData<uint64>(literal.get(), engine,
+ no_duplicates);
break;
case PRED: {
std::uniform_int_distribution<int> generator(0, 1);
@@ -250,6 +277,11 @@ std::vector<HloInstruction*> FindConstrainedUses(
auto converted_uses = FindConstrainedUses(dataflow, *instruction);
constrained_uses.insert(constrained_uses.end(), converted_uses.begin(),
converted_uses.end());
+ } else if (opcode == HloOpcode::kSort &&
+ instruction->operand_count() == 2 && op_num == 0) {
+ // Operand 0 of sort is the array of keys used for key/value
+ // (two-operand) kSort instructions.
+ constrained_uses.push_back(instruction);
}
}
}
@@ -264,6 +296,7 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses,
const HloInstruction& param, std::minstd_rand0* engine) {
std::vector<int64> index_space;
+ bool no_duplicates = false;
bool needs_constant = false;
ConstantType constant_type = ConstantType::kUnknown;
for (HloInstruction* use : constrained_uses) {
@@ -302,16 +335,22 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
constant_type = GetInitValue(*use->scatter());
break;
+ case HloOpcode::kSort:
+ no_duplicates = true;
+ break;
+
default:
return Unimplemented(
"Constrained operand generation not implemented for %s.",
use->ToString().c_str());
}
}
- if (!index_space.empty() && needs_constant) {
- return Unimplemented(
- "Conflicting operand generation constraints. Dynamically indexes a "
- "shape and is the init value of a reduction.");
+ int constraint_count = 0;
+ constraint_count += no_duplicates ? 1 : 0;
+ constraint_count += !index_space.empty() ? 1 : 0;
+ constraint_count += needs_constant ? 1 : 0;
+ if (constraint_count > 1) {
+ return Unimplemented("Conflicting operand generation constraints.");
}
if (!index_space.empty()) {
return MakeRandomIndex(index_space, engine);
@@ -324,10 +363,11 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
case ConstantType::kUnknown:
// We want the identity element for the computation, but we don't really
// know what it is - so any value we generate will be just as wrong.
- return MakeFakeLiteralInternal(param.shape(), engine);
+ return MakeFakeLiteralInternal(param.shape(), engine,
+ /*no_duplicates=*/false);
}
} else {
- return MakeFakeLiteralInternal(param.shape(), engine);
+ return MakeFakeLiteralInternal(param.shape(), engine, no_duplicates);
}
}
@@ -344,19 +384,26 @@ StatusOr<std::unique_ptr<Literal>> MakeConstrainedArgument(
StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
bool pseudo_random) {
- auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr;
- return MakeFakeLiteralInternal(shape, engine.get());
+ auto engine =
+ pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
+ return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false);
}
StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
HloModule* const module, bool pseudo_random) {
+ auto engine =
+ pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
+ return MakeFakeArguments(module, engine.get());
+}
+
+StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
+ HloModule* const module, std::minstd_rand0* engine) {
TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
const auto params = module->entry_computation()->parameter_instructions();
- auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr;
std::vector<std::unique_ptr<Literal>> arguments(params.size());
for (int i = 0; i < params.size(); ++i) {
- arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], engine.get())
- .ValueOrDie();
+ arguments[i] =
+ MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie();
}
return std::move(arguments);
}
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
index e59f215a9a..1aca1d8ef7 100644
--- a/tensorflow/compiler/xla/tests/test_utils.h
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -20,9 +20,9 @@ limitations under the License.
#include <memory>
#include <random>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -63,8 +63,17 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
// Generates a vector of arguments containing fake data. The number, shape and
// layout of the arguments is appropriate for given HLO module.
//
-// Will handle special cases such as making sure that indices used for dynamic
-// slices are bounded, reduces that call adds use 0 as an init value, etc.
+// A best-effort attempt is made to generate the data in a way which produce
+// stable computation results across platforms. Specifically:
+//
+// (1) Init values of reductions should be the identity of the reduction
+// computation.
+//
+// (2) Indices of dynamic slices and update slices should be in bounds.
+//
+// (3) Keys of key/value sorts should contain no duplicates.
+//
+// These constraints are best-effort only.
//
// If pseudo_random is true, the generated numbers will be generated
// deterministically in a pseudo random way unless the values are constrated to
@@ -78,6 +87,12 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
HloModule* const module, bool pseudo_random = true);
+// Overload which accepts a random number generator. This enables generation of
+// different random values with sequential calls to MakeFakeArguments by reusing
+// the same generator.
+StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
+ HloModule* const module, std::minstd_rand0* engine);
+
// Check that a given module satisfies various constraints before trying to
// execute it.
Status VerifyHloModule(HloModule* const module,
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
index 64d9e2031e..322c8ef090 100644
--- a/tensorflow/compiler/xla/tests/test_utils_test.cc
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
@@ -127,5 +128,51 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
EXPECT_LE(index_arg.Get<int32>({2}), 3);
}
+XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) {
+ // Inputs which are sort keys in key/value sorts should have no duplicates.
+ auto module = ParseHloString(R"(
+HloModule sort.148.1589
+
+ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (f32[1048576], s32[1048576]) {
+ %parameter.0 = f32[1048576]{0} parameter(0)
+ %parameter.1 = s32[1048576]{0} parameter(1)
+ ROOT %sort.148.1589 = (f32[1048576]{0}, s32[1048576]{0}) sort(f32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0}
+}
+)")
+ .ValueOrDie();
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ MakeFakeArguments(module.get()));
+ ASSERT_EQ(args.size(), 2);
+ const Literal& key_arg = *args[0];
+
+ tensorflow::gtl::FlatSet<uint32> key_set;
+ for (const float& value : key_arg.data<float>()) {
+ EXPECT_TRUE(key_set.insert(tensorflow::bit_cast<uint32>(value)).second);
+ }
+}
+
+XLA_TEST_F(TestUtilsTest, NoDuplicatesInt32) {
+ // Inputs which are sort keys in key/value sorts should have no duplicates.
+ auto module = ParseHloString(R"(
+HloModule sort.148.1589
+
+ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (s32[1048576], s32[1048576]) {
+ %parameter.0 = s32[1048576]{0} parameter(0)
+ %parameter.1 = s32[1048576]{0} parameter(1)
+ ROOT %sort.148.1589 = (s32[1048576]{0}, s32[1048576]{0}) sort(s32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0}
+}
+)")
+ .ValueOrDie();
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ MakeFakeArguments(module.get()));
+ ASSERT_EQ(args.size(), 2);
+ const Literal& key_arg = *args[0];
+
+ tensorflow::gtl::FlatSet<int32> key_set;
+ for (const int32& value : key_arg.data<int32>()) {
+ EXPECT_TRUE(key_set.insert(tensorflow::bit_cast<uint32>(value)).second);
+ }
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
index 97bbf80aff..c101cd2d20 100644
--- a/tensorflow/compiler/xla/tests/tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include <initializer_list>
#include <memory>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -504,7 +505,7 @@ XLA_TEST_F(TupleTest, ComplexTuples) {
LiteralUtil::CreateR2<complex64>({{{111, 222}, {331, 442}},
{{1011, 2022}, {3031, 4042}},
{{10011, 20022}, {30031, 40042}}});
- auto prod = MakeUnique<Literal>(sum->shape());
+ auto prod = absl::make_unique<Literal>(sum->shape());
ASSERT_TRUE(prod->Populate<complex64>(
[&sum](tensorflow::gtl::ArraySlice<int64> indexes) {
return sum->Get<complex64>(indexes) *
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index 11f3efb1f3..e12e095ecd 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -116,7 +117,7 @@ Status ParseOneProfileOutputLine(
", Regexp: ", regexp_pattern);
}
- if (!c_linear_search(opcodes_to_ignore, parsed_line.opcode)) {
+ if (!absl::c_linear_search(opcodes_to_ignore, parsed_line.opcode)) {
InsertOrDie(parsed_results, parsed_line.opcode, parsed_line);
}
@@ -294,7 +295,7 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) {
tensorflow::str_util::Split(profile_output, '\n');
auto while_body_profile_start =
- c_find_if(profile_output_lines, [](tensorflow::StringPiece s) {
+ absl::c_find_if(profile_output_lines, [](tensorflow::StringPiece s) {
return tensorflow::str_util::StartsWith(s,
"Execution profile for body");
});
diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc
index 897123d760..7de2c39b38 100644
--- a/tensorflow/compiler/xla/text_literal_reader.cc
+++ b/tensorflow/compiler/xla/text_literal_reader.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -102,7 +102,7 @@ StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
ShapeUtil::HumanString(shape).c_str());
}
- auto result = MakeUnique<Literal>(shape);
+ auto result = absl::make_unique<Literal>(shape);
const float fill = std::numeric_limits<float>::quiet_NaN();
result->PopulateWithValue<float>(fill);
std::vector<tensorflow::StringPiece> pieces;
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index 5ae099a462..cc07346ee5 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -24,6 +24,7 @@ limitations under the License.
#include <type_traits>
#include <vector>
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -434,122 +435,15 @@ std::vector<std::pair<int64, int64>> CommonFactors(
// Removes illegal characters from filenames.
string SanitizeFileName(string file_name);
-template <typename Container, typename Predicate>
-bool c_all_of(const Container& container, Predicate&& predicate) {
- return std::all_of(std::begin(container), std::end(container),
- std::forward<Predicate>(predicate));
-}
-
-template <typename Container, typename Predicate>
-bool c_any_of(const Container& container, Predicate&& predicate) {
- return std::any_of(std::begin(container), std::end(container),
- std::forward<Predicate>(predicate));
-}
-
-template <typename InputContainer, typename OutputIterator,
- typename UnaryOperation>
-OutputIterator c_transform(const InputContainer& input_container,
- OutputIterator output_iterator,
- UnaryOperation&& unary_op) {
- return std::transform(std::begin(input_container), std::end(input_container),
- output_iterator,
- std::forward<UnaryOperation>(unary_op));
-}
-
-template <class InputContainer, class OutputIterator, class UnaryPredicate>
-OutputIterator c_copy_if(const InputContainer& input_container,
- OutputIterator output_iterator,
- UnaryPredicate&& predicate) {
- return std::copy_if(std::begin(input_container), std::end(input_container),
- output_iterator, std::forward<UnaryPredicate>(predicate));
-}
-
-template <class InputContainer, class OutputIterator>
-OutputIterator c_copy(const InputContainer& input_container,
- OutputIterator output_iterator) {
- return std::copy(std::begin(input_container), std::end(input_container),
- output_iterator);
-}
-
-template <class InputContainer>
-void c_sort(InputContainer& input_container) {
- std::sort(std::begin(input_container), std::end(input_container));
-}
-
-template <class InputContainer, class Comparator>
-void c_sort(InputContainer& input_container, Comparator&& comparator) {
- std::sort(std::begin(input_container), std::end(input_container),
- std::forward<Comparator>(comparator));
-}
-
-template <typename Sequence, typename T>
-bool c_binary_search(const Sequence& sequence, T&& value) {
- return std::binary_search(std::begin(sequence), std::end(sequence),
- std::forward<T>(value));
-}
-
-template <typename C>
-bool c_is_sorted(const C& c) {
- return std::is_sorted(std::begin(c), std::end(c));
-}
-
-template <typename C, typename Compare>
-bool c_is_sorted(const C& c, Compare&& comp) {
- return std::is_sorted(std::begin(c), std::end(c),
- std::forward<Compare>(comp));
-}
-
-template <typename C>
-auto c_adjacent_find(C& c) -> decltype(std::begin(c)) {
- return std::adjacent_find(std::begin(c), std::end(c));
-}
-
-template <typename C, typename Pred>
-auto c_find_if(C& c, Pred&& pred) -> decltype(std::begin(c)) {
- return std::find_if(std::begin(c), std::end(c), std::forward<Pred>(pred));
-}
-
-template <typename C, typename Value>
-auto c_find(C& c, Value&& value) -> decltype(std::begin(c)) {
- return std::find(std::begin(c), std::end(c), std::forward<Value>(value));
-}
-
-template <typename Sequence>
-void c_reverse(Sequence& sequence) {
- std::reverse(std::begin(sequence), std::end(sequence));
-}
-
-template <typename Sequence, typename T, typename BinaryOp>
-typename std::decay<T>::type c_accumulate(const Sequence& sequence, T&& init,
- BinaryOp&& binary_op) {
- return std::accumulate(std::begin(sequence), std::end(sequence),
- std::forward<T>(init),
- std::forward<BinaryOp>(binary_op));
-}
-
-template <typename C, typename Pred>
-typename std::iterator_traits<
- decltype(std::begin(std::declval<C>()))>::difference_type
-c_count_if(const C& c, Pred&& pred) {
- return std::count_if(std::begin(c), std::end(c), std::forward<Pred>(pred));
-}
-
-// Determines whether `value` is present in `c`.
-template <typename C, typename T>
-bool c_linear_search(const C& c, T&& value) {
- auto last = std::end(c);
- return std::find(std::begin(c), last, std::forward<T>(value)) != last;
-}
-
template <typename C, typename Value>
int64 FindIndex(const C& c, Value&& value) {
- auto it = c_find(c, std::forward<Value>(value));
+ auto it = absl::c_find(c, std::forward<Value>(value));
return std::distance(c.begin(), it);
}
template <typename T>
bool ArrayContains(tensorflow::gtl::ArraySlice<T> c, const T& value) {
- return c_find(c, value) != c.end();
+ return absl::c_find(c, value) != c.end();
}
template <typename C, typename Value>
@@ -584,8 +478,8 @@ bool IsInt32(T x) {
template <typename T>
Status EraseElementFromVector(std::vector<T>* container, const T& value) {
- // c_find returns a const_iterator which does not seem to work on gcc 4.8.4,
- // and this breaks the ubuntu/xla_gpu build bot.
+ // absl::c_find returns a const_iterator which does not seem to work on
+ // gcc 4.8.4, and this breaks the ubuntu/xla_gpu build bot.
auto it = std::find(container->begin(), container->end(), value);
TF_RET_CHECK(it != container->end());
container->erase(it);
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index 3b72eb17c6..b53f89d63b 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -195,8 +195,13 @@ message DebugOptions {
bool xla_cpu_enable_fast_math = 99;
bool xla_gpu_enable_fast_math = 100;
- // Extra options to pass to the compilation backend; specific interpretation
- // of these values is left to the backend.
+ // Crashes the program when any kind of verification fails, instead of just
+ // logging the failures. One example is cross checking of convolution results
+ // among different algorithms.
+ bool xla_gpu_crash_on_verification_failures = 101;
+
+ // Extra options to pass to the compilation backend (e.g. LLVM); specific
+ // interpretation of these values is left to the backend.
map<string, string> xla_backend_extra_options = 500;
}
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 4c35e93d38..27aa94c2cb 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -424,25 +424,25 @@ message GatherDimensionNumbers {
// "Window indices" is a term for a set of indices that index into the
// interior of a dynamic-slice from the input tensor, the starting indices for
// which were computed from output_gather_dims (see the operation semantic for
- // how this is defined) and the gather_indices tensor.
+ // how this is defined) and the start_indices tensor.
//
// The window indices for a specific output index Out is computed as:
//
// i = 0
// for (k : [0, input_tensor_shape.rank))
// window_indices[k] =
- // if k in elided_window_dims
+ // if k in collapsed_slice_dims
// then 0
- // else Out[output_window_dims[i++]]
- repeated int64 output_window_dims = 1;
- repeated int64 elided_window_dims = 2;
+ // else Out[offset_dims[i++]]
+ repeated int64 offset_dims = 1;
+ repeated int64 collapsed_slice_dims = 2;
- // This is interpreted as a map from i to gather_dims_to_operand_dims[i]. It
- // transforms the gather index looked up from the gather_indices tensor into
+ // This is interpreted as a map from i to start_index_map[i]. It
+ // transforms the gather index looked up from the start_indices tensor into
// the starting index in the input space.
- repeated int64 gather_dims_to_operand_dims = 3;
+ repeated int64 start_index_map = 3;
- // The dimension in the gather_indices input that contains the starting
+ // The dimension in the start_indices input that contains the starting
// indices.
int64 index_vector_dim = 4;
}
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 23bb783e22..f7e3c8d8fb 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -64,6 +64,7 @@ py_library(
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/contrib/linear_optimizer:sdca_estimator_py",
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
+ "//tensorflow/contrib/lite/python:lite",
"//tensorflow/contrib/lookup:lookup_py",
"//tensorflow/contrib/losses:losses_py",
"//tensorflow/contrib/losses:metric_learning_py",
@@ -135,7 +136,6 @@ py_library(
# This is an issue with the tensorrt static library and will be fixed by
# the next tensorrt release, so fix the order here after that.
"//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows
- "//tensorflow/contrib/lite/python:lite", # unix dependency, need to fix code
]),
)
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index e18ea8df4d..45a7680160 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -94,8 +94,7 @@ from tensorflow.contrib import tpu
from tensorflow.contrib import training
from tensorflow.contrib import util
from tensorflow.contrib.eager.python import tfe as eager
-if os.name != "nt":
- from tensorflow.contrib.lite.python import lite
+from tensorflow.contrib.lite.python import lite
from tensorflow.contrib.optimizer_v2 import optimizer_v2_symbols as optimizer_v2
from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field
from tensorflow.contrib.recurrent.python import recurrent_api as recurrent
diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
index 68ead2f760..9afe3df585 100644
--- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
@@ -14,8 +14,6 @@
# ==============================================================================
"""Monte Carlo integration and helpers.
-See the @{$python/contrib.bayesflow.monte_carlo} guide.
-
@@expectation
@@expectation_importance_sampler
@@expectation_importance_sampler_logspace
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index 68d710d713..c155128c0e 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -16,7 +16,10 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+
import tempfile
+import numpy as np
+
from tensorflow.contrib.boosted_trees.estimator_batch import estimator
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.layers.python.layers import feature_column as contrib_feature_column
@@ -26,6 +29,7 @@ from tensorflow.python.feature_column import feature_column_lib as core_feature_
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import gfile
@@ -473,6 +477,63 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase):
classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1)
classifier.predict(input_fn=_eval_input_fn)
+ def testWeightedCategoricalColumn(self):
+ head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ feature_columns = [
+ core_feature_column.weighted_categorical_column(
+ categorical_column=core_feature_column.
+ categorical_column_with_vocabulary_list(
+ key="word", vocabulary_list=["the", "cat", "dog"]),
+ weight_feature_key="weight")
+ ]
+
+ labels = np.array([[1], [1], [0], [0.]], dtype=np.float32)
+
+ def _make_input_fn():
+
+ def _input_fn():
+ features_dict = {}
+ # Sparse tensor representing
+ # example 0: "cat","the"
+ # examaple 1: "dog"
+ # example 2: -
+ # example 3: "the"
+ # Weights for the words are 5 - cat, 6- dog and 1 -the.
+ features_dict["word"] = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1], [1, 0], [3, 0]],
+ values=constant_op.constant(
+ ["the", "cat", "dog", "the"], dtype=dtypes.string),
+ dense_shape=[4, 3])
+ features_dict["weight"] = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1], [1, 0], [3, 0]],
+ values=[1., 5., 6., 1.],
+ dense_shape=[4, 3])
+ return features_dict, labels
+
+ return _input_fn
+
+ est = estimator.CoreGradientBoostedDecisionTreeEstimator(
+ head=head_fn,
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=feature_columns)
+
+ input_fn = _make_input_fn()
+ est.train(input_fn=input_fn, steps=100)
+ est.evaluate(input_fn=input_fn, steps=1)
+ est.predict(input_fn=input_fn)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index 20ff48c360..2f75d8aa99 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -218,6 +218,21 @@ def extract_features(features, feature_columns, use_core_columns):
sparse_int_shapes = []
for key in sorted(features.keys()):
tensor = features[key]
+ # TODO(nponomareva): consider iterating over feature columns instead.
+ if isinstance(tensor, tuple):
+ # Weighted categorical feature.
+ categorical_tensor = tensor[0]
+ weight_tensor = tensor[1]
+
+ shape = categorical_tensor.dense_shape
+ indices = array_ops.concat([
+ array_ops.slice(categorical_tensor.indices, [0, 0], [-1, 1]),
+ array_ops.expand_dims(
+ math_ops.to_int64(categorical_tensor.values), -1)
+ ], 1)
+ tensor = sparse_tensor.SparseTensor(
+ indices=indices, values=weight_tensor.values, dense_shape=shape)
+
if isinstance(tensor, sparse_tensor.SparseTensor):
if tensor.values.dtype == dtypes.float32:
sparse_float_names.append(key)
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index e92f0bb841..150d734db6 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -34,6 +34,9 @@ Checkpointable data structures:
Checkpoint management:
@@CheckpointManager
+
+Saving and restoring Python state:
+@@NumpyState
"""
from __future__ import absolute_import
@@ -41,6 +44,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker
+from tensorflow.contrib.checkpoint.python.python_state import NumpyState
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD
index 7b200a29bf..ada4168726 100644
--- a/tensorflow/contrib/checkpoint/python/BUILD
+++ b/tensorflow/contrib/checkpoint/python/BUILD
@@ -9,6 +9,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":containers",
+ ":python_state",
":split_dependency",
":visualize",
"//tensorflow/python/training/checkpointable:data_structures",
@@ -41,6 +42,33 @@ py_test(
)
py_library(
+ name = "python_state",
+ srcs = ["python_state.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/python/training/checkpointable:base",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "python_state_test",
+ srcs = ["python_state_test.py"],
+ deps = [
+ ":python_state",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:session",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python/training/checkpointable:util",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
name = "split_dependency",
srcs = ["split_dependency.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py
new file mode 100644
index 0000000000..9b11035b6d
--- /dev/null
+++ b/tensorflow/contrib/checkpoint/python/python_state.py
@@ -0,0 +1,166 @@
+"""Utilities for including Python state in TensorFlow checkpoints."""
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import numpy
+
+from tensorflow.python.training.checkpointable import base
+
+# pylint: disable=g-import-not-at-top
+try:
+ # In Python 2.x, use the faster string buffering option.
+ from cStringIO import StringIO as BytesIO
+except ImportError:
+ from io import BytesIO
+# pylint: enable=g-import-not-at-top
+
+
+class NumpyState(base.CheckpointableBase):
+ """A checkpointable object whose NumPy array attributes are saved/restored.
+
+ Example usage:
+
+ ```python
+ arrays = tf.contrib.checkpoint.NumpyState()
+ checkpoint = tf.train.Checkpoint(numpy_arrays=arrays)
+ arrays.x = numpy.zeros([3, 4])
+ save_path = checkpoint.save("/tmp/ckpt")
+ arrays.x[1, 1] = 4.
+ checkpoint.restore(save_path)
+ assert (arrays.x == numpy.zeros([3, 4])).all()
+
+ second_checkpoint = tf.train.Checkpoint(
+ numpy_arrays=tf.contrib.checkpoint.NumpyState())
+ # Attributes of NumpyState objects are created automatically by restore()
+ second_checkpoint.restore(save_path)
+ assert (second_checkpoint.numpy_arrays.x == numpy.zeros([3, 4])).all()
+ ```
+
+ Note that `NumpyState` objects re-create the attributes of the previously
+ saved object on `restore()`. This is in contrast to TensorFlow variables, for
+ which a `Variable` object must be created and assigned to an attribute.
+
+ This snippet works both when graph building and when executing eagerly. On
+ save, the NumPy array(s) are fed as strings to be saved in the checkpoint (via
+ a placeholder when graph building, or as a string constant when executing
+ eagerly). When restoring they skip the TensorFlow graph entirely, and so no
+ restore ops need be run. This means that restoration always happens eagerly,
+ rather than waiting for `checkpoint.restore(...).run_restore_ops()` like
+ TensorFlow variables when graph building.
+ """
+
+ def _lookup_dependency(self, name):
+ """Create placeholder NumPy arrays for to-be-restored attributes.
+
+ Typically `_lookup_dependency` is used to check by name whether a dependency
+ exists. We cheat slightly by creating a checkpointable object for `name` if
+ we don't already have one, giving us attribute re-creation behavior when
+ loading a checkpoint.
+
+ Args:
+ name: The name of the dependency being checked.
+ Returns:
+ An existing dependency if one exists, or a new `_NumpyWrapper` placeholder
+ dependency (which will generally be restored immediately).
+ """
+ value = super(NumpyState, self)._lookup_dependency(name)
+ if value is None:
+ value = _NumpyWrapper(numpy.array([]))
+ new_reference = base.CheckpointableReference(name=name, ref=value)
+ self._unconditional_checkpoint_dependencies.append(new_reference)
+ self._unconditional_dependency_names[name] = value
+ super(NumpyState, self).__setattr__(name, value)
+ return value
+
+ def __getattribute__(self, name):
+ """Un-wrap `_NumpyWrapper` objects when accessing attributes."""
+ value = super(NumpyState, self).__getattribute__(name)
+ if isinstance(value, _NumpyWrapper):
+ return value.array
+ return value
+
+ def __setattr__(self, name, value):
+ """Automatically wrap NumPy arrays assigned to attributes."""
+ # TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making
+ # ndarrays checkpointable natively and using standard checkpointable list
+ # tracking.
+ if isinstance(value, numpy.ndarray):
+ try:
+ existing = super(NumpyState, self).__getattribute__(name)
+ existing.array = value
+ return
+ except AttributeError:
+ value = _NumpyWrapper(value)
+ self._track_checkpointable(value, name=name, overwrite=True)
+ elif (name not in ("_setattr_tracking", "_update_uid")
+ and getattr(self, "_setattr_tracking", True)):
+ # Mixing restore()-created attributes with user-added checkpointable
+ # objects is tricky, since we can't use the `_lookup_dependency` trick to
+ # re-create attributes (we might accidentally steal the restoration for
+ # another checkpointable object). For now `NumpyState` objects must be
+ # leaf nodes. Theoretically we could add some extra arguments to
+ # `_lookup_dependency` to figure out whether we should create a NumPy
+ # array for the attribute or not.
+ raise NotImplementedError(
+ ("Assigned %s to the %s property of %s, which is not a NumPy array. "
+ "Currently mixing NumPy arrays and other checkpointable objects is "
+ "not supported. File a feature request if this limitation bothers "
+ "you.")
+ % (value, name, self))
+ super(NumpyState, self).__setattr__(name, value)
+
+
+class _NumpyWrapper(base.CheckpointableBase):
+ """Wraps a NumPy array for storage in an object-based checkpoint."""
+
+ def __init__(self, array):
+ """Specify a NumPy array to wrap.
+
+ Args:
+ array: The NumPy array to save and restore (may be overwritten).
+ """
+ self.array = array
+
+ def _serialize(self):
+ """Callback for `PythonStringStateSaveable` to serialize the array."""
+ string_file = BytesIO()
+ try:
+ numpy.save(string_file, self.array, allow_pickle=False)
+ serialized = string_file.getvalue()
+ finally:
+ string_file.close()
+ return serialized
+
+ def _deserialize(self, string_value):
+ """Callback for `PythonStringStateSaveable` to deserialize the array."""
+ string_file = BytesIO(string_value)
+ try:
+ self.array = numpy.load(string_file, allow_pickle=False)
+ finally:
+ string_file.close()
+
+ def _gather_saveables_for_checkpoint(self):
+ """Specify callbacks for saving and restoring `array`."""
+ return {
+ "array": functools.partial(
+ base.PythonStringStateSaveable,
+ state_callback=self._serialize,
+ restore_callback=self._deserialize)
+ }
diff --git a/tensorflow/contrib/checkpoint/python/python_state_test.py b/tensorflow/contrib/checkpoint/python/python_state_test.py
new file mode 100644
index 0000000000..0439a4755e
--- /dev/null
+++ b/tensorflow/contrib/checkpoint/python/python_state_test.py
@@ -0,0 +1,101 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy
+
+from tensorflow.contrib.checkpoint.python import python_state
+from tensorflow.python.client import session
+from tensorflow.python.eager import test
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import variables
+from tensorflow.python.training.checkpointable import util
+
+
+class NumpyStateTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testSaveRestoreNumpyState(self):
+ directory = self.get_temp_dir()
+ prefix = os.path.join(directory, "ckpt")
+ save_state = python_state.NumpyState()
+ saver = util.Checkpoint(numpy=save_state)
+ save_state.a = numpy.ones([2, 2])
+ save_state.b = numpy.ones([2, 2])
+ save_state.b = numpy.zeros([2, 2])
+ self.assertAllEqual(numpy.ones([2, 2]), save_state.a)
+ self.assertAllEqual(numpy.zeros([2, 2]), save_state.b)
+ first_save_path = saver.save(prefix)
+ save_state.a[1, 1] = 2.
+ second_save_path = saver.save(prefix)
+
+ load_state = python_state.NumpyState()
+ loader = util.Checkpoint(numpy=load_state)
+ loader.restore(first_save_path).initialize_or_restore()
+ self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
+ self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
+ load_state.a[0, 0] = 42.
+ self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a)
+ loader.restore(first_save_path).run_restore_ops()
+ self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
+ loader.restore(second_save_path).run_restore_ops()
+ self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a)
+ self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
+
+ def testNoGraphPollution(self):
+ graph = ops.Graph()
+ with graph.as_default(), session.Session():
+ directory = self.get_temp_dir()
+ prefix = os.path.join(directory, "ckpt")
+ save_state = python_state.NumpyState()
+ saver = util.Checkpoint(numpy=save_state)
+ save_state.a = numpy.ones([2, 2])
+ save_path = saver.save(prefix)
+ saver.restore(save_path)
+ graph.finalize()
+ saver.save(prefix)
+ save_state.a = numpy.zeros([2, 2])
+ saver.save(prefix)
+ saver.restore(save_path)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoMixedNumpyStateTF(self):
+ save_state = python_state.NumpyState()
+ save_state.a = numpy.ones([2, 2])
+ with self.assertRaises(NotImplementedError):
+ save_state.v = variables.Variable(1.)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDocstringExample(self):
+ arrays = python_state.NumpyState()
+ checkpoint = util.Checkpoint(numpy_arrays=arrays)
+ arrays.x = numpy.zeros([3, 4])
+ save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+ arrays.x[1, 1] = 4.
+ checkpoint.restore(save_path)
+ self.assertAllEqual(numpy.zeros([3, 4]), arrays.x)
+
+ second_checkpoint = util.Checkpoint(numpy_arrays=python_state.NumpyState())
+ second_checkpoint.restore(save_path)
+ self.assertAllEqual(numpy.zeros([3, 4]), second_checkpoint.numpy_arrays.x)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py
index 95e7e744d3..cb45e42734 100644
--- a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py
+++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import json
+import os
from tensorflow.contrib.cloud.python.ops import gen_gcs_config_ops
from tensorflow.python.framework import dtypes
@@ -188,6 +189,8 @@ def configure_colab_session(session):
session: A `tf.Session` session.
"""
# Read from the application default credentials (adc).
- with open('/content/datalab/adc.json') as f:
+ adc_filename = os.environ.get(
+ 'GOOGLE_APPLICATION_CREDENTIALS', '/content/adc.json')
+ with open(adc_filename) as f:
data = json.load(f)
configure_gcs(session, credentials=data)
diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake
index 1d638e6402..479609458c 100644
--- a/tensorflow/contrib/cmake/external/nsync.cmake
+++ b/tensorflow/contrib/cmake/external/nsync.cmake
@@ -16,16 +16,16 @@ include (ExternalProject)
set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public)
set(nsync_URL https://github.com/google/nsync)
-set(nsync_TAG 1.20.0)
+set(nsync_TAG 1.20.1)
set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync)
set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install)
if(WIN32)
set(nsync_HEADERS "${nsync_BUILD}/public/*.h")
- set(nsync_STATIC_LIBRARIES ${nsync_INSTALL}/lib/nsync.lib)
+ set(nsync_STATIC_LIBRARIES ${nsync_INSTALL}/lib/nsync_cpp.lib)
else()
set(nsync_HEADERS "${nsync_BUILD}/public/*.h")
- set(nsync_STATIC_LIBRARIES ${nsync_INSTALL}/lib/libnsync.a)
+ set(nsync_STATIC_LIBRARIES ${nsync_INSTALL}/lib/libnsync_cpp.a)
endif()
ExternalProject_Add(nsync
@@ -35,12 +35,12 @@ ExternalProject_Add(nsync
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
BUILD_IN_SOURCE 1
BUILD_BYPRODUCTS ${nsync_STATIC_LIBRARIES}
- PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/nsync/CMakeLists.txt ${nsync_BUILD}
INSTALL_DIR ${nsync_INSTALL}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=Release
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
-DCMAKE_INSTALL_PREFIX:STRING=${nsync_INSTALL}
+ -DCMAKE_INSTALL_LIBDIR:STRING=lib
-DNSYNC_LANGUAGE:STRING=c++11)
set(nsync_HEADERS
diff --git a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt
deleted file mode 100644
index 6f059c7225..0000000000
--- a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt
+++ /dev/null
@@ -1,325 +0,0 @@
-cmake_minimum_required (VERSION 2.8.12)
-
-# nsync provides portable synchronization primitives, such as mutexes and
-# condition variables.
-project (nsync)
-
-# Set variable NSYNC_LANGUAGE to "c++11" to build with C++11
-# rather than C.
-
-# Some builds need position-independent code.
-set (CMAKE_POSITION_INDEPENDENT_CODE ON)
-
-# -----------------------------------------------------------------
-# Platform dependencies
-
-# Many platforms use these posix related sources; even Win32.
-set (NSYNC_POSIX_SRC
- "platform/posix/src/nsync_panic.c"
- "platform/posix/src/per_thread_waiter.c"
- "platform/posix/src/time_rep.c"
- "platform/posix/src/yield.c"
-)
-
-if (WIN32)
- # Suppress warnings to reduce build log size.
- add_definitions(/wd4267 /wd4244 /wd4800 /wd4503 /wd4554 /wd4996 /wd4348 /wd4018)
- add_definitions(/wd4099 /wd4146 /wd4267 /wd4305 /wd4307)
- add_definitions(/wd4715 /wd4722 /wd4723 /wd4838 /wd4309 /wd4334)
- add_definitions(/wd4003 /wd4244 /wd4267 /wd4503 /wd4506 /wd4800 /wd4996)
- add_definitions(/wd8029)
-endif()
-
-# Many of the string matches below use a literal "X" suffix on both sides.
-# This is because some versions of cmake treat (for example) "MSVC" (in quotes)
-# as a reference to the variable MSVC, thus the expression
-# "${CMAKE_C_COMPILER_ID}" STREQUAL "MSVC"
-# is false when ${CMAKE_C_COMPILER_ID} has the value "MSVC"! See
-# https://cmake.org/cmake/help/v3.1/policy/CMP0054.html
-
-# Pick the include directory for the operating system.
-if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/c++11")
- add_definitions ("-DNSYNC_USE_CPP11_TIMEPOINT -DNSYNC_ATOMIC_CPP11")
- set (NSYNC_OS_CPP_SRC
- "platform/c++11/src/per_thread_waiter.cc"
- "platform/c++11/src/yield.cc"
- "platform/c++11/src/time_rep_timespec.cc"
- "platform/c++11/src/nsync_panic.cc"
- )
- if ("${CMAKE_SYSTEM_NAME}X" STREQUAL "WindowsX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/win32")
- add_compile_options ("/TP")
- set (NSYNC_OS_SRC
- "platform/c++11/src/nsync_semaphore_mutex.cc"
- "platform/win32/src/clock_gettime.c"
- "platform/win32/src/pthread_key_win32.cc"
- ${NSYNC_OS_CPP_SRC}
- )
- set (NSYNC_TEST_OS_SRC
- "platform/win32/src/start_thread.c"
- )
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "DarwinX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/macos")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/posix")
- # Some versions of MacOS, such as Sierra, require _DARWIN_C_SOURCE
- # when including certin C++ standard header files, such as <mutex>.
- add_definitions ("-D_DARWIN_C_SOURCE")
- add_compile_options ("-std=c++11")
- set (NSYNC_OS_SRC
- ${NSYNC_OS_CPP_SRC}
- "platform/c++11/src/nsync_semaphore_mutex.cc"
- "platform/posix/src/clock_gettime.c"
- "platform/posix/src/nsync_semaphore_mutex.c"
- )
- set (NSYNC_TEST_OS_SRC
- "platform/posix/src/start_thread.c"
- )
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "LinuxX")
- include_directories (BEFORE "${PROJECT_SOURCE_DIR}/platform/c++11.futex")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/posix")
- add_compile_options ("-std=c++11")
- set (NSYNC_OS_SRC
- "platform/linux/src/nsync_semaphore_futex.c"
- ${NSYNC_OS_CPP_SRC}
- )
- set (NSYNC_TEST_OS_SRC
- "platform/posix/src/start_thread.c"
- )
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "NetBSDX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/posix")
- add_compile_options ("-std=c++11")
- set (NSYNC_OS_SRC
- "platform/c++11/src/nsync_semaphore_mutex.cc"
- ${NSYNC_OS_CPP_SRC}
- )
- set (NSYNC_TEST_OS_SRC
- "platform/posix/src/start_thread.c"
- )
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "FreeBSDX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/posix")
- add_compile_options ("-std=c++11")
- set (NSYNC_OS_SRC
- "platform/c++11/src/nsync_semaphore_mutex.cc"
- ${NSYNC_OS_CPP_SRC}
- )
- set (NSYNC_TEST_OS_SRC
- "platform/posix/src/start_thread.c"
- )
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "OpenBSDX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/posix")
- add_compile_options ("-std=c++11")
- set (NSYNC_OS_SRC
- "platform/c++11/src/nsync_semaphore_mutex.cc"
- ${NSYNC_OS_CPP_SRC}
- )
- set (NSYNC_TEST_OS_SRC
- "platform/posix/src/start_thread.c"
- )
- endif ()
-endif ()
-
-# Pick the include directory for the compiler.
-if ("${CMAKE_C_COMPILER_ID}X" STREQUAL "GNUX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/gcc")
- set (THREADS_HAVE_PTHREAD_ARG ON)
-elseif ("${CMAKE_C_COMPILER_ID}X" STREQUAL "ClangX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/clang")
- set (THREADS_HAVE_PTHREAD_ARG ON)
-elseif ("${CMAKE_C_COMPILER_ID}X" STREQUAL "MSVCX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/msvc")
-else ()
- message (WARNING "CMAKE_C_COMPILER_ID (${CMAKE_C_COMPILER_ID}) matched NOTHING")
-endif ()
-
-if (NOT "${NSYNC_LANGUAGE}X" STREQUAL "c++11X")
- if ("${CMAKE_SYSTEM_NAME}X" STREQUAL "WindowsX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/win32")
- set (NSYNC_OS_SRC
- ${NSYNC_POSIX_SRC}
- "platform/win32/src/clock_gettime.c"
- "platform/win32/src/init_callback_win32.c"
- "platform/win32/src/nanosleep.c"
- "platform/win32/src/nsync_semaphore_win32.c"
- "platform/win32/src/pthread_cond_timedwait_win32.c"
- "platform/win32/src/pthread_key_win32.cc"
- )
- set (NSYNC_TEST_OS_SRC
- "platform/win32/src/start_thread.c"
- )
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "DarwinX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/macos")
- set (NSYNC_POSIX ON)
- set (NSYNC_OS_EXTRA_SRC
- "platform/posix/src/clock_gettime.c"
- "platform/posix/src/nsync_semaphore_mutex.c"
- )
- include_directories ("${PROJECT_SOURCE_DIR}/platform/posix")
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "LinuxX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/linux")
- set (NSYNC_POSIX ON)
- set (NSYNC_OS_EXTRA_SRC
- "platform/linux/src/nsync_semaphore_futex.c"
- )
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "NetBSDX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/netbsd")
- set (NSYNC_POSIX ON)
- set (NSYNC_OS_EXTRA_SRC
- "platform/posix/src/nsync_semaphore_mutex.c"
- )
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "FreeBSDX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/freebsd")
- set (NSYNC_POSIX ON)
- set (NSYNC_OS_EXTRA_SRC
- "platform/posix/src/nsync_semaphore_mutex.c"
- )
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "OpenBSDX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/openbsd")
- set (NSYNC_POSIX ON)
- set (NSYNC_OS_EXTRA_SRC
- "platform/posix/src/nsync_semaphore_mutex.c"
- )
- endif ()
-endif ()
-
-if (NSYNC_POSIX)
- include_directories ("${PROJECT_SOURCE_DIR}/platform/posix")
- set (NSYNC_OS_SRC
- ${NSYNC_POSIX_SRC}
- ${NSYNC_OS_EXTRA_SRC}
- )
- set (NSYNC_TEST_OS_SRC
- "platform/posix/src/start_thread.c"
- )
-endif ()
-
-# Pick the include directory for the architecture.
-if (("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "x86_64X") OR
- ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "amd64X") OR
- ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "AMD64X"))
- include_directories ("${PROJECT_SOURCE_DIR}/platform/x86_64")
-elseif (("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "x86_32X") OR
- ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "i386X") OR
- ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "i686X"))
- include_directories ("${PROJECT_SOURCE_DIR}/platform/x86_32")
-elseif (("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "armv6lX") OR
- ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "armv7lX") OR
- ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "armX"))
- include_directories ("${PROJECT_SOURCE_DIR}/platform/arm")
-elseif (("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "aarch64X") OR
- ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "arm64X"))
- include_directories ("${PROJECT_SOURCE_DIR}/platform/aarch64")
-elseif (("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "ppcX") OR
- ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "ppc32X"))
- include_directories ("${PROJECT_SOURCE_DIR}/platform/ppc32")
-elseif (("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "ppc64X"))
- include_directories ("${PROJECT_SOURCE_DIR}/platform/ppc64")
-endif ()
-
-# Windows uses some include files from the posix directory also.
-if ("${CMAKE_SYSTEM_NAME}X" STREQUAL "WindowsX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/posix")
-endif ()
-
-# -----------------------------------------------------------------
-
-include_directories ("${PROJECT_SOURCE_DIR}/public")
-include_directories ("${PROJECT_SOURCE_DIR}/internal")
-
-set (NSYNC_SRC
- "internal/common.c"
- "internal/counter.c"
- "internal/cv.c"
- "internal/debug.c"
- "internal/dll.c"
- "internal/mu.c"
- "internal/mu_wait.c"
- "internal/note.c"
- "internal/once.c"
- "internal/sem_wait.c"
- "internal/time_internal.c"
- "internal/wait.c"
- ${NSYNC_OS_SRC}
-)
-add_library (nsync ${NSYNC_SRC})
-
-set (NSYNC_TEST_SRC
- "testing/array.c"
- "testing/atm_log.c"
- "testing/closure.c"
- "testing/smprintf.c"
- "testing/testing.c"
- "testing/time_extra.c"
- ${NSYNC_TEST_OS_SRC}
-)
-add_library (nsync_test ${NSYNC_TEST_SRC})
-
-set (NSYNC_TESTS
- "counter_test"
- "cv_mu_timeout_stress_test"
- "cv_test"
- "cv_wait_example_test"
- "dll_test"
- "mu_starvation_test"
- "mu_test"
- "mu_wait_example_test"
- "mu_wait_test"
- "note_test"
- "once_test"
- "pingpong_test"
- "wait_test"
-)
-
-if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X")
- foreach (s IN ITEMS ${NSYNC_SRC} ${NSYNC_TEST_SRC})
- SET_SOURCE_FILES_PROPERTIES ("${s}" PROPERTIES LANGUAGE CXX)
- endforeach (s)
- foreach (t IN ITEMS ${NSYNC_TESTS})
- SET_SOURCE_FILES_PROPERTIES ("testing/${t}.c" PROPERTIES LANGUAGE CXX)
- endforeach (t)
-endif ()
-
-enable_testing ()
-foreach (t IN ITEMS ${NSYNC_TESTS})
- add_executable (${t} "testing/${t}.c")
-endforeach (t)
-
-find_package (Threads REQUIRED)
-set (THREADS_PREFER_PTHREAD_FLAG ON)
-foreach (t IN ITEMS "nsync" "nsync_test" ${NSYNC_TESTS})
- if (THREADS_HAVE_PTHREAD_ARG)
- target_compile_options (${t} PUBLIC "-pthread")
- endif ()
- if (CMAKE_THREAD_LIBS_INIT)
- target_link_libraries (${t} "${CMAKE_THREAD_LIBS_INIT}")
- endif ()
-endforeach (t)
-
-foreach (t IN ITEMS ${NSYNC_TESTS})
- target_link_libraries (${t} nsync_test nsync)
- add_test (NAME ${t} COMMAND ${t})
-endforeach (t)
-
-install (TARGETS nsync
- LIBRARY DESTINATION lib COMPONENT RuntimeLibraries
- ARCHIVE DESTINATION lib COMPONENT Development)
-
-set (NSYNC_INCLUDES
- "public/nsync.h"
- "public/nsync_atomic.h"
- "public/nsync_counter.h"
- "public/nsync_cpp.h"
- "public/nsync_cv.h"
- "public/nsync_debug.h"
- "public/nsync_mu.h"
- "public/nsync_mu_wait.h"
- "public/nsync_note.h"
- "public/nsync_once.h"
- "public/nsync_time.h"
- "public/nsync_time_internal.h"
- "public/nsync_waiter.h"
-)
-
-foreach (NSYNC_INCLUDE ${NSYNC_INCLUDES})
- install (FILES ${NSYNC_INCLUDE} DESTINATION include COMPONENT Development)
-endforeach ()
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index a5a947f726..07934ef324 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -4,6 +4,8 @@ tensorflow
tensorflow/core
tensorflow/core/example
tensorflow/core/framework
+tensorflow/core/kernels
+tensorflow/core/kernels/boosted_trees
tensorflow/core/lib
tensorflow/core/lib/core
tensorflow/core/profiler
diff --git a/tensorflow/contrib/crf/__init__.py b/tensorflow/contrib/crf/__init__.py
index 615e62b16f..fe5e34d258 100644
--- a/tensorflow/contrib/crf/__init__.py
+++ b/tensorflow/contrib/crf/__init__.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""Linear-chain CRF layer.
-See the @{$python/contrib.crf} guide.
+See the [CRF](https://tensorflow.org/api_guides/python/contrib.crf) guide.
@@crf_binary_score
@@crf_decode
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index dbfff9b4f8..5821d51bca 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -20,7 +20,7 @@ be used in conjunction with the `tf.data.Dataset` API. Note that the
guarantees as `tf.data`, but we will provide deprecation advice in advance of
removing existing functionality.
-See @{$guide/datasets$Importing Data} for an overview.
+See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@Counter
@@CheckpointInputPipelineHook
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 2b75aa2ca5..4df75c1edb 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -139,7 +139,6 @@ py_test(
srcs = ["interleave_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = [
- "manual",
"no_oss",
"no_pip",
"notap",
diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
index 44c3325a3d..7a3215f6cc 100644
--- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
@@ -777,6 +777,34 @@ class ParallelInterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
+ def testShutdownRace(self):
+ dataset = dataset_ops.Dataset.range(20)
+ map_fn = lambda x: dataset_ops.Dataset.range(20 * x, 20 * (x + 1))
+ dataset = dataset.apply(
+ interleave_ops.parallel_interleave(
+ map_fn,
+ cycle_length=3,
+ sloppy=False,
+ buffer_output_elements=1,
+ prefetch_input_elements=0))
+ dataset = dataset.batch(32)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ results = []
+ with self.test_session() as sess:
+ for _ in range(2):
+ elements = []
+ sess.run(iterator.initializer)
+ try:
+ while True:
+ elements.extend(sess.run(next_element))
+ except errors.OutOfRangeError:
+ pass
+ results.append(elements)
+
+ self.assertAllEqual(results[0], results[1])
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD
index d3628d480d..c16f1d6035 100644
--- a/tensorflow/contrib/distribute/BUILD
+++ b/tensorflow/contrib/distribute/BUILD
@@ -29,7 +29,6 @@ py_library(
"//tensorflow/contrib/distribute/python:cross_tower_ops",
"//tensorflow/contrib/distribute/python:mirrored_strategy",
"//tensorflow/contrib/distribute/python:monitor",
- "//tensorflow/contrib/distribute/python:multi_worker_strategy",
"//tensorflow/contrib/distribute/python:one_device_strategy",
"//tensorflow/contrib/distribute/python:parameter_server_strategy",
"//tensorflow/contrib/distribute/python:step_fn",
diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py
index 2c93ce92ce..588a4f2898 100644
--- a/tensorflow/contrib/distribute/__init__.py
+++ b/tensorflow/contrib/distribute/__init__.py
@@ -23,7 +23,6 @@ from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import
from tensorflow.contrib.distribute.python.cross_tower_ops import *
from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy
from tensorflow.contrib.distribute.python.monitor import Monitor
-from tensorflow.contrib.distribute.python.multi_worker_strategy import MultiWorkerMirroredStrategy
from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy
from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy
from tensorflow.contrib.distribute.python.step_fn import *
@@ -40,7 +39,6 @@ _allowed_symbols = [
'CrossTowerOps',
'DistributionStrategy',
'MirroredStrategy',
- 'MultiWorkerMirroredStrategy',
'Monitor',
'OneDeviceStrategy',
'ParameterServerStrategy',
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 40a1c1707c..59efd17746 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -72,31 +72,21 @@ py_library(
":cross_tower_ops",
":shared_variable_creator",
":values",
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
"//tensorflow/python:device",
"//tensorflow/python:device_util",
"//tensorflow/python:distribute",
"//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:training",
+ "//tensorflow/python:util",
"//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:tape",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "multi_worker_strategy",
- srcs = ["multi_worker_strategy.py"],
- visibility = ["//tensorflow:internal"],
- deps = [
- ":mirrored_strategy",
- ":values",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
],
)
@@ -114,6 +104,7 @@ py_library(
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
+ "//tensorflow/python/distribute:multi_worker_util",
],
)
@@ -184,7 +175,6 @@ py_library(
],
deps = [
":mirrored_strategy",
- ":multi_worker_strategy",
":one_device_strategy",
":tpu_strategy",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
@@ -219,9 +209,13 @@ py_test(
],
deps = [
":mirrored_strategy",
+ ":multi_worker_test_base",
":strategy_test_lib",
+ "//tensorflow/python:constant_op",
"//tensorflow/python:distribute",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index aeec9c44d7..2fbadfe0f5 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -48,7 +48,6 @@ import six
from tensorflow.contrib.cluster_resolver import TPUClusterResolver
from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib
-from tensorflow.contrib.distribute.python import multi_worker_strategy
from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib
from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib
from tensorflow.contrib.optimizer_v2 import adam as adam_v2
@@ -344,31 +343,31 @@ mirrored_strategy_with_two_gpus = NamedDistribution(
multi_worker_strategy_with_cpu = NamedDistribution(
"MultiWorkerCPU",
- lambda: multi_worker_strategy.MultiWorkerMirroredStrategy(
- cluster={
+ lambda: mirrored_lib.MirroredStrategy(
+ cluster_spec={
"worker": [
"/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
]
},
- num_gpus_per_worker=0), 0)
+ num_gpus=0), 0)
multi_worker_strategy_with_one_gpu = NamedDistribution(
"MultiWorker1GPU",
- lambda: multi_worker_strategy.MultiWorkerMirroredStrategy(
- cluster={
+ lambda: mirrored_lib.MirroredStrategy(
+ cluster_spec={
"worker": [
"/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
]
},
- num_gpus_per_worker=1), 1)
+ num_gpus=1), 1)
multi_worker_strategy_with_two_gpus = NamedDistribution(
"MultiWorker2GPUs",
- lambda: multi_worker_strategy.MultiWorkerMirroredStrategy(
- cluster={
+ lambda: mirrored_lib.MirroredStrategy(
+ cluster_spec={
"worker": [
"/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
]
},
- num_gpus_per_worker=2), 2)
+ num_gpus=2), 2)
adam_optimizer_v1_fn = NamedObject(
"AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1))
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py
index 3a7addf221..163559587d 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py
@@ -53,7 +53,7 @@ def validate_destinations(destinations):
if not isinstance(
destinations,
(value_lib.DistributedValues, resource_variable_ops.ResourceVariable,
- six.string_types, list)):
+ value_lib.AggregatingVariable, six.string_types, list)):
raise ValueError("destinations must be one of a `DistributedValues` object,"
" a tf.Variable object, a device string, a list of device "
"strings or None")
@@ -78,7 +78,8 @@ def _validate_value_destination_pairs(value_destination_pairs):
def get_devices_from(destinations):
if isinstance(destinations, value_lib.DistributedValues):
return list(destinations.devices)
- elif isinstance(destinations, resource_variable_ops.ResourceVariable):
+ elif isinstance(destinations, (resource_variable_ops.ResourceVariable,
+ value_lib.AggregatingVariable)):
return [destinations.device]
elif isinstance(destinations, six.string_types):
return [device_util.resolve(destinations)]
@@ -756,7 +757,7 @@ class CollectiveAllReduce(CrossTowerOps):
)
super(CollectiveAllReduce, self).__init__()
- # TODO(yuefengz, tucker): is index slices supported by collective ops?
+ # TODO(yuefengz, tucker): is indexed slices supported by collective ops?
def _reduce(self, aggregation, per_device_value, destinations):
all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0]
if destinations is None or _devices_match(per_device_value, destinations):
@@ -768,8 +769,10 @@ class CollectiveAllReduce(CrossTowerOps):
if d in all_reduced._index:
index[d] = all_reduced._index[d]
else:
- with ops.device(d):
+ with ops.control_dependencies(list(
+ all_reduced._index.values())), ops.device(d):
index[d] = array_ops.identity(list(all_reduced._index.values())[0])
+
return value_lib.Mirrored(index)
def _batch_reduce(self, aggregation, value_destination_pairs):
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
index aec53b01d7..3508c9d599 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
@@ -417,7 +417,7 @@ class MultiWorkerCollectiveAllReduceTest(
devices = ["/device:GPU:%d" % i for i in range(num_gpus)]
else:
devices = ["/device:CPU:0"]
- return collective_all_reduce_ops, devices, "local"
+ return collective_all_reduce_ops, devices, ""
else:
collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce(
3, num_gpus, collective_keys=collective_keys)
@@ -476,7 +476,7 @@ class MultiWorkerCollectiveAllReduceTest(
destination_list = devices
all_destinations = [
- None, destination_mirrored, destination_different, destination_str,
+ destination_different, None, destination_mirrored, destination_str,
destination_list
]
@@ -540,6 +540,12 @@ class MultiWorkerCollectiveAllReduceTest(
self._run_between_graph_clients(self._test_reduction, self._cluster_spec,
num_gpus)
+ # Collective ops doesn't support strategy with one device.
+ def testReductionLocal(self, num_gpus=2):
+ if context.num_gpus() < num_gpus:
+ return
+ self._test_reduction(None, None, num_gpus, local_mode=True)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index 4facd72d12..a262d7666e 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -367,8 +367,8 @@ class TestWithDistributionStrategy(test.TestCase):
# Test with sample weight.
sample_weight = np.random.random((10,))
with self.assertRaisesRegexp(
- NotImplementedError, 'sample_weight is currently not supported when '
- 'using DistributionStrategy.'):
+ NotImplementedError, '`sample_weight` is currently not supported '
+ 'when using DistributionStrategy.'):
model.fit(
dataset,
epochs=1,
diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py
index 2f3d6bdd3f..8163494c8e 100644
--- a/tensorflow/contrib/distribute/python/metrics_v1_test.py
+++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py
@@ -68,6 +68,8 @@ def _regression_dataset_fn():
"predictions": [1., .75, .25, 0.]}).repeat()
+# TODO(priyag): Add TPU Strategy to this once metrics aggregate correctly using
+# TowerLocalVariables on TPUs. Submit http://cl/208914352.
def all_combinations():
return combinations.combine(
distribution=[combinations.default_strategy,
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index aa7a61bb3b..516ede7ade 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -56,11 +56,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
- def step_fn(ctx, inputs):
+ def step_fn(ctx, *inputs):
del ctx # Unused
return distribution.group(
distribution.call_for_each_tower(
- model_fn, inputs, run_concurrently=layer.built))
+ model_fn, *inputs, run_concurrently=layer.built))
iterator = distribution.distribute_dataset(
dataset_fn).make_one_shot_iterator()
@@ -153,11 +153,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
use_callable_loss=True,
create_optimizer_inside_model_fn=True)
- def step_fn(ctx, inputs):
+ def step_fn(ctx, *inputs):
del ctx # Unused
return distribution.group(
distribution.call_for_each_tower(
- model_fn, inputs, run_concurrently=layer.built))
+ model_fn, *inputs, run_concurrently=layer.built))
iterator = distribution.distribute_dataset(
dataset_fn).make_one_shot_iterator()
@@ -231,11 +231,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
if isinstance(distribution, mirrored_strategy.MirroredStrategy):
self.assertFalse(distribution._prefetch_on_device)
- def step_fn(ctx, inputs):
+ def step_fn(ctx, *inputs):
del ctx # Unused
fetches = distribution.unwrap(
distribution.call_for_each_tower(
- model_fn, inputs, run_concurrently=batchnorm.built))
+ model_fn, *inputs, run_concurrently=batchnorm.built))
if update_ops_in_cross_tower_mode:
fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS)
return control_flow_ops.group(fetches)
@@ -328,9 +328,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
labels = dataset_ops.Dataset.from_tensors([[6.], [21.]])
return dataset_ops.Dataset.zip((features, labels)).repeat()
- def step_fn(ctx, inputs):
+ def step_fn(ctx, x, y):
del ctx # Unused
- x, y = inputs
return distribution.group(
distribution.call_for_each_tower(
model_fn, x, y, run_concurrently=False))
@@ -417,9 +416,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
output_context.set_non_tensor_output(key1, value1)
return (train_op, loss)
- def step_fn(output_context, inputs):
+ def step_fn(output_context, *inputs):
(train_op, loss) = distribution.call_for_each_tower(
- model_fn, output_context, inputs, run_concurrently=False)
+ model_fn, output_context, *inputs, run_concurrently=False)
output_context.set_last_step_output(
name="cross_tower_loss_agg",
output=loss,
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index e3376a0636..6981449a4c 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -19,11 +19,13 @@ from __future__ import division
from __future__ import print_function
import contextlib
+from functools import partial
import threading
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import shared_variable_creator
from tensorflow.contrib.distribute.python import values
+from tensorflow.core.protobuf import cluster_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
@@ -37,6 +39,7 @@ from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import coordinator
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import server_lib
from tensorflow.python.util import nest
@@ -291,24 +294,112 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
class MirroredStrategy(distribute_lib.DistributionStrategy):
- """Mirrors vars to distribute across multiple devices on a single machine.
+ """Mirrors vars to distribute across multiple devices and machines.
+
+ This strategy uses one tower per device and sync replication for its multi-GPU
+ version.
+
+ When `cluster_spec` is given, it turns into the mulit-worker version that
+ works on multiple workers with in-graph replication.
+
+ There are several important concepts for distributed TensorFlow, e.g.
+ `client`, `job`, 'task', `cluster`, `in-graph replication` and
+ 'synchronous training' and they have already been defined in the
+ [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed).
+ The distribution strategy inherits these concepts as well and in addition to
+ that we also clarify several more concepts:
+ * **In-graph replication**: the `client` creates a single `tf.Graph` that
+ specifies tasks for devices on all workers. The `client` then creates a
+ client session which will talk to the `master` service of a `worker`. Then
+ the `master` will partition the graph and distribute the work to all
+ participating workers.
+ * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one
+ physical machine. We will have multiple `worker`s with different `task`
+ index. They all do similar things except for one worker checkpointing model
+ variables, writing summaries, etc. in addition to its ordinary work.
+
+ The multi-worker version of this class maps one tower to one device on a
+ worker. It mirrors all model variables on all towers. For example, if you have
+ two `worker`s and each `worker` has 4 GPUs, it will create 8 copies of the
+ model variables on these 8 GPUs. Then like in MirroredStrategy, each tower
+ performs their computation with their own copy of variables unless in
+ cross-tower model where variable or tensor reduction happens.
- This strategy uses one tower per device and sync replication.
+ Args:
+ devices: a list of device strings.
+ num_gpus: number of GPUs. For local training, either specify `devices` or
+ `num_gpus`. In distributed training, this must be specified as number of
+ GPUs on each worker.
+ cluster_spec: if this is set, it turns into the multi-worker version and
+ `devices` must not be set but `num_gpus` must be set.
+ cross_tower_ops: optional, a descedant of `CrossTowerOps`. If this is not
+ set, the `configure` method will try to find the best one.
+ prefetch_on_device: optional boolean to specify whether to prefetch input
+ data to devices.
"""
def __init__(self,
devices=None,
num_gpus=None,
+ cluster_spec=None,
cross_tower_ops=None,
prefetch_on_device=None):
super(MirroredStrategy, self).__init__()
- # Convert `num_gpus` into `devices`, shouldn't specify both.
- if devices is None:
+
+ if cluster_spec:
+ if devices is not None:
+ raise ValueError("Specifying devices when `cluster_spec` is also given "
+ "is not supported in MirroredStrategy.")
+
+ # TODO(yuefengz): use the utility method to normalize cluster_spec.
+ if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
+ cluster_spec = server_lib.ClusterSpec(cluster_spec)
+ elif not isinstance(cluster_spec, server_lib.ClusterSpec):
+ raise ValueError(
+ "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
+ "`tf.train.ClusterDef` object")
+ self._cluster_spec = cluster_spec
+
+ self._workers = []
+ for job in sorted(cluster_spec.jobs):
+ for task in range(cluster_spec.num_tasks(job)):
+ self._workers.append("/job:%s/task:%d" % (job, task))
+
if num_gpus is None:
- num_gpus = context.num_gpus()
- devices = ["/device:GPU:%d" % d for d in range(num_gpus)]
- elif num_gpus is not None:
- raise ValueError("Must only specify one of `devices` and `num_gpus`.")
+ raise ValueError("`num_gpus` is required if `cluster_spec` is given.")
+ self._num_gpus = num_gpus
+ if num_gpus > 0:
+ self._worker_device_map = {
+ worker: [
+ device_util.canonicalize(worker + "/device:GPU:%d" % gpu)
+ for gpu in range(num_gpus)
+ ] for worker in self._workers
+ }
+ else:
+ self._worker_device_map = {
+ worker: [device_util.canonicalize(worker, "/device:CPU:0")]
+ for worker in self._workers
+ }
+ devices = nest.flatten(self._worker_device_map)
+
+ # Setting `_default_device` will add a device scope in the
+ # distribution.scope. We set the default device to the first worker. When
+ # users specify device under distribution.scope by
+ # with tf.device("/cpu:0"):
+ # ...
+ # their ops will end up on the cpu device of its first worker, e.g.
+ # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode.
+ self._default_device = self._workers[0]
+ else:
+ self._cluster_spec = None
+ # Convert `num_gpus` into `devices`, shouldn't specify both.
+ if devices is None:
+ if num_gpus is None:
+ num_gpus = context.num_gpus()
+ devices = ["/device:GPU:%d" % d for d in range(num_gpus)]
+ elif num_gpus is not None:
+ raise ValueError("Must only specify one of `devices` and `num_gpus`.")
+ # TODO(yuefengz): consider setting the default device.
assert devices, "Must specify at least one device."
assert len(set(devices)) == len(devices), (
@@ -320,7 +411,6 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
{d: i for i, d in enumerate(devices)})
self._cross_tower_ops = cross_tower_ops
self._prefetch_on_device = prefetch_on_device
- # TODO(yuefengz): consider setting the default device.
def _create_variable(self, next_creator, *args, **kwargs):
"""Create a mirrored variable. See `DistributionStrategy.scope`."""
@@ -357,9 +447,14 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
**kwargs)
def distribute_dataset(self, dataset_fn):
- return values.PerDeviceDataset(
- self._call_dataset_fn(dataset_fn), self._devices,
- self._prefetch_on_device)
+ if self._cluster_spec:
+ return values.MultiWorkerDataset(
+ partial(self._call_dataset_fn, dataset_fn), self._worker_device_map,
+ self._prefetch_on_device)
+ else:
+ return values.PerDeviceDataset(
+ self._call_dataset_fn(dataset_fn), self._devices,
+ self._prefetch_on_device)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
def _run_steps_on_dataset(self, fn, iterator, iterations,
@@ -372,7 +467,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def body(i, *args):
"""A wrapper around `fn` to create the while loop body."""
del args
- fn_result = fn(ctx, iterator.get_next())
+ fn_inputs = iterator.get_next()
+ if not isinstance(fn_inputs, tuple):
+ fn_inputs = (fn_inputs,)
+ fn_result = fn(ctx, *fn_inputs)
for (name, output) in ctx.last_step_outputs.items():
# Convert all outputs to tensors, potentially from `DistributedValues`.
ctx.last_step_outputs[name] = self.unwrap(output)
@@ -380,12 +478,21 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
with ops.control_dependencies([fn_result]):
return [i + 1] + flat_last_step_outputs
+ # We capture the control_flow_context at this point, before we run `fn`
+ # inside a while_loop. This is useful in cases where we might need to exit
+ # these contexts and get back to the outer context to do some things, for
+ # e.g. create an op which should be evaluated only once at the end of the
+ # loop on the host. One such usage is in creating metrics' value op.
+ self._outer_control_flow_context = (
+ ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
+
cond = lambda i, *args: i < iterations
i = constant_op.constant(0)
loop_result = control_flow_ops.while_loop(
cond, body, [i] + initial_loop_values, name="",
parallel_iterations=1, back_prop=False, swap_memory=False,
return_same_structure=True)
+ del self._outer_control_flow_context
ctx.run_op = control_flow_ops.group(loop_result)
@@ -432,10 +539,19 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
# in addition to PerDevice data.
return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()})
- def configure(self, session_config=None):
+ def configure(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ del cluster_spec, task_type, task_id
if self._cross_tower_ops is None:
- self._cross_tower_ops = cross_tower_ops_lib.choose_the_best(
- self._devices, session_config=session_config)
+ if self._cluster_spec:
+ self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce(
+ self._workers, self._num_gpus)
+ else:
+ self._cross_tower_ops = cross_tower_ops_lib.choose_the_best(
+ self._devices, session_config=session_config)
def _get_cross_tower_ops(self):
if self._cross_tower_ops is None:
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
index 5db2fff239..55d59adc07 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
@@ -19,12 +19,16 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.training import server_lib
class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase):
@@ -86,5 +90,33 @@ class VariableCreatorStackTest(test.TestCase):
self.assertEquals(expected, result)
+class MultiWorkerMirroredStrategyTest(
+ multi_worker_test_base.MultiWorkerTestBase,
+ strategy_test_lib.DistributionTestBase):
+
+ def _get_distribution_strategy(self):
+ return mirrored_strategy.MirroredStrategy(
+ cluster_spec=server_lib.ClusterSpec({
+ 'worker': ['/job:worker/task:0', '/job:worker/task:1']
+ }),
+ num_gpus=context.num_gpus())
+
+ def testMinimizeLossGraph(self):
+ self._test_minimize_loss_graph(self._get_distribution_strategy())
+
+ def testDeviceScope(self):
+ """Test the device scope of multi-worker MirroredStrategy."""
+ with context.graph_mode():
+ strategy = mirrored_strategy.MirroredStrategy(
+ cluster_spec={'worker': ['/job:worker/task:0', '/job:worker/task:1']},
+ num_gpus=context.num_gpus())
+ with strategy.scope():
+ a = constant_op.constant(1.)
+ with ops.device('/cpu:0'):
+ b = constant_op.constant(1.)
+ self.assertEqual(a.device, '/job:worker/task:0')
+ self.assertEqual(b.device, '/job:worker/task:0/device:CPU:0')
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py
deleted file mode 100644
index cbfe5df61d..0000000000
--- a/tensorflow/contrib/distribute/python/multi_worker_strategy.py
+++ /dev/null
@@ -1,141 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Classes implementing a mirrored DistributionStrategy for multiple workers."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from functools import partial
-
-from tensorflow.contrib.distribute.python import values
-from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy
-from tensorflow.core.protobuf import cluster_pb2
-from tensorflow.python.training import device_util
-from tensorflow.python.training import server_lib
-from tensorflow.python.util import nest
-
-
-# TODO(yuefengz): support between-graph replication.
-# TODO(yuefengz): merge this class into its base class.
-# TODO(yuefengz): in some cases, we probably want to use configure method to
-# configure this class.
-# TODO(yuefengz): MirroredStrategy.worker_devices may be confusing after the
-# class is introduced.
-class MultiWorkerMirroredStrategy(MirroredStrategy):
- """Mirrored strategy that works on multiple workers with in-graph replication.
-
- There are several important concepts for distributed TensorFlow, e.g.
- `client`, `job`, 'task', `cluster`, `in-graph replication` and
- 'synchronous training' and they have already been defined in the
- [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed).
- The distribution strategy inherits these concepts as well and in addition to
- that we also clarify several more concepts:
- * **In-graph replication**: the `client` creates a single `tf.Graph` that
- specifies tasks for devices on all workers. The `client` then creates a
- client session which will talk to the `master` service of a `worker`. Then
- the `master` will partition the graph and distribute the work to all
- participating workers.
- * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one
- physical machine. We will have multiple `worker`s with different `task`
- index. They all do similar things except for one worker checkpointing model
- variables, writing summaries, etc. in addition to its ordinary work.
-
- This class maps one tower to one device on a worker. It mirrors all model
- variables on all towers. For example, if you have two `worker`s and each
- `worker` has 4 GPUs, it will create 8 copies of the model variables on these 8
- GPUs. Then like in MirroredStrategy, each tower performs their computation
- with their own copy of variables unless in cross-tower model where variable or
- tensor reduction happens.
- """
-
- def __init__(self,
- num_gpus_per_worker=1,
- worker_job_name=None,
- num_workers=None,
- cluster=None,
- cross_tower_ops=None,
- prefetch_on_device=None):
- """Initialize the strategy object.
-
- Args:
- num_gpus_per_worker: number of GPUs per work. If it is zero, the local
- CPU will be used.
- worker_job_name: the job name for `worker`, typically just 'worker'.
- num_workers: the number of workers. If it is 0, it regenerates to
- single-worker MirroredStrategy.
- cluster: a `tf.train.ClusterSpec` object or a dict that can be used to
- construct a `tf.train.ClusterSpec` object or a `tf.train.ClusterDef`
- proto buffer. It is an alternative way to initialize this object.
- cross_tower_ops: the cross tower ops to use. If None, a default one will
- be used. If configure method is called, a best one for the configuration
- will be chosen.
- prefetch_on_device: a boolean to specify whether to prefetech input to
- each worker's devices.
-
- Raises:
- ValueError: if got an unexpected `cluster`.
- """
- if cluster is None:
- self._workers = [
- '/job:%s/task:%d' % (worker_job_name, task_index)
- for task_index in range(num_workers)
- ]
- else:
- if isinstance(cluster, (dict, cluster_pb2.ClusterDef)):
- cluster_spec = server_lib.ClusterSpec(cluster)
- elif isinstance(cluster, server_lib.ClusterSpec):
- cluster_spec = cluster
- else:
- raise ValueError(
- "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
- '`tf.train.ClusterDef` object')
-
- self._workers = []
- for job in sorted(cluster_spec.jobs):
- for task in range(cluster_spec.num_tasks(job)):
- self._workers.append('/job:%s/task:%d' % (job, task))
-
- self._num_gpus_per_worker = num_gpus_per_worker
- if num_gpus_per_worker > 0:
- self._worker_device_map = {
- worker: [
- device_util.canonicalize(worker + '/device:GPU:%d' % gpu)
- for gpu in range(num_gpus_per_worker)
- ] for worker in self._workers
- }
- else:
- self._worker_device_map = {
- worker: [device_util.canonicalize(worker, '/device:CPU:0')]
- for worker in self._workers
- }
- self._devices = nest.flatten(self._worker_device_map)
-
- super(MultiWorkerMirroredStrategy, self).__init__(
- devices=self._devices, prefetch_on_device=prefetch_on_device)
-
- # Setting `_default_device` will add a device scope in the
- # distribution.scope. We set the default device to the first worker. When
- # users specify device under distribution.scope by
- # with tf.device("/cpu:0"):
- # ...
- # their ops will end up on the cpu device of its first worker, e.g.
- # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode.
- self._default_device = self._workers[0]
-
- def distribute_dataset(self, dataset_fn):
- return values.MultiWorkerDataset(
- partial(self._call_dataset_fn, dataset_fn), self._worker_device_map,
- self._prefetch_on_device)
diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py
deleted file mode 100644
index 09c859b32a..0000000000
--- a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for MultiWorkerMirroredStrategy."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.distribute.python import multi_worker_strategy
-from tensorflow.contrib.distribute.python import multi_worker_test_base
-from tensorflow.contrib.distribute.python import strategy_test_lib
-from tensorflow.python.eager import context
-from tensorflow.python.eager import test
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import ops
-from tensorflow.python.training import server_lib
-
-
-class MultiWorkerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
- strategy_test_lib.DistributionTestBase):
-
- def _get_distribution_strategy(self):
- return multi_worker_strategy.MultiWorkerMirroredStrategy(
- cluster=server_lib.ClusterSpec({
- 'worker': ['/job:worker/task:0', '/job:worker/task:1']
- }),
- num_gpus_per_worker=context.num_gpus())
-
- def testMinimizeLossGraph(self):
- self._test_minimize_loss_graph(self._get_distribution_strategy())
-
-
-class DeviceScopeTest(test.TestCase):
- """Test the device scope of MultiWorkerMirroredStrategy."""
-
- def testDeviceScope(self):
- with context.graph_mode():
- strategy = multi_worker_strategy.MultiWorkerMirroredStrategy(
- cluster={'worker': ['/job:worker/task:0', '/job:worker/task:1']},
- num_gpus_per_worker=context.num_gpus())
- with strategy.scope():
- a = constant_op.constant(1.)
- with ops.device('/cpu:0'):
- b = constant_op.constant(1.)
- self.assertEqual(a.device, '/job:worker/task:0')
- self.assertEqual(b.device, '/job:worker/task:0/device:CPU:0')
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index 016978cdb3..68561b5bbf 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -80,18 +80,30 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
def body(i, *args):
"""A wrapper around `fn` to create the while loop body."""
del args
- fn_result = fn(ctx, iterator.get_next())
+ fn_inputs = iterator.get_next()
+ if not isinstance(fn_inputs, tuple):
+ fn_inputs = (fn_inputs,)
+ fn_result = fn(ctx, *fn_inputs)
flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
with ops.control_dependencies([fn_result]):
return [i + 1] + flat_last_step_outputs
+ # We capture the control_flow_context at this point, before we run `fn`
+ # inside a while_loop. This is useful in cases where we might need to exit
+ # these contexts and get back to the outer context to do some things, for
+ # e.g. create an op which should be evaluated only once at the end of the
+ # loop on the host. One such usage is in creating metrics' value op.
+ self._outer_control_flow_context = (
+ ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
+
+ # TODO(priyag): Use max_iterations instead of an explicit counter.
cond = lambda i, *args: i < iterations
i = constant_op.constant(0)
- # TODO(priyag): Use max_iterations instead of an explicit counter.
loop_result = control_flow_ops.while_loop(
cond, body, [i] + initial_loop_values, name="",
parallel_iterations=1, back_prop=False, swap_memory=False,
return_same_structure=True)
+ del self._outer_control_flow_context
ctx.run_op = control_flow_ops.group(loop_result)
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
index 407c78df95..96b6519bc4 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
@@ -18,38 +18,24 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import json
-import os
-
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import values
-from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import device_setter
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
-from tensorflow.python.training import server_lib
from tensorflow.python.util import nest
_LOCAL_CPU = "/device:CPU:0"
_LOCAL_GPU_0 = "/device:GPU:0"
-def _normalize_cluster_spec(cluster_spec):
- """Makes `cluster_spec` into a `ClusterSpec` object."""
- if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
- return server_lib.ClusterSpec(cluster_spec)
- elif not isinstance(cluster_spec, server_lib.ClusterSpec):
- raise ValueError(
- "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
- "`tf.train.ClusterDef` object")
- return cluster_spec
-
-
# TODO(yuefengz): maybe cache variables on local CPU.
# TODO(yuefengz): we may want to set session options to disallow communication
# between workers.
@@ -70,7 +56,11 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
assigned to.
This class assumes between-graph replication will be used and works on a graph
- for a particular worker.
+ for a particular worker. Note that each graph and worker is independent.
+ This means that while each worker will synchronously compute a single gradient
+ update across all GPUs, updates between workers proceed asynchronously.
+ Operations that occur only on the first tower (such as incrementing the global
+ step), will occur on the first tower *of every worker*.
It is expected to call `call_for_each_tower(fn, *args, **kwargs)` for any
operations which potentially can be replicated across towers (i.e. multiple
@@ -88,7 +78,7 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
3) It is also not recommended to open a colocation scope (i.e. calling
`tf.colocate_with`) under the strategy's scope. For colocating variables,
use `distribution.colocate_vars_with` instead. Colocation of ops will possibly
- create conflicts of device assignement.
+ create conflicts of device assignment.
"""
def __init__(self,
@@ -96,7 +86,7 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
cluster_spec=None,
task_type=None,
task_id=None):
- """Initiailizes this strategy.
+ """Initializes this strategy.
Args:
num_gpus_per_worker: number of local GPUs or GPUs per worker.
@@ -108,7 +98,7 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
super(ParameterServerStrategy, self).__init__()
self._num_gpus_per_worker = num_gpus_per_worker
if cluster_spec:
- cluster_spec = _normalize_cluster_spec(cluster_spec)
+ cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
self._cluster_spec = cluster_spec
# We typically don't need to do all-reduce in this strategy.
@@ -216,6 +206,9 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
else:
self._default_device = self._worker_device
+ self._is_chief = cluster_spec is None or multi_worker_util.is_chief(
+ cluster_spec, task_type, task_id)
+
def distribute_dataset(self, dataset_fn):
"""Distributes the dataset to each local GPU."""
return values.PerDeviceDataset(
@@ -229,14 +222,30 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
# TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through
# this creator, such as "MutableHashTable".
def _create_variable(self, next_creator, *args, **kwargs):
+ if self.num_towers > 1:
+ aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
+ if aggregation not in (
+ vs.VariableAggregation.NONE,
+ vs.VariableAggregation.SUM,
+ vs.VariableAggregation.MEAN
+ ):
+ raise ValueError("Invalid variable aggregation mode: " + aggregation +
+ " for variable: " + kwargs["name"])
+
+ def var_creator(*args, **kwargs):
+ v = next_creator(*args, **kwargs)
+ return values.AggregatingVariable(v, aggregation)
+ else:
+ var_creator = next_creator
+
if "colocate_with" in kwargs:
with ops.device(None):
with ops.colocate_with(kwargs["colocate_with"]):
- return next_creator(*args, **kwargs)
+ return var_creator(*args, **kwargs)
with ops.colocate_with(None, ignore_existing=True):
with ops.device(self._variable_device):
- return next_creator(*args, **kwargs)
+ return var_creator(*args, **kwargs)
def _call_for_each_tower(self, fn, *args, **kwargs):
# pylint: disable=protected-access
@@ -258,7 +267,6 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
# pylint: disable=protected-access
return mirrored_strategy._reduce_non_distributed_value(
self, aggregation, value, destinations)
-
return self._cross_tower_ops.reduce(
aggregation, value, destinations=destinations)
@@ -291,6 +299,8 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
return nest.map_structure(_select_fn, structured)
def _update(self, var, fn, *args, **kwargs):
+ if isinstance(var, values.AggregatingVariable):
+ var = var.get()
if not isinstance(var, resource_variable_ops.ResourceVariable):
raise ValueError(
"You can not update `var` %r. It must be a Variable." % var)
@@ -319,26 +329,31 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
# No need to distinguish between normal variables and tower-local variables.
return array_ops.identity(var)
- def configure(self, session_config=None):
- del session_config
+ def configure(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ """Configures the strategy class.
- # Use TF_CONFIG to get the cluster spec and the current job.
- tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
- cluster_spec = _normalize_cluster_spec(tf_config.get("cluster", {}))
+ The strategy object will be re-initialized if `cluster_spec` is given but
+ was not passed in the constructor.
- task_env = tf_config.get("task", {})
- if task_env:
- task_type = task_env.get("type", "worker")
- task_id = int(task_env.get("index", "0"))
- else:
- task_type = "worker"
- task_id = None
+ Args:
+ session_config: not used currently.
+ cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
+ cluster configurations.
+ task_type: the current task type.
+ task_id: the current task id.
+ """
+ del session_config
# Set the devices if cluster_spec is defined in TF_CONFIG but not passed in
# the constructor.
if not self._cluster_spec and cluster_spec:
- self._cluster_spec = cluster_spec
- self._initialize_devices(self._num_gpus_per_worker, cluster_spec,
+ self._cluster_spec = multi_worker_util.normalize_cluster_spec(
+ cluster_spec)
+ self._initialize_devices(self._num_gpus_per_worker, self._cluster_spec,
task_type, task_id)
@property
@@ -356,3 +371,19 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
def non_slot_devices(self, var_list):
return min(var_list, key=lambda x: x.name)
+
+ @property
+ def between_graph(self):
+ return True
+
+ @property
+ def should_init(self):
+ return self._is_chief
+
+ @property
+ def should_checkpoint(self):
+ return self._is_chief
+
+ @property
+ def should_save_summary(self):
+ return self._is_chief
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index 02eb68227d..adfe3e8b02 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import json
import threading
from absl.testing import parameterized
@@ -69,19 +68,8 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
if not task_type:
return distribution, ''
- tf_config = {
- 'cluster': self._cluster_spec,
- 'task': {
- 'type': task_type,
- 'index': task_id
- }
- }
- with self._lock:
- # Accessing environment variables should be protected by locks because
- # environment variables are shared by all threads.
- with test.mock.patch.dict('os.environ',
- {'TF_CONFIG': json.dumps(tf_config)}):
- distribution.configure()
+ distribution.configure(
+ cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id)
return distribution, self._workers[task_id].target
def _test_device_assignment_distributed(self, task_type, task_id, num_gpus):
@@ -113,7 +101,9 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
# The device scope is ignored for variables but not for normal ops.
with ops.device('/job:worker/task:0'):
- x = variable_scope.get_variable('x', initializer=10.0)
+ x = variable_scope.get_variable(
+ 'x', initializer=10.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
x_add = x.assign_add(c)
e = a + c
# The variable x is on the task 1 since the device_function has been
@@ -125,18 +115,26 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
# The colocate_vars_with can override the distribution's device.
with d.colocate_vars_with(x):
- y = variable_scope.get_variable('y', initializer=20.0)
- y_add = y.assign_add(x_add)
+ y = variable_scope.get_variable(
+ 'y', initializer=20.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ # We add an identity here to avoid complaints about summing
+ # non-distributed values.
+ y_add = y.assign_add(array_ops.identity(x_add))
self.assertEqual(y.device, '/job:ps/task:1')
self.assertEqual(y_add.device, y.device)
self.assertEqual(y.device, x.device)
- z = variable_scope.get_variable('z', initializer=10.0)
+ z = variable_scope.get_variable(
+ 'z', initializer=10.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
self.assertEqual(z.device, '/job:ps/task:0')
self.assertNotEqual(z.device, x.device)
with ops.control_dependencies([y_add]):
- z_add = z.assign_add(y)
+ # We add an identity here to avoid complaints about summing
+ # non-distributed values.
+ z_add = z.assign_add(array_ops.identity(y))
with ops.control_dependencies([z_add]):
f = z + c
self.assertEqual(f.device, worker_device + '/' + last_part_device)
@@ -214,7 +212,9 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
# The device scope is ignored for variables but not for normal ops.
with ops.device('/device:GPU:2'):
- x = variable_scope.get_variable('x', initializer=10.0)
+ x = variable_scope.get_variable(
+ 'x', initializer=10.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
x_add = x.assign_add(c)
e = a + c
self.assertEqual(
@@ -224,19 +224,27 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
# The colocate_vars_with can override the distribution's device.
with d.colocate_vars_with(x):
- y = variable_scope.get_variable('y', initializer=20.0)
- y_add = y.assign_add(x_add)
+ y = variable_scope.get_variable(
+ 'y', initializer=20.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ # We add an identity here to avoid complaints about summing
+ # non-distributed values.
+ y_add = y.assign_add(array_ops.identity(x_add))
self.assertEqual(
device_util.canonicalize(y.device), tower_variable_device)
self.assertEqual(y_add.device, y.device)
self.assertEqual(y.device, x.device)
- z = variable_scope.get_variable('z', initializer=10.0)
+ z = variable_scope.get_variable(
+ 'z', initializer=10.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
self.assertEqual(
device_util.canonicalize(z.device), tower_variable_device)
with ops.control_dependencies([y_add]):
- z_add = z.assign_add(y)
+ # We add an identity here to avoid complaints about summing
+ # non-distributed values.
+ z_add = z.assign_add(array_ops.identity(y))
with ops.control_dependencies([z_add]):
f = z + c
self.assertEqual(f.device, tower_compute_device)
@@ -298,11 +306,18 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
d.scope():
def model_fn():
- x = variable_scope.get_variable('x', initializer=10.0)
- y = variable_scope.get_variable('y', initializer=20.0)
-
- x_add = x.assign_add(1.0, use_locking=True)
- y_add = y.assign_add(1.0, use_locking=True)
+ x = variable_scope.get_variable(
+ 'x', initializer=10.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ y = variable_scope.get_variable(
+ 'y', initializer=20.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
+
+ # We explicitly make a constant tensor here to avoid complaints about
+ # summing non-distributed values.
+ one = constant_op.constant(1.0)
+ x_add = x.assign_add(one, use_locking=True)
+ y_add = y.assign_add(one, use_locking=True)
train_op = control_flow_ops.group([x_add, y_add])
return x, y, train_op
diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py
index d3611570b4..1b5a4f64e5 100644
--- a/tensorflow/contrib/distribute/python/step_fn.py
+++ b/tensorflow/contrib/distribute/python/step_fn.py
@@ -90,14 +90,14 @@ class StandardSingleLossStep(StandardInputStep):
def __call__(self):
with self._distribution.scope():
- def step_fn(ctx, inputs):
+ def step_fn(ctx, *inputs):
"""Function to run one iteration with one input."""
gradients_fn = backprop.implicit_grad(self._loss_fn)
gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn)
grads_and_vars = self.distribution.call_for_each_tower(
gradients_fn,
- ctx, inputs,
+ ctx, *inputs,
run_concurrently=self._is_run_concurrently)
# If threads use layers, then we need to run the first step
# sequentially, so that layers.build() is not executed in parallel.
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index b510fdb888..a486003076 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -37,7 +37,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import device_util
-from tensorflow.python.training import server_lib
from tensorflow.python.util import nest
@@ -46,13 +45,13 @@ def get_tpu_system_metadata(tpu_cluster_resolver):
master = tpu_cluster_resolver.master()
# pylint: disable=protected-access
- cluster_def = (tpu_cluster_resolver.cluster_spec()
- or server_lib.ClusterSpec({})).as_cluster_def()
+ cluster_spec = tpu_cluster_resolver.cluster_spec()
+ cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
tpu_system_metadata = (
tpu_system_metadata_lib._query_tpu_system_metadata(
master,
cluster_def=cluster_def,
- query_topology=True))
+ query_topology=False))
return tpu_system_metadata
@@ -144,7 +143,10 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
ctx = values.MultiStepContext()
def run_fn(*args, **kwargs):
del args, kwargs
- fn_result = fn(ctx, dequeue_fn())
+ fn_inputs = dequeue_fn()
+ if not isinstance(fn_inputs, tuple):
+ fn_inputs = (fn_inputs,)
+ fn_result = fn(ctx, *fn_inputs)
flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
if flat_last_step_outputs:
with ops.control_dependencies([fn_result]):
@@ -157,8 +159,18 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
def iterate_on_tpu():
return training_loop.repeat(iterations, run_fn, initial_loop_values)
+ # We capture the control_flow_context at this point, before we run `fn`
+ # inside a while_loop and TPU replicate context. This is useful in cases
+ # where we might need to exit these contexts and get back to the outer
+ # context to do some things, for e.g. create an op which should be
+ # evaluated only once at the end of the loop on the host. One such usage
+ # is in creating metrics' value op.
+ self._outer_control_flow_context = (
+ ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
+
replicate_inputs = [[]] * self.num_towers
replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
+ del self._outer_control_flow_context
ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)
# Filter out any ops from the outputs, typically this would be the case
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 8548a86421..a58bb3a849 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -308,26 +308,6 @@ class DistributedVariable(DistributedDelegate):
ops.register_dense_tensor_like_type(DistributedVariable)
-def _get_update_device():
- """Validate we are in update/update_non_slot() and return current device.
-
- This is used in MirroredVariable.assign* members, to make sure they
- are only called via an update method, to make sure all components of the
- variable are being updated in a consistent way.
-
- Returns:
- A string device.
-
- Raises:
- RuntimeError: If not in distribution.update()/.update_non_slot().
- """
- device = distribute_lib.get_update_device()
- if device is None:
- raise RuntimeError(
- "Use DistributionStrategy.update() to modify a MirroredVariable.")
- return device
-
-
class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
"""Class for defining how to restore a MirroredVariable."""
@@ -366,13 +346,14 @@ class MirroredVariable(DistributedVariable, Mirrored,
f = kwargs.pop("f")
if distribution_strategy_context.get_cross_tower_context():
update_device = distribute_lib.get_update_device()
- # We are calling update on the mirrored variable in cross tower context.
if update_device is not None:
- # We are calling an assign function on the mirrored variable in cross
- # tower context.
+ # We are calling an assign function on the mirrored variable in an
+ # update context.
v = self.get(device=update_device)
return f(v, *args, **kwargs)
+ # We are calling assign on the mirrored variable in cross tower context,
+ # use update to update the variable.
return distribution_strategy_context.get_distribution_strategy().update(
self, f, *args, **kwargs)
else:
@@ -1057,3 +1038,156 @@ def value_container(val):
if container is not None:
return container
return val
+
+
+# TODO(josh11b): Descend from Variable.
+class AggregatingVariable(checkpointable.CheckpointableBase):
+ """A wrapper around a variable that aggregates updates across towers."""
+
+ def __init__(self, v, aggregation):
+ self._v = v
+ # TODO(josh11b): Set v._distributed_container?
+ # v._distributed_container = weakref.ref(self) # pylint: disable=protected-access
+ self._aggregation = aggregation
+
+ def get(self):
+ return self._v
+
+ def __getattr__(self, name):
+ return getattr(self._v, name)
+
+ def _assign_func(self, *args, **kwargs):
+ f = kwargs.pop("f")
+ if distribution_strategy_context.get_cross_tower_context():
+ update_device = distribute_lib.get_update_device()
+ if update_device is not None:
+ # We are calling an assign function in an update context.
+ return f(self._v, *args, **kwargs)
+
+ # We are calling an assign function in cross tower context, wrap it in an
+ # update call.
+ return distribution_strategy_context.get_distribution_strategy().update(
+ self, f, *args, **kwargs)
+ else:
+ assert distribution_strategy_context.get_tower_context()
+ # We are calling an assign function in tower context.
+ # We reduce the value we want to assign/add/sub. More details about how we
+ # handle the different use cases can be found in the _reduce method.
+ # We call the function with the reduced value.
+ if self._aggregation == vs.VariableAggregation.NONE:
+ raise ValueError("You must specify an aggregation method to update a "
+ "a variable in Tower Context.")
+
+ def merge_fn(strategy, value, *other_args, **other_kwargs):
+ return strategy.update(
+ self, f,
+ strategy.reduce(
+ aggregation=self._aggregation, value=value, destinations=self),
+ *other_args, **other_kwargs)
+
+ return distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, *args, **kwargs)
+
+ def assign_sub(self, *args, **kwargs):
+ assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
+ return self._assign_func(f=assign_sub_fn, *args, **kwargs)
+
+ def assign_add(self, *args, **kwargs):
+ assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
+ return self._assign_func(f=assign_add_fn, *args, **kwargs)
+
+ def assign(self, *args, **kwargs):
+ assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
+ return self._assign_func(f=assign_fn, *args, **kwargs)
+
+ @property
+ def aggregation(self):
+ return self._aggregation
+
+ @property
+ def name(self):
+ return self._v.name
+
+ @property
+ def dtype(self):
+ return self._v.dtype
+
+ # TODO(josh11b): Test saving & restoring.
+ def _gather_saveables_for_checkpoint(self):
+ return {checkpointable.VARIABLE_VALUE_KEY: self._v}
+
+ # pylint: disable=multiple-statements
+ def __add__(self, o): return self._v + o
+ def __radd__(self, o): return o + self._v
+ def __sub__(self, o): return self._v - o
+ def __rsub__(self, o): return o - self._v
+ def __mul__(self, o): return self._v * o
+ def __rmul__(self, o): return o * self._v
+ def __truediv__(self, o): return self._v / o
+ def __rtruediv__(self, o): return o / self._v
+ def __floordiv__(self, o): return self._v // o
+ def __rfloordiv__(self, o): return o // self._v
+ def __mod__(self, o): return self._v % o
+ def __rmod__(self, o): return o % self._v
+ def __lt__(self, o): return self._v < o
+ def __le__(self, o): return self._v <= o
+ def __gt__(self, o): return self._v > o
+ def __ge__(self, o): return self._v >= o
+ def __and__(self, o): return self._v & o
+ def __rand__(self, o): return o & self._v
+ def __or__(self, o): return self._v | o
+ def __ror__(self, o): return o | self._v
+ def __xor__(self, o): return self._v ^ o
+ def __rxor__(self, o): return o ^ self._v
+ def __getitem__(self, o): return self._v[o]
+ def __pow__(self, o, modulo=None): return pow(self._v, o, modulo)
+ def __rpow__(self, o): return pow(o, self._v)
+ def __invert__(self): return ~self._v
+ def __neg__(self): return -self._v
+ def __abs__(self): return abs(self._v)
+
+ def __div__(self, o):
+ try:
+ return self._v.__div__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __rdiv__(self, o):
+ try:
+ return self._v.__rdiv__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __matmul__(self, o):
+ try:
+ return self._v.__matmul__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __rmatmul__(self, o):
+ try:
+ return self._v.__rmatmul__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __str__(self):
+ return str(self._v)
+
+ def __repr__(self):
+ return repr(self._v)
+
+
+# Register a conversion function which reads the value of the variable,
+# allowing instances of the class to be used as tensors.
+def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False):
+ return ops.internal_convert_to_tensor(
+ var.get(), dtype=dtype, name=name, as_ref=as_ref)
+
+
+ops.register_tensor_conversion_function(
+ AggregatingVariable, _tensor_conversion_aggregate)
+ops.register_dense_tensor_like_type(AggregatingVariable)
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD
index f7933639a0..fa3f1bb7ad 100644
--- a/tensorflow/contrib/eager/python/BUILD
+++ b/tensorflow/contrib/eager/python/BUILD
@@ -104,7 +104,6 @@ cuda_py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
- "//tensorflow/python/eager:graph_callable",
"//tensorflow/python/eager:test",
"//tensorflow/python:variables",
],
diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py
index 0736ed02b7..e5058bfd94 100644
--- a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py
+++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py
@@ -218,7 +218,7 @@ class DensenetBenchmark(tf.test.Benchmark):
tf.constant(1.).cpu()
def _benchmark_eager_apply(self, label, device_and_format, defun=False,
- execution_mode=None, compiled=False):
+ execution_mode=None):
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks,
@@ -228,7 +228,7 @@ class DensenetBenchmark(tf.test.Benchmark):
weight_decay=1e-4, dropout_rate=0,
pool_initial=True, include_top=True)
if defun:
- model.call = tfe.defun(model.call, compiled=compiled)
+ model.call = tfe.defun(model.call)
batch_size = 64
num_burn = 5
num_iters = 30
@@ -264,8 +264,7 @@ class DensenetBenchmark(tf.test.Benchmark):
make_iterator,
device_and_format,
defun=False,
- execution_mode=None,
- compiled=False):
+ execution_mode=None):
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
for batch_size in self._train_batch_sizes():
@@ -279,8 +278,8 @@ class DensenetBenchmark(tf.test.Benchmark):
optimizer = tf.train.GradientDescentOptimizer(0.1)
apply_grads = apply_gradients
if defun:
- model.call = tfe.defun(model.call, compiled=compiled)
- apply_grads = tfe.defun(apply_gradients, compiled=compiled)
+ model.call = tfe.defun(model.call)
+ apply_grads = tfe.defun(apply_gradients)
num_burn = 3
num_iters = 10
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
index 07d8788882..d265169b5e 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
@@ -216,12 +216,12 @@ class ResNet50Benchmarks(tf.test.Benchmark):
tf.constant(1.).cpu()
def _benchmark_eager_apply(self, label, device_and_format, defun=False,
- execution_mode=None, compiled=False):
+ execution_mode=None):
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
model = resnet50.ResNet50(data_format)
if defun:
- model.call = tfe.defun(model.call, compiled=compiled)
+ model.call = tfe.defun(model.call)
batch_size = 64
num_burn = 5
num_iters = 30
@@ -257,8 +257,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
make_iterator,
device_and_format,
defun=False,
- execution_mode=None,
- compiled=False):
+ execution_mode=None):
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
for batch_size in self._train_batch_sizes():
@@ -267,8 +266,8 @@ class ResNet50Benchmarks(tf.test.Benchmark):
optimizer = tf.train.GradientDescentOptimizer(0.1)
apply_grads = apply_gradients
if defun:
- model.call = tfe.defun(model.call, compiled=compiled)
- apply_grads = tfe.defun(apply_gradients, compiled=compiled)
+ model.call = tfe.defun(model.call)
+ apply_grads = tfe.defun(apply_gradients)
num_burn = 3
num_iters = 10
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
index 84b2ddf0de..6a921e1997 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
@@ -226,14 +226,13 @@ class RevNetBenchmark(tf.test.Benchmark):
label,
device_and_format,
defun=False,
- execution_mode=None,
- compiled=False):
+ execution_mode=None):
config = config_.get_hparams_imagenet_56()
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
model = revnet.RevNet(config=config)
if defun:
- model.call = tfe.defun(model.call, compiled=compiled)
+ model.call = tfe.defun(model.call)
batch_size = 64
num_burn = 5
num_iters = 10
@@ -271,8 +270,7 @@ class RevNetBenchmark(tf.test.Benchmark):
make_iterator,
device_and_format,
defun=False,
- execution_mode=None,
- compiled=False):
+ execution_mode=None):
config = config_.get_hparams_imagenet_56()
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py
index 90a3711475..91bc75213c 100644
--- a/tensorflow/contrib/eager/python/saver_test.py
+++ b/tensorflow/contrib/eager/python/saver_test.py
@@ -21,15 +21,11 @@ import os
from tensorflow.contrib.eager.python import saver as _saver
from tensorflow.python.eager import context
-from tensorflow.python.eager import graph_callable
from tensorflow.python.eager import test
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import variable_scope
from tensorflow.python.training import adam
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import momentum
@@ -142,53 +138,6 @@ class SaverTest(test.TestCase):
with _saver.restore_variables_on_create(ckpt_prefix):
_ = model(resource_variable_ops.ResourceVariable(1.0, name='v2'))
- def testSaveRestoreGraphCallable(self):
- with ops.device(self._dev()):
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def model(x):
- v = variable_scope.get_variable(
- 'v', initializer=init_ops.zeros_initializer(), shape=())
- return v + x
-
- # Default 2 + 0 = 2
- self.assertEqual(
- 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
-
- # Save the variable value 0.
- ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
- _saver.Saver(model.variables).save(ckpt_prefix)
-
- # update variable to 1, so that 2 + 1 = 3
- model.variables[0].assign(1.)
- self.assertEqual(
- 3, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
-
- # load the variable value 0, so that 2 + 0 = 2
- _saver.Saver(model.variables).restore(ckpt_prefix)
- self.assertEqual(
- 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
-
- # update checkpoint variable to 1 and memory value to 2.
- model.variables[0].assign(1.)
- _saver.Saver(model.variables).save(ckpt_prefix)
- model.variables[0].assign(2.)
- self.assertEqual(
- 4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
-
- # reset the graph and reload on create, so that 1 + 2 = 3
- ops.reset_default_graph()
- with _saver.restore_variables_on_create(ckpt_prefix):
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def model2(x):
- v = variable_scope.get_variable(
- 'v', initializer=init_ops.zeros_initializer(), shape=())
- return v + x
-
- self.assertEqual(
- 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy())
-
class GetOptimizerTests(test.TestCase):
diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py
index de11d00a1a..4dfd083443 100644
--- a/tensorflow/contrib/eager/python/tfe.py
+++ b/tensorflow/contrib/eager/python/tfe.py
@@ -16,7 +16,7 @@
EXPERIMENTAL: APIs here are unstable and likely to change without notice.
-To use, at program startup, call `tfe.enable_eager_execution()`.
+To use, at program startup, call `tf.enable_eager_execution()`.
@@metrics
@@ -67,6 +67,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@execution_mode
@@async_wait
@@async_clear_error
+@@set_server_def
@@run_test_in_graph_and_eager_modes
@@run_all_tests_in_graph_and_eager_modes
@@ -110,6 +111,7 @@ from tensorflow.python.eager.context import async_clear_error
from tensorflow.python.eager.context import SYNC
from tensorflow.python.eager.context import ASYNC
from tensorflow.python.eager.context import num_gpus
+from tensorflow.python.eager.context import set_server_def
from tensorflow.python.eager.execution_callbacks import add_execution_callback
from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks
from tensorflow.python.eager.execution_callbacks import inf_callback
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index 6ad3a4a604..258860f263 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -45,7 +45,7 @@ _allowed_symbols = [
'clip_gradients_by_norm',
'forward_features',
'InMemoryEvaluatorHook',
- 'StopAtCheckpointStepHook',
+ 'make_stop_at_checkpoint_step_hook',
'logistic_regression_head',
'multi_class_head',
'multi_head',
diff --git a/tensorflow/contrib/estimator/python/estimator/export.py b/tensorflow/contrib/estimator/python/estimator/export.py
index 03cf6f107c..b0deb9b494 100644
--- a/tensorflow/contrib/estimator/python/estimator/export.py
+++ b/tensorflow/contrib/estimator/python/estimator/export.py
@@ -31,8 +31,8 @@ def export_saved_model_for_mode(
# pylint: disable=line-too-long
"""Exports a single train/eval/predict graph as a SavedModel.
- For a detailed guide, see
- @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}.
+ For a detailed guide, see [Using SavedModel with Estimators](
+ https://tensorflow.org/guide/saved_model#using_savedmodel_with_estimators).
Sample usage:
```python
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py
index faefda7c48..66c46e66b7 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks.py
@@ -74,8 +74,9 @@ class InMemoryEvaluatorHook(training.SessionRunHook):
estimator: A `tf.estimator.Estimator` instance to call evaluate.
input_fn: Equivalent to the `input_fn` arg to `estimator.evaluate`. A
function that constructs the input data for evaluation.
- See @{$premade_estimators#create_input_functions} for more
- information. The function should construct and return one of
+ See [Createing input functions](
+ https://tensorflow.org/guide/premade_estimators#create_input_functions)
+ for more information. The function should construct and return one of
the following:
* A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
@@ -212,8 +213,12 @@ class InMemoryEvaluatorHook(training.SessionRunHook):
self._evaluate(session)
-class StopAtCheckpointStepHook(training.SessionRunHook):
- """Hook that requests stop at a specified step based on checkpoint."""
+class _StopAtCheckpointStepHook(training.SessionRunHook):
+ """Hook that requests stop at a specified step based on checkpoint.
+
+ Note: We recommend using 'make_stop_at_checkpoint_step_hook` to get the proper
+ hook.
+ """
def __init__(self, model_dir, last_step,
wait_after_file_check_secs=30):
@@ -263,4 +268,17 @@ class StopAtCheckpointStepHook(training.SessionRunHook):
else:
time.sleep(self._wait_after_file_check_secs)
+
+def make_stop_at_checkpoint_step_hook(estimator,
+ last_step,
+ wait_after_file_check_secs=30):
+ """Creates a proper StopAtCheckpointStepHook based on chief status."""
+
+ if estimator.config.is_chief:
+ return training.StopAtStepHook(last_step=last_step)
+ return _StopAtCheckpointStepHook(
+ model_dir=estimator.model_dir,
+ last_step=last_step,
+ wait_after_file_check_secs=wait_after_file_check_secs)
+
# pylint: enable=protected-access
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py
index 42352aa3ff..c6c6cad95a 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py
@@ -326,7 +326,7 @@ class StopAtCheckpointStepHookTest(test.TestCase):
step = training.create_global_step()
assign_ten = step.assign(10)
no_op = control_flow_ops.no_op()
- hook = hooks_lib.StopAtCheckpointStepHook(
+ hook = hooks_lib._StopAtCheckpointStepHook(
model_dir=tempfile.mkdtemp(), last_step=10)
with training.SingularMonitoredSession(hooks=[hook]) as mon_sess:
mon_sess.raw_session().run(assign_ten)
@@ -342,7 +342,7 @@ class StopAtCheckpointStepHookTest(test.TestCase):
assign_nine = step.assign(9)
assign_ten = step.assign(10)
no_op = control_flow_ops.no_op()
- hook = hooks_lib.StopAtCheckpointStepHook(
+ hook = hooks_lib._StopAtCheckpointStepHook(
model_dir=model_dir, last_step=10)
with tf_session.Session() as sess:
sess.run(assign_nine)
@@ -360,7 +360,7 @@ class StopAtCheckpointStepHookTest(test.TestCase):
step = training.create_global_step()
assign_ten = step.assign(10)
no_op = control_flow_ops.no_op()
- hook = hooks_lib.StopAtCheckpointStepHook(
+ hook = hooks_lib._StopAtCheckpointStepHook(
model_dir=model_dir, last_step=10)
with tf_session.Session() as sess:
sess.run(assign_ten)
@@ -372,6 +372,32 @@ class StopAtCheckpointStepHookTest(test.TestCase):
self.assertFalse(mock_sleep.called)
self.assertTrue(mon_sess.should_stop())
+ def test_creates_regular_stop_at_step_hook_for_chief(self):
+ # by default an estimator is in chief mode
+ dnn = estimator_lib.DNNClassifier(
+ feature_columns=[feature_column_lib.numeric_column('x')],
+ hidden_units=[3, 1])
+ hook = hooks_lib.make_stop_at_checkpoint_step_hook(dnn, 300)
+ self.assertIsInstance(hook, training.StopAtStepHook)
+ self.assertEqual(300, hook._last_step)
+
+ def test_creates_checkpoint_hook_for_workers(self):
+
+ class FakeWorkerConfig(estimator_lib.RunConfig):
+
+ @property
+ def is_chief(self):
+ return False
+
+ dnn = estimator_lib.DNNClassifier(
+ feature_columns=[feature_column_lib.numeric_column('x')],
+ hidden_units=[3, 1],
+ config=FakeWorkerConfig())
+ hook = hooks_lib.make_stop_at_checkpoint_step_hook(dnn, 300)
+ self.assertIsInstance(hook, hooks_lib._StopAtCheckpointStepHook)
+ self.assertEqual(300, hook._last_step)
+ self.assertEqual(dnn.model_dir, hook._model_dir)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/ffmpeg/__init__.py b/tensorflow/contrib/ffmpeg/__init__.py
index 484ffee3e7..3a756da932 100644
--- a/tensorflow/contrib/ffmpeg/__init__.py
+++ b/tensorflow/contrib/ffmpeg/__init__.py
@@ -15,7 +15,7 @@
# pylint: disable=g-short-docstring-punctuation
"""Working with audio using FFmpeg.
-See the @{$python/contrib.ffmpeg} guide.
+See the [FFMPEG](https://tensorflow.org/api_guides/python/contrib.ffmpeg) guide.
@@decode_audio
@@encode_audio
diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py
index 20d099fe5d..95f5ba90ab 100644
--- a/tensorflow/contrib/framework/__init__.py
+++ b/tensorflow/contrib/framework/__init__.py
@@ -15,7 +15,9 @@
"""Framework utilities.
-See the @{$python/contrib.framework} guide.
+See the
+[Contrib Framework](https://tensorflow.org/api_guides/python/contrib.framework)
+guide.
@@assert_same_float_dtype
@@assert_scalar
diff --git a/tensorflow/contrib/framework/python/ops/critical_section_ops.py b/tensorflow/contrib/framework/python/ops/critical_section_ops.py
index 72835c3ad8..71ab755aa2 100644
--- a/tensorflow/contrib/framework/python/ops/critical_section_ops.py
+++ b/tensorflow/contrib/framework/python/ops/critical_section_ops.py
@@ -325,6 +325,8 @@ class CriticalSection(object):
def _is_self_handle(self, x):
"""Check if the tensor `x` is the same Mutex as `self._handle`."""
+ if isinstance(x, ops.EagerTensor):
+ return x is self._handle
return (x.op.type == "MutexV2"
# blank shared_name means the op will create a unique one.
and x.op.get_attr("shared_name")
@@ -365,8 +367,7 @@ class CriticalSection(object):
"(CriticalSection: %s) requested exclusive resource access "
"of this resource. Did you mean to call execute with keyword "
"argument exclusive_resource_access=False?" %
- (list(resource_intersection), self._handle.name,
- sg.op.name, sg.handle.name))
+ (list(resource_intersection), self._handle, sg, sg.handle))
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
diff --git a/tensorflow/contrib/framework/python/ops/script_ops.py b/tensorflow/contrib/framework/python/ops/script_ops.py
index 5d269fefdc..d5cb679e2c 100644
--- a/tensorflow/contrib/framework/python/ops/script_ops.py
+++ b/tensorflow/contrib/framework/python/ops/script_ops.py
@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
-"""Script Language Operators. See the @{$python/script_ops} guide.
+"""Script Language Operators.
@@py_func
"""
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h
index 7534f5797c..869e899ac8 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRDPARTY_TENSORFLOW_CONTRIB_KERNELS_FUSED_CONV2D_BIAS_ACTIVATION_OP_H_
-#define THIRDPARTY_TENSORFLOW_CONTRIB_KERNELS_FUSED_CONV2D_BIAS_ACTIVATION_OP_H_
+#ifndef TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV2D_BIAS_ACTIVATION_OP_H_
+#define TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV2D_BIAS_ACTIVATION_OP_H_
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -62,4 +62,4 @@ class LaunchFusedConv2DBiasActivationOp<Eigen::GpuDevice, T, BiasType,
} // namespace tensorflow
-#endif
+#endif // TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV2D_BIAS_ACTIVATION_OP_H_
diff --git a/tensorflow/contrib/graph_editor/__init__.py b/tensorflow/contrib/graph_editor/__init__.py
index 51b7f45274..b2de2b9a69 100644
--- a/tensorflow/contrib/graph_editor/__init__.py
+++ b/tensorflow/contrib/graph_editor/__init__.py
@@ -14,7 +14,9 @@
# ==============================================================================
"""TensorFlow Graph Editor.
-See the @{$python/contrib.graph_editor} guide.
+See the
+[Graph Editor](https://tensorflow.org/api_guides/python/contrib.graph_editor)
+guide.
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py
index 1939caaa2d..3054128979 100644
--- a/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
@@ -226,6 +227,81 @@ class InterpolateSplineTest(test_util.TensorFlowTestCase):
interp_val = sess.run(interpolator)
self.assertAllClose(interp_val[0, :, 0], target_interpolation)
+ def test_nd_linear_interpolation_unspecified_shape(self):
+ """Ensure that interpolation supports dynamic batch_size and num_points."""
+
+ tp = _QuadraticPlusSinProblemND()
+ (query_points, _, train_points,
+ train_values) = tp.get_problem(dtype='float64')
+
+ # Construct placeholders such that the batch size, number of train points,
+ # and number of query points are not known at graph construction time.
+ feature_dim = query_points.shape[-1]
+ value_dim = train_values.shape[-1]
+ train_points_ph = array_ops.placeholder(
+ dtype=train_points.dtype, shape=[None, None, feature_dim])
+ train_values_ph = array_ops.placeholder(
+ dtype=train_values.dtype, shape=[None, None, value_dim])
+ query_points_ph = array_ops.placeholder(
+ dtype=query_points.dtype, shape=[None, None, feature_dim])
+
+ order = 1
+ reg_weight = 0.01
+
+ interpolator = interpolate_spline.interpolate_spline(
+ train_points_ph, train_values_ph, query_points_ph, order, reg_weight)
+
+ target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)]
+ target_interpolation = np.array(target_interpolation)
+ with self.test_session() as sess:
+
+ (train_points_value, train_values_value, query_points_value) = sess.run(
+ [train_points, train_values, query_points])
+
+ interp_val = sess.run(
+ interpolator,
+ feed_dict={
+ train_points_ph: train_points_value,
+ train_values_ph: train_values_value,
+ query_points_ph: query_points_value
+ })
+ self.assertAllClose(interp_val[0, :, 0], target_interpolation)
+
+ def test_fully_unspecified_shape(self):
+ """Ensure that erreor is thrown when input/output dim unspecified."""
+
+ tp = _QuadraticPlusSinProblemND()
+ (query_points, _, train_points,
+ train_values) = tp.get_problem(dtype='float64')
+
+ # Construct placeholders such that the batch size, number of train points,
+ # and number of query points are not known at graph construction time.
+ feature_dim = query_points.shape[-1]
+ value_dim = train_values.shape[-1]
+ train_points_ph = array_ops.placeholder(
+ dtype=train_points.dtype, shape=[None, None, feature_dim])
+ train_points_ph_invalid = array_ops.placeholder(
+ dtype=train_points.dtype, shape=[None, None, None])
+ train_values_ph = array_ops.placeholder(
+ dtype=train_values.dtype, shape=[None, None, value_dim])
+ train_values_ph_invalid = array_ops.placeholder(
+ dtype=train_values.dtype, shape=[None, None, None])
+ query_points_ph = array_ops.placeholder(
+ dtype=query_points.dtype, shape=[None, None, feature_dim])
+
+ order = 1
+ reg_weight = 0.01
+
+ with self.assertRaises(ValueError):
+ _ = interpolate_spline.interpolate_spline(
+ train_points_ph_invalid, train_values_ph, query_points_ph, order,
+ reg_weight)
+
+ with self.assertRaises(ValueError):
+ _ = interpolate_spline.interpolate_spline(
+ train_points_ph, train_values_ph_invalid, query_points_ph, order,
+ reg_weight)
+
def test_interpolation_gradient(self):
"""Make sure that backprop can run. Correctness of gradients is assumed.
diff --git a/tensorflow/contrib/image/python/ops/interpolate_spline.py b/tensorflow/contrib/image/python/ops/interpolate_spline.py
index daf8c56456..f0b408faa3 100644
--- a/tensorflow/contrib/image/python/ops/interpolate_spline.py
+++ b/tensorflow/contrib/image/python/ops/interpolate_spline.py
@@ -17,9 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
@@ -95,10 +92,22 @@ def _solve_interpolation(train_points, train_values, order,
Returns:
w: `[b, n, k]` weights on each interpolation center
v: `[b, d, k]` weights on each input dimension
+ Raises:
+ ValueError: if d or k is not fully specified.
"""
- b, n, d = train_points.get_shape().as_list()
- _, _, k = train_values.get_shape().as_list()
+ # These dimensions are set dynamically at runtime.
+ b, n, _ = array_ops.unstack(array_ops.shape(train_points), num=3)
+
+ d = train_points.shape[-1]
+ if d.value is None:
+ raise ValueError('The dimensionality of the input points (d) must be '
+ 'statically-inferrable.')
+
+ k = train_values.shape[-1]
+ if k.value is None:
+ raise ValueError('The dimensionality of the output values (k) must be '
+ 'statically-inferrable.')
# First, rename variables so that the notation (c, f, w, v, A, B, etc.)
# follows https://en.wikipedia.org/wiki/Polyharmonic_spline.
@@ -113,14 +122,12 @@ def _solve_interpolation(train_points, train_values, order,
matrix_a = _phi(_pairwise_squared_distance_matrix(c), order) # [b, n, n]
if regularization_weight > 0:
- batch_identity_matrix = np.expand_dims(np.eye(n), 0)
- batch_identity_matrix = constant_op.constant(
- batch_identity_matrix, dtype=train_points.dtype)
-
+ batch_identity_matrix = array_ops.expand_dims(
+ linalg_ops.eye(n, dtype=c.dtype), 0)
matrix_a += regularization_weight * batch_identity_matrix
# Append ones to the feature values for the bias term in the linear model.
- ones = array_ops.ones([b, n, 1], train_points.dtype)
+ ones = array_ops.ones_like(c[..., :1], dtype=c.dtype)
matrix_b = array_ops.concat([c, ones], 2) # [b, n, d + 1]
# [b, n + d + 1, n]
@@ -164,9 +171,6 @@ def _apply_interpolation(query_points, train_points, w, v, order):
Polyharmonic interpolation evaluated at points defined in query_points.
"""
- batch_size = train_points.get_shape()[0].value
- num_query_points = query_points.get_shape()[1].value
-
# First, compute the contribution from the rbf term.
pairwise_dists = _cross_squared_distance_matrix(query_points, train_points)
phi_pairwise_dists = _phi(pairwise_dists, order)
@@ -177,7 +181,7 @@ def _apply_interpolation(query_points, train_points, w, v, order):
# Pad query_points with ones, for the bias term in the linear model.
query_points_pad = array_ops.concat([
query_points,
- array_ops.ones([batch_size, num_query_points, 1], train_points.dtype)
+ array_ops.ones_like(query_points[..., :1], train_points.dtype)
], 2)
linear_term = math_ops.matmul(query_points_pad, v)
@@ -251,6 +255,9 @@ def interpolate_spline(train_points,
Note the interpolation procedure is differentiable with respect to all inputs
besides the order parameter.
+ We support dynamically-shaped inputs, where batch_size, n, and m are None
+ at graph construction time. However, d and k must be known.
+
Args:
train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional
locations. These do not need to be regularly-spaced.
diff --git a/tensorflow/contrib/integrate/__init__.py b/tensorflow/contrib/integrate/__init__.py
index 694f0c14bd..3c37f152e5 100644
--- a/tensorflow/contrib/integrate/__init__.py
+++ b/tensorflow/contrib/integrate/__init__.py
@@ -15,7 +15,9 @@
"""Integration and ODE solvers.
-See the @{$python/contrib.integrate} guide.
+See the
+[Contrib Integrate](https://tensorflow.org/api_guides/python/contrib.integrate)
+guide.
@@odeint
@@odeint_fixed
diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD
index 7355a403ae..b4fe8cac74 100644
--- a/tensorflow/contrib/layers/BUILD
+++ b/tensorflow/contrib/layers/BUILD
@@ -185,7 +185,7 @@ py_test(
py_test(
name = "normalization_test",
- size = "small",
+ size = "medium",
srcs = ["python/layers/normalization_test.py"],
srcs_version = "PY2AND3",
tags = ["no_windows"], # TODO: needs investigation on Windows
diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py
index a7b41b714f..af8e673f59 100644
--- a/tensorflow/contrib/layers/__init__.py
+++ b/tensorflow/contrib/layers/__init__.py
@@ -14,7 +14,9 @@
# ==============================================================================
"""Ops for building neural network layers, regularizers, summaries, etc.
-See the @{$python/contrib.layers} guide.
+See the
+[Contrib Layers](https://tensorflow.org/api_guides/python/contrib.layers)
+guide.
@@avg_pool2d
@@avg_pool3d
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py
index 3ae07cedab..28d19a0445 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column.py
@@ -997,9 +997,14 @@ class _OneHotColumn(
# Remove (?, -1) index
weighted_column = sparse_ops.sparse_slice(
weighted_column,
- [0, 0],
+ array_ops.zeros_like(weighted_column.dense_shape),
weighted_column.dense_shape)
- return sparse_ops.sparse_tensor_to_dense(weighted_column)
+ dense_tensor = sparse_ops.sparse_tensor_to_dense(weighted_column)
+ batch_shape = array_ops.shape(dense_tensor)[:-1]
+ dense_tensor_shape = array_ops.concat(
+ [batch_shape, [self.length]], axis=0)
+ dense_tensor = array_ops.reshape(dense_tensor, dense_tensor_shape)
+ return dense_tensor
dense_id_tensor = sparse_ops.sparse_tensor_to_dense(sparse_id_column,
default_value=-1)
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py
index 1de9ab7056..eaaf9f8d5f 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py
@@ -57,6 +57,29 @@ def _sparse_id_tensor(shape, vocab_size, seed=112123):
indices=indices, values=values, dense_shape=shape)
+def _sparse_id_tensor_with_weights(shape, vocab_size, seed=112123):
+ # Returns a arbitrary `SparseTensor` with given shape and vocab size.
+ assert vocab_size >= shape[-1]
+ np.random.seed(seed)
+ indices = np.array(list(itertools.product(*[range(s) for s in shape])))
+
+ # Values must be distinct from the vocab
+ values = np.ndarray.flatten(np.array([
+ np.random.choice(vocab_size, size=shape[-1], replace=False)
+ for _ in range(np.prod(shape[:-1]))]))
+ weights = np.sort(np.random.rand(*shape), axis=len(shape)-1)
+
+ # Remove entries if weight < 0.5 for sparsity.
+ keep = np.ndarray.flatten(weights < 0.5) # Remove half of them
+ indices = indices[keep]
+ values = values[keep]
+ weights = np.ndarray.flatten(weights)[keep]
+ return (sparse_tensor_lib.SparseTensor(
+ indices=indices, values=values, dense_shape=shape),
+ sparse_tensor_lib.SparseTensor(
+ indices=indices, values=weights, dense_shape=shape))
+
+
class FeatureColumnTest(test.TestCase):
def testImmutability(self):
@@ -329,6 +352,34 @@ class FeatureColumnTest(test.TestCase):
self.assertEqual(one_hot.sparse_id_column.name, "ids_weighted_by_weights")
self.assertEqual(one_hot.length, 3)
+ def testIntegerizedOneHotColumnForWeightedSparseColumn(self):
+ vocab_size = 5
+ ids = fc.sparse_column_with_integerized_feature("ids", vocab_size)
+ weighted_ids = fc.weighted_sparse_column(ids, "weights")
+ one_hot = fc.one_hot_column(weighted_ids)
+ self.assertEqual(one_hot.sparse_id_column.name, "ids_weighted_by_weights")
+ self.assertEqual(one_hot.length, vocab_size)
+
+ def testIntegerizedOneHotWeightedSparseColumnShape(self):
+ vocab_size = 5
+ for id_tensor_shape in [[4, 3], [2, 4], [3, 3, 3]]:
+ output_rank = len(id_tensor_shape)
+ a = fc.sparse_column_with_integerized_feature("a", vocab_size)
+ weighted = fc.weighted_sparse_column(a, "weights")
+ one_hot = fc.one_hot_column(weighted)
+ id_tensor, weight_tensor = _sparse_id_tensor_with_weights(
+ id_tensor_shape, vocab_size)
+
+ one_hot_output = one_hot._to_dnn_input_layer(
+ (id_tensor, weight_tensor),
+ output_rank=output_rank)
+ one_hot_output_shape = one_hot_output.get_shape().as_list()
+ expected_shape = id_tensor_shape[:-1] + [vocab_size]
+ self.assertEquals(expected_shape, one_hot_output_shape)
+ with self.test_session() as sess:
+ one_hot_value = sess.run(one_hot_output)
+ self.assertEquals(expected_shape, list(one_hot_value.shape))
+
def testOneHotColumnWithSparseColumnWithHashKeys(self):
input_values = ["marlo", "unknown", "omar"]
inputs = constant_op.constant(input_values)
diff --git a/tensorflow/contrib/layers/python/layers/normalization.py b/tensorflow/contrib/layers/python/layers/normalization.py
index c807ab0f2e..11033a2e9c 100644
--- a/tensorflow/contrib/layers/python/layers/normalization.py
+++ b/tensorflow/contrib/layers/python/layers/normalization.py
@@ -176,7 +176,8 @@ def group_norm(inputs,
variables_collections=None,
outputs_collections=None,
trainable=True,
- scope=None):
+ scope=None,
+ mean_close_to_zero=False):
"""Functional interface for the group normalization layer.
Reference: https://arxiv.org/abs/1803.08494.
@@ -222,6 +223,19 @@ def group_norm(inputs,
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
scope: Optional scope for `variable_scope`.
+ mean_close_to_zero: The mean of `input` before ReLU will be close to zero
+ when batch size >= 4k for Resnet-50 on TPU. If `True`, use
+ `nn.sufficient_statistics` and `nn.normalize_moments` to calculate the
+ variance. This is the same behavior as `fused` equals `True` in batch
+ normalization. If `False`, use `nn.moments` to calculate the variance.
+ When `mean` is close to zero, like 1e-4, use `mean` to calculate the
+ variance may have poor result due to repeated roundoff error and
+ denormalization in `mean`. When `mean` is large, like 1e2,
+ sum(`input`^2) is so large that only the high-order digits of the elements
+ are being accumulated. Thus, use sum(`input` - `mean`)^2/n to calculate
+ the variance has better accuracy compared to (sum(`input`^2)/n - `mean`^2)
+ when `mean` is large.
+
Returns:
A `Tensor` representing the output of the operation.
@@ -333,7 +347,14 @@ def group_norm(inputs,
gamma = array_ops.reshape(gamma, params_shape_broadcast)
# Calculate the moments.
- mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)
+ if mean_close_to_zero:
+ # One pass algorithm returns better result when mean is close to zero.
+ counts, means_ss, variance_ss, _ = nn.sufficient_statistics(
+ inputs, moments_axes, keep_dims=True)
+ mean, variance = nn.normalize_moments(
+ counts, means_ss, variance_ss, shift=None)
+ else:
+ mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)
# Compute normalization.
# TODO(shlens): Fix nn.batch_normalization to handle the 5-D Tensor
diff --git a/tensorflow/contrib/layers/python/layers/normalization_test.py b/tensorflow/contrib/layers/python/layers/normalization_test.py
index b6e96350db..55272e5fd1 100644
--- a/tensorflow/contrib/layers/python/layers/normalization_test.py
+++ b/tensorflow/contrib/layers/python/layers/normalization_test.py
@@ -293,8 +293,13 @@ class GroupNormTest(test.TestCase):
train_np, eval_np = sess.run([output_train, output_eval])
self.assertAllClose(train_np, eval_np)
- def doOutputTest(self, input_shape, channels_axis=None, reduction_axes=None,
- groups=2, tol=1e-2):
+ def doOutputTest(self,
+ input_shape,
+ channels_axis=None,
+ reduction_axes=None,
+ mean_close_to_zero=False,
+ groups=2,
+ tol=1e-2):
# Select the axis for the channel and the dimensions along which statistics
# are accumulated.
if channels_axis < 0:
@@ -322,17 +327,28 @@ class GroupNormTest(test.TestCase):
if i not in reduced_axes:
reduced_shape.append(a)
- for mu in (0.0, 1e2):
- for sigma in (1.0, 0.1):
+ if mean_close_to_zero:
+ mu_tuple = (1e-4, 1e-2, 1.0)
+ sigma_tuple = (1e-2, 0.1, 1.0)
+ else:
+ mu_tuple = (1.0, 1e2)
+ sigma_tuple = (1.0, 0.1)
+
+ for mu in mu_tuple:
+ for sigma in sigma_tuple:
# Determine shape of Tensor after normalization.
expected_mean = np.zeros(reduced_shape)
expected_var = np.ones(reduced_shape)
- inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu
+ inputs = random_ops.random_normal(input_shape, seed=0) * sigma + mu
output_op = normalization.group_norm(
- inputs, groups=groups, center=False, scale=False,
+ inputs,
+ groups=groups,
+ center=False,
+ scale=False,
channels_axis=channels_axis,
- reduction_axes=reduction_axes)
+ reduction_axes=reduction_axes,
+ mean_close_to_zero=mean_close_to_zero)
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
outputs = sess.run(output_op)
@@ -347,12 +363,32 @@ class GroupNormTest(test.TestCase):
self.assertAllClose(expected_mean, mean, rtol=tol, atol=tol)
self.assertAllClose(expected_var, var, rtol=tol, atol=tol)
+ def doOutputTestForMeanCloseToZero(self,
+ input_shape,
+ channels_axis=None,
+ reduction_axes=None,
+ groups=2,
+ tol=5e-2):
+ self.doOutputTest(
+ input_shape,
+ channels_axis=channels_axis,
+ reduction_axes=reduction_axes,
+ groups=groups,
+ tol=tol,
+ mean_close_to_zero=True)
+
def testOutputSmallInput4D_NHWC(self):
input_shape = [10, 10, 10, 30]
# Specify axes with positive values.
self.doOutputTest(input_shape, channels_axis=3, reduction_axes=[1, 2])
# Specify axes with negative values.
self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2])
+ # Specify axes with positive values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=3, reduction_axes=[1, 2])
+ # Specify axes with negative values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=-1, reduction_axes=[-3, -2])
def testOutputSmallInput3D_NHWC(self):
input_shape = [10, 10, 30]
@@ -360,6 +396,12 @@ class GroupNormTest(test.TestCase):
self.doOutputTest(input_shape, channels_axis=2, reduction_axes=[0, 1])
# Specify axes with negative values.
self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2])
+ # Specify axes with positive values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=2, reduction_axes=[0, 1])
+ # Specify axes with negative values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=-1, reduction_axes=[-3, -2])
def testOutputSmallInput4D_NCHW(self):
input_shape = [10, 10, 10, 30]
@@ -367,6 +409,12 @@ class GroupNormTest(test.TestCase):
self.doOutputTest(input_shape, channels_axis=1, reduction_axes=[2, 3])
# Specify axes with negative values.
self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1])
+ # Specify axes with positive values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=1, reduction_axes=[2, 3])
+ # Specify axes with negative values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=-3, reduction_axes=[-2, -1])
def testOutputSmallInput3D_NCHW(self):
input_shape = [10, 10, 30]
@@ -374,23 +422,43 @@ class GroupNormTest(test.TestCase):
self.doOutputTest(input_shape, channels_axis=0, reduction_axes=[1, 2])
# Specify axes with negative values.
self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1])
+ # Specify axes with positive values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=0, reduction_axes=[1, 2])
+ # Specify axes with negative values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=-3, reduction_axes=[-2, -1])
def testOutputBigInput4D_NHWC(self):
- self.doOutputTest([5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2],
- groups=1)
+ self.doOutputTest(
+ [5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2], groups=1)
+ self.doOutputTestForMeanCloseToZero(
+ [5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2], groups=1)
def testOutputBigInput4D_NCHW(self):
- self.doOutputTest([1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3],
- groups=4)
+ self.doOutputTest(
+ [1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3], groups=4)
+ self.doOutputTestForMeanCloseToZero(
+ [1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3], groups=4)
def testOutputSmallInput2D_NC(self):
- self.doOutputTest([10, 7*100], channels_axis=1, reduction_axes=[], groups=7)
+ self.doOutputTest(
+ [10, 7 * 100], channels_axis=1, reduction_axes=[], groups=7)
+ self.doOutputTestForMeanCloseToZero(
+ [10, 7 * 100], channels_axis=1, reduction_axes=[], groups=7)
def testOutputSmallInput5D_NCXXX(self):
- self.doOutputTest([10, 10, 20, 40, 5],
- channels_axis=1,
- reduction_axes=[2, 3, 4],
- groups=5)
+ self.doOutputTest(
+ [10, 10, 20, 40, 5],
+ channels_axis=1,
+ reduction_axes=[2, 3, 4],
+ groups=5)
+ self.doOutputTestForMeanCloseToZero(
+ [10, 10, 20, 40, 5],
+ channels_axis=1,
+ reduction_axes=[2, 3, 4],
+ groups=5)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
index dad3da3748..b25f11b5a6 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
@@ -151,9 +151,19 @@ def _rev_block_forward(x1,
return y1, y2
+def _safe_wraps(fn):
+ if isinstance(fn, functools.partial):
+ # functools.partial objects cannot be wrapped as they are missing the
+ # necessary properties (__name__, __module__, __doc__).
+ def passthrough(f):
+ return f
+ return passthrough
+ return functools.wraps(fn)
+
+
def _scope_wrap(fn, scope):
- @functools.wraps(fn)
+ @_safe_wraps(fn)
def wrap(*args, **kwargs):
with variable_scope.variable_scope(scope, use_resource=True):
return fn(*args, **kwargs)
@@ -430,7 +440,7 @@ def rev_block(x1,
def enable_with_args(dec):
"""A decorator for decorators to enable their usage with or without args."""
- @functools.wraps(dec)
+ @_safe_wraps(dec)
def new_dec(*args, **kwargs):
if len(args) == 1 and not kwargs and callable(args[0]):
# Used as decorator without args
@@ -477,7 +487,7 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
tf.gradients).
"""
- @functools.wraps(fn)
+ @_safe_wraps(fn)
def wrapped(*args):
return _recompute_grad(
fn, args, use_data_dep=use_data_dep, tupleize_grads=tupleize_grads)
diff --git a/tensorflow/contrib/learn/__init__.py b/tensorflow/contrib/learn/__init__.py
index 79bd73faaf..28a6f5aed9 100644
--- a/tensorflow/contrib/learn/__init__.py
+++ b/tensorflow/contrib/learn/__init__.py
@@ -19,7 +19,8 @@ This module and all its submodules are deprecated. See
[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
for migration instructions.
-See the @{$python/contrib.learn} guide.
+See the [Contrib Learn](https://tensorflow.org/api_guides/python/contrib.learn)
+guide.
@@BaseEstimator
@@Estimator
diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py
index a262a099cf..cbe4c03e4d 100644
--- a/tensorflow/contrib/linalg/__init__.py
+++ b/tensorflow/contrib/linalg/__init__.py
@@ -14,7 +14,8 @@
# ==============================================================================
"""Linear algebra libraries.
-See the @{$python/contrib.linalg} guide.
+See the[Contrib Linalg](https://tensorflow.org/api_guides/python/contrib.linalg)
+guide.
@@LinearOperator
@@LinearOperatorBlockDiag
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index 1e6f1e7da2..0091587bf7 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -154,6 +154,14 @@ cc_library(
"optional_debug_tools.h",
],
copts = tflite_copts(),
+ linkopts = [
+ ] + select({
+ "//tensorflow:android": [
+ "-llog",
+ ],
+ "//conditions:default": [
+ ],
+ }),
deps = [
":arena_planner",
":builtin_op_data",
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index ab694d768f..05d0b453ab 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -228,6 +228,7 @@ def generated_test_models():
"control_dep",
"conv",
"conv_with_shared_weights",
+ "conv_to_depthwiseconv_with_shared_weights",
"depthwiseconv",
"div",
"equal",
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index 8a8eb98568..e0e411e7a1 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -113,6 +113,7 @@ typedef enum {
kTfLiteBuiltinOneHot = 85,
kTfLiteBuiltinLogicalAnd = 86,
kTfLiteBuiltinLogicalNot = 87,
+ kTfLiteBuiltinUnpack = 88,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
index c920f6a508..c7f4df3cdc 100644
--- a/tensorflow/contrib/lite/context.h
+++ b/tensorflow/contrib/lite/context.h
@@ -29,9 +29,6 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
#define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
-#if defined(_MSC_VER)
-#include <complex.h>
-#endif
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
@@ -49,7 +46,8 @@ typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
typedef enum {
kTfLiteEigenContext = 0, // include eigen_support.h to use.
kTfLiteGemmLowpContext = 1, // include gemm_support.h to use.
- kTfLiteMaxExternalContexts = 2
+ kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support.
+ kTfLiteMaxExternalContexts = 3
} TfLiteExternalContextType;
// An external context is a collection of information unrelated to the TF Lite
@@ -152,6 +150,11 @@ void TfLiteIntArrayFree(TfLiteIntArray* v);
} \
} while (0)
+// Single-precision complex data type compatible with the C99 definition.
+typedef struct {
+ float re, im; // real and imaginary parts, respectively.
+} TfLiteComplex64;
+
// Types supported by tensor
typedef enum {
kTfLiteNoType = 0,
@@ -183,11 +186,7 @@ typedef union {
uint8_t* uint8;
bool* b;
int16_t* i16;
-#if defined(_MSC_VER)
- _Fcomplex* c64;
-#else
- _Complex float* c64;
-#endif
+ TfLiteComplex64* c64;
} TfLitePtrUnion;
// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD
index 5a7eb370f6..8abc828578 100644
--- a/tensorflow/contrib/lite/delegates/eager/BUILD
+++ b/tensorflow/contrib/lite/delegates/eager/BUILD
@@ -16,11 +16,10 @@ cc_library(
deps = [
":util",
"//tensorflow/c:c_api_internal",
- "//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:kernel_api",
] + select({
"//tensorflow:android": [
- "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_lib_lite_no_runtime",
],
"//conditions:default": [
"//tensorflow/core:framework",
@@ -55,12 +54,11 @@ cc_library(
":delegate_data",
":kernel",
":util",
- "//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite:util",
] + select({
"//tensorflow:android": [
- "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_lib_lite_no_runtime",
],
"//conditions:default": [
"//tensorflow/core:lib",
@@ -89,7 +87,7 @@ cc_library(
"//tensorflow/core/common_runtime/eager:context",
] + select({
"//tensorflow:android": [
- "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:core_cpu",
@@ -119,7 +117,6 @@ cc_library(
":delegate_data",
":util",
"@flatbuffers",
- "//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite:string",
"//tensorflow/contrib/lite/kernels:kernel_util",
@@ -127,6 +124,9 @@ cc_library(
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:tensor_handle",
] + select({
+ # TODO(b/111881878): The android_tensorflow_lib target pulls in the full
+ # set of core TensorFlow kernels. We may want to revisit this dependency
+ # to allow selective registration via build targets.
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib",
],
@@ -167,13 +167,11 @@ cc_library(
srcs = ["util.cc"],
hdrs = ["util.h"],
deps = [
- ":constants",
"//tensorflow/c:c_api_internal",
- "//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:kernel_api",
] + select({
"//tensorflow:android": [
- "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_lib_lite_no_runtime",
],
"//conditions:default": [
"//tensorflow/core:lib",
@@ -193,8 +191,3 @@ tf_cc_test(
"@com_google_googletest//:gtest",
],
)
-
-cc_library(
- name = "constants",
- hdrs = ["constants.h"],
-)
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.cc b/tensorflow/contrib/lite/delegates/eager/delegate.cc
index 8ab768575e..45fc158157 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate.cc
@@ -83,27 +83,26 @@ TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
} // namespace delegate
} // namespace eager
-EagerDelegate::EagerDelegate() {}
-
-EagerDelegate::~EagerDelegate() {}
-
-TfLiteStatus EagerDelegate::Apply(Interpreter* interpreter) {
- if (!delegate_) {
- if (!eager::DelegateData::Create(&delegate_data_).ok()) {
- fprintf(stderr, "Unable to initialize TensorFlow context.\n");
- return kTfLiteError;
- }
-
- delegate_.reset(new TfLiteDelegate{
- /*data_=*/delegate_data_.get(),
- /*nullptr,*/ &eager::delegate::Prepare,
- /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle,
- /*CopyToBufferHandle=*/nullptr,
- /*FreeBufferHandle=*/nullptr});
+std::unique_ptr<EagerDelegate> EagerDelegate::Create() {
+ std::unique_ptr<eager::DelegateData> delegate_data;
+ if (!eager::DelegateData::Create(&delegate_data).ok()) {
+ fprintf(stderr, "Unable to initialize TensorFlow context.\n");
+ return nullptr;
}
- return interpreter->ModifyGraphWithDelegate(delegate_.get(),
- /*allow_dynamic_tensors=*/true);
+ return std::unique_ptr<EagerDelegate>(
+ new EagerDelegate(std::move(delegate_data)));
}
+EagerDelegate::EagerDelegate(std::unique_ptr<eager::DelegateData> delegate_data)
+ : TfLiteDelegate{
+ /*data_=*/delegate_data.get(),
+ /*nullptr,*/ &eager::delegate::Prepare,
+ /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle,
+ /*CopyToBufferHandle=*/nullptr,
+ /*FreeBufferHandle=*/nullptr},
+ delegate_data_(std::move(delegate_data)) {}
+
+EagerDelegate::~EagerDelegate() {}
+
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/eager/delegate.h
index a07002f487..6d15ba47dc 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate.h
+++ b/tensorflow/contrib/lite/delegates/eager/delegate.h
@@ -17,7 +17,6 @@ limitations under the License.
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
-#include "tensorflow/contrib/lite/interpreter.h"
namespace tflite {
@@ -30,24 +29,29 @@ namespace tflite {
// interpreters, but it is *not* thread-safe.
//
// Usage:
-// EagerDelegate delegate;
+// auto delegate = EagerDelegate::Create();
// ... build interpreter ...
//
-// delegate.Apply(interpreter);
+// if (delegate) {
+// interpreter->ModifyGraphWithDelegate(
+// delegate.get(), /*allow_dynamic_tensors=*/true);
+// }
// ... run inference ...
// ... destroy interpreter ...
// ... destroy delegate ...
-class EagerDelegate {
+class EagerDelegate : public TfLiteDelegate {
public:
- EagerDelegate();
- ~EagerDelegate();
+ // Creates a delegate that supports TF ops.
+ //
+ // If the underyling TF Eager context creation fails, returns null.
+ static std::unique_ptr<EagerDelegate> Create();
- // Modifies the graph loaded in the interpreter.
- TfLiteStatus Apply(Interpreter* interpreter);
+ ~EagerDelegate();
private:
+ explicit EagerDelegate(std::unique_ptr<eager::DelegateData> delegate_data);
+
std::unique_ptr<eager::DelegateData> delegate_data_;
- std::unique_ptr<TfLiteDelegate> delegate_;
};
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
index 511a239363..eb47f46c0b 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
@@ -28,21 +28,21 @@ using ::testing::ElementsAre;
class DelegateTest : public testing::EagerModelTest {
public:
DelegateTest() {
- // The delegate needs to be constructed before the interpreter because the
- // interpreter references data contained in the delegate.
- delegate_.reset(new EagerDelegate());
+ delegate_ = EagerDelegate::Create();
interpreter_.reset(new Interpreter(&error_reporter_));
}
~DelegateTest() override {
// The delegate needs to be destructed after the interpreter because the
// interpreter references data contained in the delegate.
- delete interpreter_.release();
- delete delegate_.release();
+ interpreter_.reset();
+ delegate_.reset();
}
void ConfigureDelegate() {
- CHECK(delegate_->Apply(interpreter_.get()) == kTfLiteOk);
+ ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(
+ delegate_.get(), /*allow_dynamic_tensors=*/true),
+ kTfLiteOk);
}
private:
diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/eager/util.cc
index c8aa0b7f69..4426c653e6 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.cc
+++ b/tensorflow/contrib/lite/delegates/eager/util.cc
@@ -13,16 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/delegates/eager/util.h"
-#include "tensorflow/contrib/lite/delegates/eager/constants.h"
namespace tflite {
namespace eager {
-bool IsEagerOp(const char* custom_name) {
- return custom_name && strncmp(custom_name, kCustomCodePrefix,
- strlen(kCustomCodePrefix)) == 0;
-}
-
TfLiteStatus ConvertStatus(TfLiteContext* context,
const tensorflow::Status& status) {
if (!status.ok()) {
diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h
index b7363361be..a9407be071 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.h
+++ b/tensorflow/contrib/lite/delegates/eager/util.h
@@ -23,10 +23,6 @@ limitations under the License.
namespace tflite {
namespace eager {
-// Checks whether the prefix of the custom name indicates the operation is an
-// Eager operation.
-bool IsEagerOp(const char* custom_name);
-
// Converts a tensorflow:Status into a TfLiteStatus. If the original status
// represented an error, reports it using the given 'context'.
TfLiteStatus ConvertStatus(TfLiteContext* context,
diff --git a/tensorflow/contrib/lite/delegates/eager/util_test.cc b/tensorflow/contrib/lite/delegates/eager/util_test.cc
index 541d0b1701..53378a1eaf 100644
--- a/tensorflow/contrib/lite/delegates/eager/util_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/util_test.cc
@@ -103,16 +103,6 @@ TEST(UtilTest, TypeConversions) {
EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool));
}
-TEST(UtilTest, IsEagerOp) {
- EXPECT_TRUE(IsEagerOp("Eager"));
- EXPECT_TRUE(IsEagerOp("EagerOp"));
- EXPECT_FALSE(IsEagerOp("eager"));
- EXPECT_FALSE(IsEagerOp("Eage"));
- EXPECT_FALSE(IsEagerOp("OpEager"));
- EXPECT_FALSE(IsEagerOp(nullptr));
- EXPECT_FALSE(IsEagerOp(""));
-}
-
} // namespace
} // namespace eager
} // namespace tflite
diff --git a/tensorflow/contrib/lite/examples/android/build.gradle b/tensorflow/contrib/lite/examples/android/build.gradle
index a47fa4bbf6..66a62a921a 100644
--- a/tensorflow/contrib/lite/examples/android/build.gradle
+++ b/tensorflow/contrib/lite/examples/android/build.gradle
@@ -14,6 +14,7 @@ buildscript {
allprojects {
repositories {
+ google()
jcenter()
}
}
diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h
index 98934ce41d..96d2810937 100644
--- a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h
+++ b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h
@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_
-#define TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_IOS_SIMPLE_IOS_IMAGE_LOAD_H_
+#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_IOS_SIMPLE_IOS_IMAGE_LOAD_H_
#include <vector>
std::vector<uint8_t> LoadImageFromFile(const char* file_name, int* out_width,
int* out_height, int* out_channels);
-#endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_
+#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_IOS_SIMPLE_IOS_IMAGE_LOAD_H_
diff --git a/tensorflow/contrib/lite/g3doc/_book.yaml b/tensorflow/contrib/lite/g3doc/_book.yaml
index 98abd5743b..1dffe30790 100644
--- a/tensorflow/contrib/lite/g3doc/_book.yaml
+++ b/tensorflow/contrib/lite/g3doc/_book.yaml
@@ -1,6 +1,7 @@
upper_tabs:
# Tabs left of dropdown menu
- include: /_upper_tabs_left.yaml
+- include: /versions/_upper_tabs_versions.yaml
# Dropdown menu
- name: Ecosystem
path: /ecosystem
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index a27df4b964..7d69aa2ad3 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -413,7 +413,12 @@ class Interpreter {
return op_reg.profiling_string(&context_, node);
}
+ // Set the value of an external context.
+ void SetExternalContext(TfLiteExternalContextType type,
+ TfLiteExternalContext* ctx);
+
private:
+ friend class InterpreterBuilder;
friend class InterpreterTest;
// Prevent 'context_' from accessing functions that are only available to
@@ -543,12 +548,30 @@ class Interpreter {
struct TfLiteContext* context, TfLiteExternalContextType type);
// Set the value of an external context.
- void SetExternalContext(TfLiteExternalContextType type,
- TfLiteExternalContext* ctx);
static void SetExternalContext(struct TfLiteContext* context,
TfLiteExternalContextType type,
TfLiteExternalContext* ctx);
+ using TfLiteDelegatePtr =
+ std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
+
+ // Variant of the public ModifyGraphWithDelegate method that additionally
+ // Assumes ownership of the provided delegate.
+ // WARNING: This is an experimental API and subject to change.
+ template <typename Delegate>
+ TfLiteStatus ModifyGraphWithDelegate(std::unique_ptr<Delegate> typed_delegate,
+ bool allow_dynamic_tensors = false) {
+ TfLiteDelegatePtr delegate(typed_delegate.release(),
+ [](TfLiteDelegate* delegate) {
+ delete static_cast<Delegate*>(delegate);
+ });
+ // Note that we retain ownership of the delegate even if graph modification
+ // fails, as delegate use will be in an indeterminate state at that point.
+ owned_delegates_.push_back(std::move(delegate));
+ return ModifyGraphWithDelegate(owned_delegates_.back().get(),
+ allow_dynamic_tensors);
+ }
+
// Ensures that `tensors_` has at least `kTensorsCapacityHeadroom` extra
// capacity. Calling this function may invalidate existing pointers to
// tensors. After calling this function, adding `kTensorsCapacityHeadroom`
@@ -628,6 +651,11 @@ class Interpreter {
// Whether to delegate to NN API
std::unique_ptr<NNAPIDelegate> nnapi_delegate_;
+ // List of delegates that have been installed and are owned by this
+ // interpreter instance. Useful if client delegate ownership is burdensome.
+ // WARNING: This is an experimental API and subject to change.
+ std::vector<TfLiteDelegatePtr> owned_delegates_;
+
std::unique_ptr<MemoryPlanner> memory_planner_;
bool allow_buffer_handle_output_ = false;
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index f00697826c..5bcf0927d8 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -26,6 +26,13 @@ namespace tflite {
// InterpreterTest is a friend of Interpreter, so it can access context_.
class InterpreterTest : public ::testing::Test {
+ public:
+ template <typename Delegate>
+ static TfLiteStatus ModifyGraphWithDelegate(
+ Interpreter* interpreter, std::unique_ptr<Delegate> delegate) {
+ return interpreter->ModifyGraphWithDelegate(std::move(delegate));
+ }
+
protected:
TfLiteContext* GetInterpreterContext() { return &interpreter_.context_; }
@@ -1302,6 +1309,57 @@ TEST_F(TestDelegateWithDynamicTensors, AllowDynamicTensors) {
ASSERT_EQ(interpreter_->execution_plan()[0], 1);
}
+TEST(TestDelegateOwnership, ProperlyDisposed) {
+ struct TfLiteInterpreterOwnedDelegate : public TfLiteDelegate {
+ TfLiteInterpreterOwnedDelegate(bool* destroyed, bool* prepared)
+ : destroyed(destroyed), prepared(prepared) {
+ Prepare = [](TfLiteContext*, TfLiteDelegate* delegate) -> TfLiteStatus {
+ *static_cast<TfLiteInterpreterOwnedDelegate*>(delegate)->prepared =
+ true;
+ return kTfLiteOk;
+ };
+ }
+ ~TfLiteInterpreterOwnedDelegate() { *destroyed = true; }
+
+ bool* destroyed;
+ bool* prepared;
+ };
+
+ // Construct a delegate with flags for indicating preparation/destruction.
+ bool destroyed = false;
+ bool prepared = false;
+ std::unique_ptr<TfLiteInterpreterOwnedDelegate> delegate(
+ new TfLiteInterpreterOwnedDelegate(&destroyed, &prepared));
+ {
+ // Create an interpreter and assemble a simple graph.
+ Interpreter interpreter;
+ TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr};
+ ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetOutputs({1}), kTfLiteOk);
+ ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr,
+ &registration),
+ kTfLiteOk);
+
+ // Pass delegate ownership to that interpreter.
+ ASSERT_EQ(InterpreterTest::ModifyGraphWithDelegate(&interpreter,
+ std::move(delegate)),
+ kTfLiteOk);
+
+ // The delegate should be prepared as normal, and should be preserved.
+ EXPECT_TRUE(prepared);
+ EXPECT_FALSE(destroyed);
+
+ // Interpreter interaction should not impact the delegate's validity.
+ interpreter.AllocateTensors();
+ interpreter.Invoke();
+ EXPECT_FALSE(destroyed);
+ }
+
+ // Only after the interpreter is destroyed should the delegate be destroyed.
+ EXPECT_TRUE(destroyed);
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
index 94a1ec65d6..41093e8ffe 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
@@ -15,8 +15,8 @@ limitations under the License.
package org.tensorflow.lite;
-/** Type of elements in a {@link TfLiteTensor}. */
-enum DataType {
+/** Represents the type of elements in a TensorFlow Lite {@link Tensor} as an enum. */
+public enum DataType {
/** 32-bit single precision floating point. */
FLOAT32(1),
@@ -35,13 +35,29 @@ enum DataType {
this.value = value;
}
- /** Corresponding value of the kTfLite* enum in the TensorFlow Lite CC API. */
- int getNumber() {
+ /** Returns the size of an element of this type, in bytes, or -1 if element size is variable. */
+ public int byteSize() {
+ switch (this) {
+ case FLOAT32:
+ return 4;
+ case INT32:
+ return 4;
+ case UINT8:
+ return 1;
+ case INT64:
+ return 8;
+ }
+ throw new IllegalArgumentException(
+ "DataType error: DataType " + this + " is not supported yet");
+ }
+
+ /** Corresponding value of the TfLiteType enum in the TensorFlow Lite C API. */
+ int c() {
return value;
}
- /** Converts an integer to the corresponding type. */
- static DataType fromNumber(int c) {
+ /** Converts a C TfLiteType enum value to the corresponding type. */
+ static DataType fromC(int c) {
for (DataType t : values) {
if (t.value == c) {
return t;
@@ -55,22 +71,6 @@ enum DataType {
+ ")");
}
- /** Returns byte size of the type. */
- int elemByteSize() {
- switch (this) {
- case FLOAT32:
- return 4;
- case INT32:
- return 4;
- case UINT8:
- return 1;
- case INT64:
- return 8;
- }
- throw new IllegalArgumentException(
- "DataType error: DataType " + this + " is not supported yet");
- }
-
/** Gets string names of the data type. */
String toStringName() {
switch (this) {
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index 7002f82677..b84720ae8e 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -162,9 +162,7 @@ public final class Interpreter implements AutoCloseable {
*/
public void runForMultipleInputsOutputs(
@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {
- if (wrapper == null) {
- throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
- }
+ checkNotClosed();
wrapper.run(inputs, outputs);
}
@@ -174,12 +172,16 @@ public final class Interpreter implements AutoCloseable {
* <p>IllegalArgumentException will be thrown if it fails to resize.
*/
public void resizeInput(int idx, @NonNull int[] dims) {
- if (wrapper == null) {
- throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
- }
+ checkNotClosed();
wrapper.resizeInput(idx, dims);
}
+ /** Gets the number of input tensors. */
+ public int getInputTensorCount() {
+ checkNotClosed();
+ return wrapper.getInputTensorCount();
+ }
+
/**
* Gets index of an input given the op name of the input.
*
@@ -187,51 +189,65 @@ public final class Interpreter implements AutoCloseable {
* to initialize the {@link Interpreter}.
*/
public int getInputIndex(String opName) {
- if (wrapper == null) {
- throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
- }
+ checkNotClosed();
return wrapper.getInputIndex(opName);
}
/**
+ * Gets the Tensor associated with the provdied input index.
+ *
+ * <p>IllegalArgumentException will be thrown if the provided index is invalid.
+ */
+ public Tensor getInputTensor(int inputIndex) {
+ checkNotClosed();
+ return wrapper.getInputTensor(inputIndex);
+ }
+
+ /** Gets the number of output Tensors. */
+ public int getOutputTensorCount() {
+ checkNotClosed();
+ return wrapper.getOutputTensorCount();
+ }
+
+ /**
* Gets index of an output given the op name of the output.
*
* <p>IllegalArgumentException will be thrown if the op name does not exist in the model file used
* to initialize the {@link Interpreter}.
*/
public int getOutputIndex(String opName) {
- if (wrapper == null) {
- throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
- }
+ checkNotClosed();
return wrapper.getOutputIndex(opName);
}
/**
+ * Gets the Tensor associated with the provdied output index.
+ *
+ * <p>IllegalArgumentException will be thrown if the provided index is invalid.
+ */
+ public Tensor getOutputTensor(int outputIndex) {
+ checkNotClosed();
+ return wrapper.getOutputTensor(outputIndex);
+ }
+
+ /**
* Returns native inference timing.
* <p>IllegalArgumentException will be thrown if the model is not initialized by the
* {@link Interpreter}.
*/
public Long getLastNativeInferenceDurationNanoseconds() {
- if (wrapper == null) {
- throw new IllegalStateException("Internal error: The interpreter has already been closed.");
- }
+ checkNotClosed();
return wrapper.getLastNativeInferenceDurationNanoseconds();
}
/** Turns on/off Android NNAPI for hardware acceleration when it is available. */
public void setUseNNAPI(boolean useNNAPI) {
- if (wrapper != null) {
- wrapper.setUseNNAPI(useNNAPI);
- } else {
- throw new IllegalStateException(
- "Internal error: NativeInterpreterWrapper has already been closed.");
- }
+ checkNotClosed();
+ wrapper.setUseNNAPI(useNNAPI);
}
public void setNumThreads(int numThreads) {
- if (wrapper == null) {
- throw new IllegalStateException("The interpreter has already been closed.");
- }
+ checkNotClosed();
wrapper.setNumThreads(numThreads);
}
@@ -253,5 +269,11 @@ public final class Interpreter implements AutoCloseable {
}
}
+ private void checkNotClosed() {
+ if (wrapper == null) {
+ throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
+ }
+ }
+
NativeInterpreterWrapper wrapper;
}
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index 767a220f8c..fa25082304 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -114,12 +114,10 @@ final class NativeInterpreterWrapper implements AutoCloseable {
}
}
- if (!isMemoryAllocated) {
+ boolean needsAllocation = !isMemoryAllocated;
+ if (needsAllocation) {
allocateTensors(interpreterHandle, errorHandle);
isMemoryAllocated = true;
- // Allocation can trigger dynamic resizing of output tensors, so clear the
- // output tensor cache.
- Arrays.fill(outputTensors, null);
}
for (int i = 0; i < inputs.length; ++i) {
@@ -130,6 +128,14 @@ final class NativeInterpreterWrapper implements AutoCloseable {
run(interpreterHandle, errorHandle);
long inferenceDurationNanoseconds = System.nanoTime() - inferenceStartNanos;
+ // Allocation can trigger dynamic resizing of output tensors, so refresh all output shapes.
+ if (needsAllocation) {
+ for (int i = 0; i < outputTensors.length; ++i) {
+ if (outputTensors[i] != null) {
+ outputTensors[i].refreshShape();
+ }
+ }
+ }
for (Map.Entry<Integer, Object> output : outputs.entrySet()) {
getOutputTensor(output.getKey()).copyTo(output.getValue());
}
@@ -144,8 +150,9 @@ final class NativeInterpreterWrapper implements AutoCloseable {
void resizeInput(int idx, int[] dims) {
if (resizeInput(interpreterHandle, errorHandle, idx, dims)) {
isMemoryAllocated = false;
- // Resizing will invalidate the Tensor's shape, so invalidate the Tensor handle.
- inputTensors[idx] = null;
+ if (inputTensors[idx] != null) {
+ inputTensors[idx].refreshShape();
+ }
}
}
@@ -230,6 +237,11 @@ final class NativeInterpreterWrapper implements AutoCloseable {
return getOutputQuantizationScale(interpreterHandle, index);
}
+ /** Gets the number of input tensors. */
+ int getInputTensorCount() {
+ return inputTensors.length;
+ }
+
/**
* Gets the input {@link Tensor} for the provided input index.
*
@@ -247,6 +259,11 @@ final class NativeInterpreterWrapper implements AutoCloseable {
return inputTensor;
}
+ /** Gets the number of output tensors. */
+ int getOutputTensorCount() {
+ return inputTensors.length;
+ }
+
/**
* Gets the output {@link Tensor} for the provided output index.
*
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
index 2403570c52..f174178d98 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
@@ -26,7 +26,7 @@ import java.util.Arrays;
* <p>The native handle of a {@code Tensor} belongs to {@code NativeInterpreterWrapper}, thus not
* needed to be closed here.
*/
-final class Tensor {
+public final class Tensor {
static Tensor fromHandle(long nativeHandle) {
return new Tensor(nativeHandle);
@@ -37,11 +37,26 @@ final class Tensor {
return dtype;
}
+ /**
+ * Returns the number of dimensions (sometimes referred to as <a
+ * href="https://www.tensorflow.org/resources/dims_types.html#rank">rank</a>) of the Tensor.
+ *
+ * <p>Will be 0 for a scalar, 1 for a vector, 2 for a matrix, 3 for a 3-dimensional tensor etc.
+ */
+ public int numDimensions() {
+ return shapeCopy.length;
+ }
+
/** Returns the size, in bytes, of the tensor data. */
public int numBytes() {
return numBytes(nativeHandle);
}
+ /** Returns the number of elements in a flattened (1-D) view of the tensor. */
+ public int numElements() {
+ return computeNumElements(shapeCopy);
+ }
+
/**
* Returns the <a href="https://www.tensorflow.org/resources/dims_types.html#shape">shape</a> of
* the Tensor, i.e., the sizes of each dimension.
@@ -103,13 +118,22 @@ final class Tensor {
if (isByteBuffer(input)) {
return null;
}
- int[] inputShape = shapeOf(input);
+ int[] inputShape = computeShapeOf(input);
if (Arrays.equals(shapeCopy, inputShape)) {
return null;
}
return inputShape;
}
+ /**
+ * Forces a refresh of the tensor's cached shape.
+ *
+ * <p>This is useful if the tensor is resized or has a dynamic shape.
+ */
+ void refreshShape() {
+ this.shapeCopy = shape(nativeHandle);
+ }
+
/** Returns the type of the data. */
static DataType dataTypeOf(Object o) {
if (o != null) {
@@ -132,22 +156,31 @@ final class Tensor {
}
/** Returns the shape of an object as an int array. */
- static int[] shapeOf(Object o) {
- int size = numDimensions(o);
+ static int[] computeShapeOf(Object o) {
+ int size = computeNumDimensions(o);
int[] dimensions = new int[size];
fillShape(o, 0, dimensions);
return dimensions;
}
+ /** Returns the number of elements in a flattened (1-D) view of the tensor's shape. */
+ static int computeNumElements(int[] shape) {
+ int n = 1;
+ for (int i = 0; i < shape.length; ++i) {
+ n *= shape[i];
+ }
+ return n;
+ }
+
/** Returns the number of dimensions of a multi-dimensional array, otherwise 0. */
- static int numDimensions(Object o) {
+ static int computeNumDimensions(Object o) {
if (o == null || !o.getClass().isArray()) {
return 0;
}
if (Array.getLength(o) == 0) {
throw new IllegalArgumentException("Array lengths cannot be 0.");
}
- return 1 + numDimensions(Array.get(o, 0));
+ return 1 + computeNumDimensions(Array.get(o, 0));
}
/** Recursively populates the shape dimensions for a given (multi-dimensional) array. */
@@ -188,7 +221,7 @@ final class Tensor {
dtype, o.getClass().getName(), oType));
}
- int[] oShape = shapeOf(o);
+ int[] oShape = computeShapeOf(o);
if (!Arrays.equals(oShape, shapeCopy)) {
throw new IllegalArgumentException(
String.format(
@@ -204,11 +237,11 @@ final class Tensor {
private final long nativeHandle;
private final DataType dtype;
- private final int[] shapeCopy;
+ private int[] shapeCopy;
private Tensor(long nativeHandle) {
this.nativeHandle = nativeHandle;
- this.dtype = DataType.fromNumber(dtype(nativeHandle));
+ this.dtype = DataType.fromC(dtype(nativeHandle));
this.shapeCopy = shape(nativeHandle);
}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java
index cebc944200..6d6417f895 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java
@@ -26,9 +26,16 @@ public final class DataTypeTest {
@Test
public void testElemByteSize() {
- assertThat(DataType.FLOAT32.elemByteSize()).isEqualTo(4);
- assertThat(DataType.INT32.elemByteSize()).isEqualTo(4);
- assertThat(DataType.UINT8.elemByteSize()).isEqualTo(1);
- assertThat(DataType.INT64.elemByteSize()).isEqualTo(8);
+ assertThat(DataType.FLOAT32.byteSize()).isEqualTo(4);
+ assertThat(DataType.INT32.byteSize()).isEqualTo(4);
+ assertThat(DataType.UINT8.byteSize()).isEqualTo(1);
+ assertThat(DataType.INT64.byteSize()).isEqualTo(8);
+ }
+
+ @Test
+ public void testConversion() {
+ for (DataType dataType : DataType.values()) {
+ assertThat(DataType.fromC(dataType.c())).isEqualTo(dataType);
+ }
}
}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index d66a73db94..9070b788b6 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -47,6 +47,10 @@ public final class InterpreterTest {
public void testInterpreter() throws Exception {
Interpreter interpreter = new Interpreter(MODEL_FILE);
assertThat(interpreter).isNotNull();
+ assertThat(interpreter.getInputTensorCount()).isEqualTo(1);
+ assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(interpreter.getOutputTensorCount()).isEqualTo(1);
+ assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
interpreter.close();
}
@@ -183,6 +187,19 @@ public final class InterpreterTest {
}
@Test
+ public void testResizeInput() {
+ try (Interpreter interpreter = new Interpreter(MODEL_FILE)) {
+ int[] inputDims = {1};
+ interpreter.resizeInput(0, inputDims);
+ assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(inputDims);
+ ByteBuffer input = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder());
+ ByteBuffer output = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder());
+ interpreter.run(input, output);
+ assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims);
+ }
+ }
+
+ @Test
public void testMobilenetRun() {
// Create a gray image.
float[][][][] img = new float[1][224][224][3];
@@ -199,6 +216,8 @@ public final class InterpreterTest {
Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE);
interpreter.run(img, labels);
+ assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(new int[] {1, 224, 224, 3});
+ assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(new int[] {1, 1001});
interpreter.close();
assertThat(labels[0])
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
index 71ef044943..85ad393d89 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
@@ -64,6 +64,8 @@ public final class TensorTest {
assertThat(tensor.shape()).isEqualTo(expectedShape);
assertThat(tensor.dataType()).isEqualTo(DataType.FLOAT32);
assertThat(tensor.numBytes()).isEqualTo(2 * 8 * 8 * 3 * 4);
+ assertThat(tensor.numElements()).isEqualTo(2 * 8 * 8 * 3);
+ assertThat(tensor.numDimensions()).isEqualTo(4);
}
@Test
@@ -201,12 +203,12 @@ public final class TensorTest {
@Test
public void testNumDimensions() {
int scalar = 1;
- assertThat(Tensor.numDimensions(scalar)).isEqualTo(0);
+ assertThat(Tensor.computeNumDimensions(scalar)).isEqualTo(0);
int[][] array = {{2, 4}, {1, 9}};
- assertThat(Tensor.numDimensions(array)).isEqualTo(2);
+ assertThat(Tensor.computeNumDimensions(array)).isEqualTo(2);
try {
int[] emptyArray = {};
- Tensor.numDimensions(emptyArray);
+ Tensor.computeNumDimensions(emptyArray);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("Array lengths cannot be 0.");
@@ -214,9 +216,21 @@ public final class TensorTest {
}
@Test
+ public void testNumElements() {
+ int[] scalarShape = {};
+ assertThat(Tensor.computeNumElements(scalarShape)).isEqualTo(1);
+ int[] vectorShape = {3};
+ assertThat(Tensor.computeNumElements(vectorShape)).isEqualTo(3);
+ int[] matrixShape = {3, 4};
+ assertThat(Tensor.computeNumElements(matrixShape)).isEqualTo(12);
+ int[] degenerateShape = {3, 4, 0};
+ assertThat(Tensor.computeNumElements(degenerateShape)).isEqualTo(0);
+ }
+
+ @Test
public void testFillShape() {
int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}};
- int num = Tensor.numDimensions(array);
+ int num = Tensor.computeNumDimensions(array);
int[] shape = new int[num];
Tensor.fillShape(array, 0, shape);
assertThat(num).isEqualTo(3);
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index c5586475ec..1f528fdab9 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -225,6 +225,7 @@ cc_library(
"//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite:util",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/kernels/internal:audio_utils",
"//tensorflow/contrib/lite/kernels/internal:kernel_utils",
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
index 7f0676be27..df4d871466 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
@@ -46,8 +46,8 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
inline void Relu(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Relu(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
// legacy, for compatibility with old checked-in code
@@ -580,8 +580,8 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
inline void Logistic(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Logistic(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
@@ -601,8 +601,8 @@ inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
inline void Tanh(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Tanh(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 2d172315da..f19df5e17e 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -2327,8 +2327,8 @@ void GlobalBatchNormalization(const float* input_data,
}
}
-inline void Relu(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+inline void Relu(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Relu (not fused)");
const auto input = MapAsVector(input_data, input_shape);
@@ -2946,7 +2946,58 @@ void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
inline void MulElementwise(int size, const ArithmeticParams& params,
const uint8* input1_data, const uint8* input2_data,
uint8* output_data) {
- for (int i = 0; i < size; ++i) {
+ int i = 0;
+ TFLITE_DCHECK_GT(params.input1_offset, -256);
+ TFLITE_DCHECK_LT(params.input1_offset, 256);
+ TFLITE_DCHECK_GT(params.input2_offset, -256);
+ TFLITE_DCHECK_LT(params.input2_offset, 256);
+ TFLITE_DCHECK_GT(params.output_offset, -256);
+ TFLITE_DCHECK_LT(params.output_offset, 256);
+#ifdef USE_NEON
+ const auto input1_offset_vector = vdupq_n_s16(params.input1_offset);
+ const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
+ const auto output_offset_vector = vdupq_n_s16(params.output_offset);
+ const auto output_activation_min_vector =
+ vdup_n_u8(params.quantized_activation_min);
+ const auto output_activation_max_vector =
+ vdup_n_u8(params.quantized_activation_max);
+ for (; i <= size - 8; i += 8) {
+ // We load / store 8 at a time, multiplying as two sets of 4 int32s.
+ const auto input1_val_original = vld1_u8(input1_data + i);
+ const auto input2_val_original = vld1_u8(input2_data + i);
+ const auto input1_val_s16 =
+ vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
+ const auto input2_val_s16 =
+ vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
+ const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector);
+ const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
+
+ const auto input1_val_low = vget_low_s16(input1_val);
+ const auto input1_val_high = vget_high_s16(input1_val);
+ const auto input2_val_low = vget_low_s16(input2_val);
+ const auto input2_val_high = vget_high_s16(input2_val);
+
+ auto p1 = vmull_s16(input2_val_low, input1_val_low);
+ auto p2 = vmull_s16(input2_val_high, input1_val_high);
+
+ p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
+ p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
+ using gemmlowp::RoundingDivideByPOT;
+ p1 = RoundingDivideByPOT(p1, -params.output_shift);
+ p2 = RoundingDivideByPOT(p2, -params.output_shift);
+
+ const auto p1_narrowed = vmovn_s32(p1);
+ const auto p2_narrowed = vmovn_s32(p2);
+ const auto p =
+ vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
+ const auto clamped =
+ vmax_u8(output_activation_min_vector,
+ vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
+ vst1_u8(output_data + i, clamped);
+ }
+#endif // NEON
+
+ for (; i < size; ++i) {
const int32 input1_val = params.input1_offset + input1_data[i];
const int32 input2_val = params.input2_offset + input2_data[i];
const int32 unclamped_result =
@@ -2965,9 +3016,53 @@ inline void MulElementwise(int size, const ArithmeticParams& params,
inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
const uint8 broadcast_value,
const uint8* input2_data, uint8* output_data) {
- const int32 input1_val = params.input1_offset + broadcast_value;
+ const int16 input1_val = params.input1_offset + broadcast_value;
+
+ int i = 0;
+ TFLITE_DCHECK_GT(params.input1_offset, -256);
+ TFLITE_DCHECK_LT(params.input1_offset, 256);
+ TFLITE_DCHECK_GT(params.input2_offset, -256);
+ TFLITE_DCHECK_LT(params.input2_offset, 256);
+ TFLITE_DCHECK_GT(params.output_offset, -256);
+ TFLITE_DCHECK_LT(params.output_offset, 256);
+#ifdef USE_NEON
+ const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
+ const auto output_offset_vector = vdupq_n_s16(params.output_offset);
+ const auto output_activation_min_vector =
+ vdup_n_u8(params.quantized_activation_min);
+ const auto output_activation_max_vector =
+ vdup_n_u8(params.quantized_activation_max);
+ for (; i <= size - 8; i += 8) {
+ // We load / store 8 at a time, multiplying as two sets of 4 int32s.
+ const auto input2_val_original = vld1_u8(input2_data + i);
+ const auto input2_val_s16 =
+ vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
+ const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
+
+ const auto input2_val_low = vget_low_s16(input2_val);
+ const auto input2_val_high = vget_high_s16(input2_val);
+
+ auto p1 = vmull_n_s16(input2_val_low, input1_val);
+ auto p2 = vmull_n_s16(input2_val_high, input1_val);
+
+ p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
+ p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
+ using gemmlowp::RoundingDivideByPOT;
+ p1 = RoundingDivideByPOT(p1, -params.output_shift);
+ p2 = RoundingDivideByPOT(p2, -params.output_shift);
- for (int i = 0; i < size; ++i) {
+ const auto p1_narrowed = vmovn_s32(p1);
+ const auto p2_narrowed = vmovn_s32(p2);
+ const auto p =
+ vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
+ const auto clamped =
+ vmax_u8(output_activation_min_vector,
+ vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
+ vst1_u8(output_data + i, clamped);
+ }
+#endif // NEON
+
+ for (; i < size; ++i) {
const int32 input2_val = params.input2_offset + input2_data[i];
const int32 unclamped_result =
params.output_offset +
@@ -4449,8 +4544,8 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Logistic(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Logistic");
auto input_map = MapAsVector(input_data, input_shape);
auto output_map = MapAsVector(output_data, output_shape);
@@ -4595,8 +4690,8 @@ inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
- int16* output_data, const RuntimeShape& output_shape) {
+inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+ const RuntimeShape& output_shape, int16* output_data) {
gemmlowp::ScopedProfilingLabel label("Logistic/Int16");
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -4655,8 +4750,14 @@ inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
}
}
-inline void Tanh(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+// Legacy version.
+inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
+ int16* output_data, const RuntimeShape& output_shape) {
+ Logistic(input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Tanh");
auto input_map = MapAsVector(input_data, input_shape);
auto output_map = MapAsVector(output_data, output_shape);
@@ -4919,14 +5020,21 @@ inline void Cast(const SrcT* input_data, const Dims<4>& input_dims,
output_map.array() = input_map.array().template cast<DstT>();
}
-inline void Floor(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+inline void Floor(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Floor");
- auto input_map = MapAsVector(input_data, input_dims);
- auto output_map = MapAsVector(output_data, output_dims);
+ auto input_map = MapAsVector(input_data, input_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
output_map.array() = Eigen::floor(input_map.array());
}
+// Legacy Dims<4> version.
+inline void Floor(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
#ifdef USE_NEON
inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
float scale, float* output_ptr) {
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
index b862ae38c7..71ae74f34c 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
@@ -42,20 +42,20 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
inline void Relu(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Relu(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
inline void Relu1(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Relu1(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ Relu1(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
inline void Relu6(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Relu6(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ Relu6(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
template <FusedActivationFunctionType Ac>
@@ -583,8 +583,8 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
inline void Logistic(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Logistic(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
@@ -598,14 +598,14 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
int16* output_data, const Dims<4>& output_dims) {
- Logistic(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
inline void Tanh(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Tanh(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index cb254f36cc..556049d8a6 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -846,8 +846,8 @@ void GlobalBatchNormalization(const float* input_data,
}
}
-inline void Relu(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+inline void Relu(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const float val = input_data[i];
@@ -857,8 +857,8 @@ inline void Relu(const float* input_data, const RuntimeShape& input_shape,
}
}
-inline void Relu1(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+inline void Relu1(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)");
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
@@ -870,8 +870,8 @@ inline void Relu1(const float* input_data, const RuntimeShape& input_shape,
}
}
-inline void Relu6(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+inline void Relu6(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)");
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
@@ -1494,6 +1494,7 @@ inline void BroadcastMul4DSlow(const ArithmeticParams& params,
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
+ // The input shapes are extended as part of NdArrayDesc initialization.
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
RuntimeShape extended_output_shape =
@@ -2075,6 +2076,44 @@ inline void Concatenation(int concat_dim, const uint8* const* input_data,
}
}
+template <typename Scalar>
+void Pack(int dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, const int32* input_zeropoint,
+ const float* input_scale, int inputs_count, Scalar* output_data,
+ const Dims<4>& output_dims, const int32 output_zeropoint,
+ const float output_scale) {
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ int outer_size = 1;
+ for (int i = dim + 1; i < 4; i++) {
+ outer_size *= output_dims.sizes[i];
+ }
+ Scalar* output_ptr = output_data;
+ const int copy_size = FlatSize(**input_dims) / outer_size;
+ const float inverse_output_scale = 1.f / output_scale;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < inputs_count; ++i) {
+ if (input_zeropoint[i] == output_zeropoint &&
+ input_scale[i] == output_scale) {
+ memcpy(output_ptr, input_data[i] + k * copy_size,
+ copy_size * sizeof(Scalar));
+ } else {
+ assert(false);
+ const float scale = input_scale[i] * inverse_output_scale;
+ const float bias = -input_zeropoint[i] * scale;
+ auto input_ptr = input_data[i];
+ for (int j = 0; j < copy_size; ++j) {
+ const int32_t value =
+ static_cast<int32_t>(round(input_ptr[j] * scale + bias)) +
+ output_zeropoint;
+ output_ptr[j] =
+ static_cast<uint8_t>(std::max(std::min(255, value), 0));
+ }
+ }
+ output_ptr += copy_size;
+ }
+ }
+}
+
template <FusedActivationFunctionType Ac, typename Scalar>
void DepthConcatenation(const Scalar* const* input_data,
const Dims<4>* const* input_dims, int inputs_count,
@@ -3117,8 +3156,8 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Logistic(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
@@ -3166,8 +3205,8 @@ inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
- int16* output_data, const RuntimeShape& output_shape) {
+inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+ const RuntimeShape& output_shape, int16* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
@@ -3184,8 +3223,8 @@ inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
}
}
-inline void Tanh(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
@@ -4069,21 +4108,24 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
}
template <typename T, typename Op>
-void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims,
- Op op) {
+void MaximumMinimumBroadcast4DSlow(const RuntimeShape& input1_shape,
+ const T* input1_data,
+ const RuntimeShape& input2_shape,
+ const T* input2_data,
+ const RuntimeShape& output_shape,
+ T* output_data, Op op) {
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- auto out_idx = Offset(output_dims, c, x, y, b);
- auto in1_idx = SubscriptToIndex(desc1, c, x, y, b);
- auto in2_idx = SubscriptToIndex(desc2, c, x, y, b);
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ auto out_idx = Offset(output_shape, b, y, x, c);
+ auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
+ auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
auto in1_val = input1_data[in1_idx];
auto in2_val = input2_data[in2_idx];
output_data[out_idx] = op(in1_val, in2_val);
@@ -4093,9 +4135,20 @@ void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
}
}
+template <typename T, typename Op>
+void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims,
+ Op op) {
+ MaximumMinimumBroadcast4DSlow(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, op);
+}
+
template <typename T1, typename T2, typename T3, typename Cmp>
-void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
- T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
+void ArgMinMax(const T3* axis, const RuntimeShape& input_shape,
+ const T1* input_data, const RuntimeShape& output_shape,
+ T2* output_data, const Cmp& cmp) {
// The current ArgMax implemention can only determine the index of the maximum
// value in the last dimension. So the axis argument is ignored.
@@ -4103,9 +4156,11 @@ void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
// 1). For the sake of simplicity, the output dimensions are equal to the
// input dimensions here. We enforce the constraint that the last dimension
// must always be 1.
- TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1);
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = ArraySize(input_dims, 0);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.Dims(3), 1);
+ const int outer_size = MatchingFlatSizeSkipDim(input_shape, 3, output_shape);
+ const int depth = input_shape.Dims(3);
for (int i = 0; i < outer_size; ++i) {
auto min_max_value = input_data[i * depth];
@@ -4121,6 +4176,15 @@ void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
}
}
+// Legacy Dims<4> version.
+template <typename T1, typename T2, typename T3, typename Cmp>
+void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
+ T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
+ ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data, cmp);
+}
+
+// Legacy.
// TODO(renjieliu): Remove this one.
template <typename T1, typename T2, typename T3>
void ArgMax(const T3* axis, const T1* input_data,
@@ -4253,16 +4317,26 @@ template <typename T>
using ComparisonFn = bool (*)(T, T);
template <typename T, ComparisonFn<T> F>
-inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- bool* output_data, const Dims<4>& output_dims) {
+inline void Comparison(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape& input2_shape, const T* input2_data,
+ const RuntimeShape& output_shape, bool* output_data) {
const int64_t flatsize =
- MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int64_t i = 0; i < flatsize; ++i) {
output_data[i] = F(input1_data[i], input2_data[i]);
}
}
+// Legacy Dims<4> version.
+template <typename T, ComparisonFn<T> F>
+inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ bool* output_data, const Dims<4>& output_dims) {
+ Comparison<T, F>(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
template <typename T, ComparisonFn<int32> F>
inline void Comparison(int left_shift, const T* input1_data,
const Dims<4>& input1_dims, int32 input1_offset,
@@ -4473,69 +4547,156 @@ inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
}
template <typename T>
-inline void Pow(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+inline void Pow(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape& input2_shape, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = std::pow(input1_data[i], input2_data[i]);
}
}
+// Legacy Dims<4> version.
template <typename T>
-inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
+inline void Pow(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ Pow(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
+ input2_data, DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void BroadcastPow4DSlow(const RuntimeShape& input1_shape,
+ const T* input1_data,
+ const RuntimeShape& input2_shape,
+ const T* input2_data,
+ const RuntimeShape& output_shape,
+ T* output_data) {
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- std::pow(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
- input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ auto out_idx = Offset(output_shape, b, y, x, c);
+ auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
+ auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
+ auto in1_val = input1_data[in1_idx];
+ auto in2_val = input2_data[in2_idx];
+ output_data[out_idx] = std::pow(in1_val, in2_val);
}
}
}
}
}
+// Legacy Dims<4> version.
+template <typename T>
+inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ BroadcastPow4DSlow(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Logical(const RuntimeShape& input1_shape, const bool* input1_data,
+ const RuntimeShape& input2_shape, const bool* input2_data,
+ const RuntimeShape& output_shape, bool* output_data,
+ const std::function<bool(bool, bool)>& func) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = func(input1_data[i], input2_data[i]);
+ }
+}
+
+// Legacy Dims<4> version.
inline void Logical(const bool* input1_data, const Dims<4>& input1_dims,
const bool* input2_data, const Dims<4>& input2_dims,
bool* output_data, const Dims<4>& output_dims,
const std::function<bool(bool, bool)>& func) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
- for (int i = 0; i < flat_size; ++i) {
- output_data[i] = func(input1_data[i], input2_data[i]);
+ Logical(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
+ input2_data, DimsToShape(output_dims), output_data, func);
+}
+
+inline void BroadcastLogical4DSlow(
+ const RuntimeShape& input1_shape, const bool* input1_data,
+ const RuntimeShape& input2_shape, const bool* input2_data,
+ const RuntimeShape& output_shape, bool* output_data,
+ const std::function<bool(bool, bool)>& func) {
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ auto out_idx = Offset(output_shape, b, y, x, c);
+ auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
+ auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
+ auto in1_val = input1_data[in1_idx];
+ auto in2_val = input2_data[in2_idx];
+ output_data[out_idx] = func(in1_val, in2_val);
+ }
+ }
+ }
}
}
+// Legacy Dims<4> version.
inline void BroadcastLogical(const bool* input1_data,
const Dims<4>& input1_dims,
const bool* input2_data,
const Dims<4>& input2_dims, bool* output_data,
const Dims<4>& output_dims,
const std::function<bool(bool, bool)>& func) {
+ BroadcastLogical4DSlow(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, func);
+}
+
+// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more
+// generalized and efficient BroadcastBinaryFunction.
+//
+// Also appears to duplicte MinimumMaximum.
+//
+// R: Result type. T1: Input 1 type. T2: Input 2 type.
+template <typename R, typename T1, typename T2>
+inline void BroadcastBinaryFunction4DSlow(const RuntimeShape& input1_shape,
+ const T1* input1_data,
+ const RuntimeShape& input2_shape,
+ const T2* input2_data,
+ const RuntimeShape& output_shape,
+ R* output_data, R (*func)(T1, T2)) {
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- func(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
- input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ auto out_idx = Offset(output_shape, b, y, x, c);
+ auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
+ auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
+ auto in1_val = input1_data[in1_idx];
+ auto in2_val = input2_data[in2_idx];
+ output_data[out_idx] = func(in1_val, in2_val);
}
}
}
}
}
-// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more
-// generalized and efficient BroadcastBinaryFunction.
+// Legacy Dims<4> version.
//
// R: Result type. T1: Input 1 type. T2: Input 2 type.
template <typename R, typename T1, typename T2>
@@ -4545,20 +4706,9 @@ inline void BroadcastBinaryFunction(const T1* input1_data,
const Dims<4>& input2_dims, R* output_data,
const Dims<4>& output_dims,
R (*func)(T1, T2)) {
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- func(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
- input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
- }
- }
- }
- }
+ BroadcastBinaryFunction4DSlow(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, func);
}
} // namespace reference_ops
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 7b6838db53..204df9ab19 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -660,6 +660,19 @@ enum class BroadcastableOpCategory : uint8 {
kGenericBroadcast, // Fall-back.
};
+struct MinMax {
+ float min;
+ float max;
+};
+static_assert(sizeof(MinMax) == 8, "");
+
+struct ActivationParams {
+ FusedActivationFunctionType activation_type;
+ // Quantized inference params.
+ int32 activation_min;
+ int32 activation_max;
+};
+
// For Add, Sub, Mul ops.
struct ArithmeticParams {
// Shape dependent / common to data / op types.
@@ -695,29 +708,122 @@ struct ArithmeticParams {
int broadcast_shape[5];
};
-template <typename T>
-inline void SetActivationParams(T min, T max, ArithmeticParams* params);
+struct ConcatenationParams {
+ int8 axis;
+};
-template <>
-inline void SetActivationParams(float min, float max,
- ArithmeticParams* params) {
- params->float_activation_min = min;
- params->float_activation_max = max;
-}
+struct ComparisonParams {
+ // uint8 inference params.
+ int left_shift;
+ int32 input0_offset;
+ int32 input0_multiplier;
+ int input0_shift;
+ int32 input1_offset;
+ int32 input1_multiplier;
+ int input1_shift;
+ // Shape dependent / common to inference types.
+ bool is_broadcast;
+};
-template <>
-inline void SetActivationParams(int32 min, int32 max,
- ArithmeticParams* params) {
- params->quantized_activation_min = min;
- params->quantized_activation_max = max;
-}
+struct ConvParams {
+ PaddingType padding_type;
+ PaddingValues padding_values;
+ // TODO(starka): This was just "stride", so check that width+height is OK.
+ int8 stride_width;
+ int8 stride_height;
+ int8 dilation_width_factor;
+ int8 dilation_height_factor;
+ // uint8 inference params.
+ // TODO(b/65838351): Use smaller types if appropriate.
+ int32 input_offset;
+ int32 weights_offset;
+ int32 output_offset;
+ int32 output_multiplier;
+ int output_shift;
+ int32 output_activation_min;
+ int32 output_activation_max;
+};
+
+struct DepthToSpaceParams {
+ int16 block_size;
+};
+
+struct DepthwiseParams {
+ PaddingType padding_type;
+ PaddingValues padding_values;
+ int8 stride;
+ int8 depth_multiplier;
+ // uint8 inference params.
+ // TODO(b/65838351): Use smaller types if appropriate.
+ int32 input_offset;
+ int32 weights_offset;
+ int32 output_offset;
+ int32 output_multiplier;
+ int output_shift;
+ int32 output_activation_min;
+ int32 output_activation_max;
+};
+
+struct FakeQuantParams {
+ MinMax minmax;
+ int32 num_bits;
+};
+
+struct FullyConnectedParams {
+ // uint8 inference params.
+ // TODO(b/65838351): Use smaller types if appropriate.
+ int32 input_offset;
+ int32 weights_offset;
+ int32 output_offset;
+ int32 output_multiplier;
+ int output_shift;
+ int32 output_activation_min;
+ int32 output_activation_max;
+ FullyConnectedWeightsFormat weights_format;
+};
+
+struct GatherParams {
+ int8 input_rank;
+ int16 axis;
+};
+
+struct L2NormalizationParams {
+ // uint8 inference params.
+ int32 input_zero_point;
+};
+
+struct LocalResponseNormalizationParams {
+ int32 range;
+ double bias;
+ double alpha;
+ double beta;
+};
+
+struct LogisticParams {
+ // uint8 inference params.
+ int32 input_zero_point;
+ int32 input_range_radius;
+ int32 input_multiplier;
+ int input_left_shift;
+};
+
+struct LstmCellParams {
+ int32 weights_zero_point;
+ int32 accum_multiplier;
+ int accum_shift;
+ int state_integer_bits;
+};
+
+struct MeanParams {
+ int8 axis_count;
+ int16 axis[4];
+};
struct PadParams {
int8 left_padding_count;
int32 left_padding[4];
int8 right_padding_count;
int32 right_padding[4];
- // FloatOrInt pad_value;
};
struct PoolParams {
@@ -736,6 +842,15 @@ struct PoolParams {
float float_activation_max;
};
+struct ReshapeParams {
+ int8 shape_count;
+ int32 shape[4];
+};
+
+struct ResizeBilinearParams {
+ bool align_corners;
+};
+
struct SliceParams {
int8 begin_count;
int32 begin[4];
@@ -743,6 +858,73 @@ struct SliceParams {
int32 size[4];
};
+struct SoftmaxParams {
+ // beta is not really used (not a Tensorflow parameter) and not implemented
+ // for LogSoftmax.
+ double beta;
+ // uint8 inference params. Used even when beta defaults to 1.0.
+ int32 input_beta_multiplier;
+ int32 input_beta_left_shift;
+ // Reverse scaling is only used by LogSoftmax.
+ int32 reverse_scaling_divisor;
+ int32 reverse_scaling_right_shift;
+ int diff_min;
+};
+
+struct SpaceToDepthParams {
+ int16 block_size;
+};
+
+struct SplitParams {
+ // Graphs that split into, say, 2000 nodes are encountered. The indices in
+ // OperatorEdges are of type uint16.
+ uint16 num_split;
+};
+
+struct SqueezeParams {
+ int8 squeeze_dims_count;
+ int32 squeeze_dims[4];
+};
+
+struct StridedSliceParams {
+ int8 start_indices_count;
+ int16 start_indices[4];
+ int8 stop_indices_count;
+ int16 stop_indices[4];
+ int8 strides_count;
+ int16 strides[4];
+
+ int16 begin_mask;
+ int16 ellipsis_mask;
+ int16 end_mask;
+ int16 new_axis_mask;
+ int16 shrink_axis_mask;
+};
+
+struct TanhParams {
+ int32 input_zero_point;
+ int32 input_range_radius;
+ int32 input_multiplier;
+ int input_left_shift;
+};
+
+template <typename T>
+inline void SetActivationParams(T min, T max, ArithmeticParams* params);
+
+template <>
+inline void SetActivationParams(float min, float max,
+ ArithmeticParams* params) {
+ params->float_activation_min = min;
+ params->float_activation_max = max;
+}
+
+template <>
+inline void SetActivationParams(int32 min, int32 max,
+ ArithmeticParams* params) {
+ params->quantized_activation_min = min;
+ params->quantized_activation_max = max;
+}
+
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/pack.cc b/tensorflow/contrib/lite/kernels/pack.cc
index bb3416f6a6..cc326a7d51 100644
--- a/tensorflow/contrib/lite/kernels/pack.cc
+++ b/tensorflow/contrib/lite/kernels/pack.cc
@@ -27,24 +27,9 @@ namespace {
constexpr int kOutputTensor = 0;
-// Op data for pack op.
-struct OpData {
- int values_count;
- int axis;
-};
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- auto* data = new OpData;
- data->axis = 0;
- return data;
-}
-
-void Free(TfLiteContext* context, void* buffer) {
- delete reinterpret_cast<OpData*>(buffer);
-}
-
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- const OpData* data = reinterpret_cast<OpData*>(node->builtin_data);
+ const TfLitePackParams* data =
+ reinterpret_cast<TfLitePackParams*>(node->builtin_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -54,9 +39,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, NumDimensions(input0) >= data->axis);
// TODO(renjieliu): Support negative axis.
TF_LITE_ENSURE(context, data->axis >= 0);
- if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32) {
+ if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32 &&
+ input0->type != kTfLiteUInt8 && input0->type != kTfLiteInt16) {
context->ReportError(context,
- "Currently pack only supports int32 and float32.");
+ "Currently pack only supports "
+ "float32/uint8/int16/int32.");
return kTfLiteError;
}
// Make sure all inputs have the same shape and type.
@@ -82,6 +69,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, output->type, input0->type);
+ // Guarantee input/output quantization params match as we do not support
+ // packing quantized tensors.
+ for (int i = 0; i < data->values_count; i++) {
+ const TfLiteTensor* input = GetInput(context, node, i);
+ TF_LITE_ENSURE_EQ(context, input->params.zero_point,
+ output->params.zero_point);
+ TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale);
+ }
+
return context->ResizeTensor(context, output, output_shape);
}
@@ -95,7 +91,8 @@ void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output,
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- const OpData* data = reinterpret_cast<OpData*>(node->builtin_data);
+ const TfLitePackParams* data =
+ reinterpret_cast<TfLitePackParams*>(node->builtin_data);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (output->type) {
@@ -103,13 +100,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
PackImpl<float>(context, node, output, data->values_count, data->axis);
break;
}
+ case kTfLiteUInt8: {
+ PackImpl<uint8_t>(context, node, output, data->values_count, data->axis);
+ break;
+ }
case kTfLiteInt32: {
PackImpl<int32_t>(context, node, output, data->values_count, data->axis);
break;
}
default: {
context->ReportError(context,
- "Currently pack only supports int32 and float32.");
+ "Currently pack only supports "
+ "float32/uint8/int32.");
return kTfLiteError;
}
}
@@ -121,8 +123,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace pack
TfLiteRegistration* Register_PACK() {
- static TfLiteRegistration r = {pack::Init, pack::Free, pack::Prepare,
- pack::Eval};
+ static TfLiteRegistration r = {nullptr, nullptr, pack::Prepare, pack::Eval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/pack_test.cc b/tensorflow/contrib/lite/kernels/pack_test.cc
index 485a50ad3a..c70dbd2764 100644
--- a/tensorflow/contrib/lite/kernels/pack_test.cc
+++ b/tensorflow/contrib/lite/kernels/pack_test.cc
@@ -51,6 +51,7 @@ class PackOpModel : public SingleOpModel {
int output_;
};
+// float32 tests.
TEST(PackOpTest, FloatThreeInputs) {
PackOpModel<float> model({TensorType_FLOAT32, {2}}, 0, 3);
model.SetInput(0, {1, 4});
@@ -81,7 +82,8 @@ TEST(PackOpTest, FloatMultilDimensions) {
ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
}
-TEST(PackOpTest, IntThreeInputs) {
+// int32 tests.
+TEST(PackOpTest, Int32ThreeInputs) {
PackOpModel<int32_t> model({TensorType_INT32, {2}}, 0, 3);
model.SetInput(0, {1, 4});
model.SetInput(1, {2, 5});
@@ -91,7 +93,7 @@ TEST(PackOpTest, IntThreeInputs) {
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6}));
}
-TEST(PackOpTest, IntThreeInputsDifferentAxis) {
+TEST(PackOpTest, Int32ThreeInputsDifferentAxis) {
PackOpModel<int32_t> model({TensorType_INT32, {2}}, 1, 3);
model.SetInput(0, {1, 4});
model.SetInput(1, {2, 5});
@@ -101,7 +103,7 @@ TEST(PackOpTest, IntThreeInputsDifferentAxis) {
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
-TEST(PackOpTest, IntMultilDimensions) {
+TEST(PackOpTest, Int32MultilDimensions) {
PackOpModel<int32_t> model({TensorType_INT32, {2, 3}}, 1, 2);
model.SetInput(0, {1, 2, 3, 4, 5, 6});
model.SetInput(1, {7, 8, 9, 10, 11, 12});
@@ -110,6 +112,38 @@ TEST(PackOpTest, IntMultilDimensions) {
EXPECT_THAT(model.GetOutput(),
ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
}
+
+// uint8
+TEST(PackOpTest, Uint8ThreeInputs) {
+ PackOpModel<uint8_t> model({TensorType_UINT8, {2}}, 0, 3);
+ model.SetInput(0, {1, 4});
+ model.SetInput(1, {2, 5});
+ model.SetInput(2, {3, 6});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6}));
+}
+
+TEST(PackOpTest, Uint8ThreeInputsDifferentAxis) {
+ PackOpModel<uint8_t> model({TensorType_UINT8, {2}}, 1, 3);
+ model.SetInput(0, {1, 4});
+ model.SetInput(1, {2, 5});
+ model.SetInput(2, {3, 6});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TEST(PackOpTest, Uint8MultilDimensions) {
+ PackOpModel<uint8_t> model({TensorType_UINT8, {2, 3}}, 1, 2);
+ model.SetInput(0, {1, 2, 3, 4, 5, 6});
+ model.SetInput(1, {7, 8, 9, 10, 11, 12});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2, 3));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 6159311910..9681b900b7 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/util.h"
namespace tflite {
namespace ops {
@@ -129,9 +130,7 @@ const TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op,
int version) const {
// Return the NULL Op for all ops whose name start with "Eager", allowing
// the interpreter to delegate their execution.
- // TODO(ycling): Refactoring and extract an `IsEagerOp` function into
- // `lite:framework` build target.
- if (string(op).find("Eager") == 0) {
+ if (IsEagerOp(op)) {
static TfLiteRegistration null_op{
nullptr, nullptr, &UnsupportedTensorFlowOp,
nullptr, nullptr, BuiltinOperator_CUSTOM,
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 9edf5ba38f..5988b7a3a7 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -26,6 +26,9 @@ limitations under the License.
#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/nnapi_delegate.h"
#endif
+#if defined(TFLITE_EXTENDED)
+#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
+#endif
#include "tensorflow/contrib/lite/version.h"
namespace tflite {
@@ -786,6 +789,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_LOGICAL_OR:
case BuiltinOperator_LOGICAL_AND:
case BuiltinOperator_LOGICAL_NOT:
+ case BuiltinOperator_UNPACK:
break;
}
return kTfLiteOk;
@@ -1040,6 +1044,14 @@ TfLiteStatus InterpreterBuilder::operator()(
}
(**interpreter).SetVariables(std::move(variables));
+#if defined(TFLITE_EXTENDED)
+ if (auto delegate = EagerDelegate::Create()) {
+ (**interpreter)
+ .ModifyGraphWithDelegate(std::move(delegate),
+ /*allow_dynamic_tensors=*/true);
+ }
+#endif
+
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 45c92a8671..5d8e7a50e2 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -647,6 +647,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_ONE_HOT:
case tflite::BuiltinOperator_LOGICAL_AND:
case tflite::BuiltinOperator_LOGICAL_NOT:
+ case tflite::BuiltinOperator_UNPACK:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 860aff9e7e..47f0c8e9a2 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -112,8 +112,11 @@ py_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python:framework",
"//tensorflow/python:platform",
+ "//tensorflow/python:util",
],
)
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 11d4bdbe82..12cc66dc55 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import os as _os
+import platform as _platform
import subprocess as _subprocess
import tempfile as _tempfile
@@ -26,6 +27,7 @@ from tensorflow.contrib.lite.python import lite_constants
from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2
from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
from tensorflow.python.platform import resource_loader as _resource_loader
+from tensorflow.python.util import deprecation
from tensorflow.python.util.lazy_loader import LazyLoader
@@ -90,12 +92,13 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
fp_output.name
]
cmdline = " ".join(cmd)
+ is_windows = _platform.system() == "Windows"
proc = _subprocess.Popen(
cmdline,
shell=True,
stdout=_subprocess.PIPE,
stderr=_subprocess.STDOUT,
- close_fds=True)
+ close_fds=not is_windows)
stdout, stderr = proc.communicate()
exitcode = proc.returncode
if exitcode == 0:
@@ -223,7 +226,8 @@ def build_toco_convert_protos(input_tensors,
return model, toco
-def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
+def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
+ **kwargs):
""""Convert a model using TOCO.
Typically this function is used to convert from TensorFlow GraphDef to TFLite.
@@ -252,3 +256,30 @@ def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
toco_flags.SerializeToString(),
input_data.SerializeToString())
return data
+
+
+@deprecation.deprecated(None, "Use `lite.TocoConverter` instead.")
+def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
+ """"Convert a model using TOCO.
+
+ Typically this function is used to convert from TensorFlow GraphDef to TFLite.
+ Conversion can be customized by providing arguments that are forwarded to
+ `build_toco_convert_protos` (see documentation for details).
+
+ Args:
+ input_data: Input data (i.e. often `sess.graph_def`),
+ input_tensors: List of input tensors. Type and shape are computed using
+ `foo.get_shape()` and `foo.dtype`.
+ output_tensors: List of output tensors (only .name is used from this).
+ *args: See `build_toco_convert_protos`,
+ **kwargs: See `build_toco_convert_protos`.
+
+ Returns:
+ The converted data. For example if TFLite was the destination, then
+ this will be a tflite flatbuffer in a bytes array.
+
+ Raises:
+ Defined in `build_toco_convert_protos`.
+ """
+ return toco_convert_impl(input_data, input_tensors, output_tensors, *args,
+ **kwargs)
diff --git a/tensorflow/contrib/lite/python/convert_test.py b/tensorflow/contrib/lite/python/convert_test.py
index dc21a9b669..bc05514cec 100644
--- a/tensorflow/contrib/lite/python/convert_test.py
+++ b/tensorflow/contrib/lite/python/convert_test.py
@@ -113,12 +113,13 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
# and 1 final output).
self.assertEqual(self._countIdentities(sess.graph_def.node), 4)
- stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess)
+ stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
+ graph_def=sess.graph_def)
self.assertCountEqual(
self._getGraphOpTypes(
stubbed_graphdef,
- output_nodes=[op_hint._tensor_name_base(output)]),
+ output_nodes=[op_hint._tensor_name_base(output.name)]),
["cool_activation", "Const", "Identity"])
def testScaleAndBiasAndIdentity(self):
@@ -139,12 +140,13 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
# +1 for the final output
self.assertEqual(self._countIdentities(sess.graph_def.node), 6)
- stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess)
+ stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
+ graph_def=sess.graph_def)
self.assertCountEqual(
self._getGraphOpTypes(
stubbed_graphdef,
- output_nodes=[op_hint._tensor_name_base(output)]),
+ output_nodes=[op_hint._tensor_name_base(output.name)]),
["scale_and_bias_and_identity", "Const", "Identity", "Pack"])
def testTwoFunctions(self):
@@ -153,7 +155,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
b = array_ops.constant([1.])
def _double_values(x):
custom = op_hint.OpHint("add_test")
- x = custom.add_inputs(x)
+ x, = custom.add_inputs(x)
output = math_ops.multiply(x, x)
output, = custom.add_outputs(output)
return output
@@ -164,13 +166,90 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
# make sure one identity for each input (2) and output (2) => 2 + 2
# +1 for the final output
self.assertEqual(self._countIdentities(sess.graph_def.node), 5)
- stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess)
+ stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
+ graph_def=sess.graph_def)
self.assertCountEqual(
self._getGraphOpTypes(
stubbed_graphdef,
- output_nodes=[op_hint._tensor_name_base(output)]),
+ output_nodes=[op_hint._tensor_name_base(output.name)]),
["add_test", "Const", "Identity", "Add"])
+ def _get_input_index(self, x):
+ return x.op.node_def.attr[op_hint.OpHint.FUNCTION_INPUT_INDEX_ATTR].i
+
+ def _get_output_index(self, x):
+ return x.op.node_def.attr[op_hint.OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i
+
+ def _get_sort_index(self, x):
+ return x.op.node_def.attr[op_hint.OpHint.FUNCTION_SORT_INDEX_ATTR].i
+
+ def testTags(self):
+ """Test if multiple args with the same tag are grouped."""
+ a = array_ops.constant([1.])
+ b = array_ops.constant([2.])
+ c = array_ops.constant([3.])
+ d = array_ops.constant([4.])
+ custom = op_hint.OpHint("test_tag")
+ a = custom.add_input(a, tag="mytag",
+ aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ b, = custom.add_inputs(b)
+ c = custom.add_input(c, tag="mytag",
+ aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ d = custom.add_input(d, tag="mytag2",
+ aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ res = math_ops.add(math_ops.mul(a, b), math_ops.mul(c, b))
+ custom.add_outputs([res])
+ with self.test_session():
+ self.assertEqual(self._get_input_index(a), 0)
+ self.assertEqual(self._get_sort_index(a), 0)
+ self.assertEqual(self._get_input_index(b), 1)
+ self.assertEqual(self._get_input_index(c), 0)
+ self.assertEqual(self._get_sort_index(c), 1)
+
+ def testOverrideIndex(self):
+ a = array_ops.constant([1.])
+ b = array_ops.constant([2.])
+ c = array_ops.constant([3.])
+ custom = op_hint.OpHint("test_override")
+ b = custom.add_input(b) # should auto assign 0
+ a = custom.add_input(a, index_override=1)
+ c = custom.add_input(c) # should auto assign 2
+ with self.test_session():
+ self.assertEqual(self._get_input_index(a), 1)
+ self.assertEqual(self._get_input_index(b), 0)
+ self.assertEqual(self._get_input_index(c), 2)
+
+ def testAggregate(self):
+ a = array_ops.constant([3., 4.])
+ b = array_ops.constant([5., 6.])
+ hint = op_hint.OpHint("agg")
+ a0, a1 = array_ops.unstack(a)
+ b0, b1 = array_ops.unstack(b)
+
+ a0 = hint.add_input(a0, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ b0 = hint.add_input(b0, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ a1 = hint.add_input(a1, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ b1 = hint.add_input(b1, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+
+ c0 = math_ops.add(a0, b0, name="addleft")
+ c1 = math_ops.add(a1, b1, name="addright")
+ c0 = hint.add_output(
+ c0, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ c1 = hint.add_output(
+ c1, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+
+ curr = array_ops.stack([c0, c1])
+ output = array_ops.identity(curr, name="FINAL_OUTPUT")
+ with self.test_session() as sess:
+ stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
+ graph_def=sess.graph_def)
+ print(stubbed_graphdef)
+ self.assertCountEqual(
+ self._getGraphOpTypes(
+ stubbed_graphdef,
+ output_nodes=[op_hint._tensor_name_base(output.name)]),
+ ["agg", "Const", "Identity"])
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 5ec52035ad..2313bfa3b6 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -41,7 +41,8 @@ from google.protobuf.message import DecodeError
from tensorflow.contrib.lite.python import lite_constants as constants
from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert import tensor_name as _tensor_name
-from tensorflow.contrib.lite.python.convert import toco_convert
+from tensorflow.contrib.lite.python.convert import toco_convert # pylint: disable=unused-import
+from tensorflow.contrib.lite.python.convert import toco_convert_impl as _toco_convert_impl
from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
@@ -110,6 +111,7 @@ class TocoConverter(object):
Example usage:
+ ```python
# Converting a GraphDef from session.
converter = lite.TocoConverter.from_session(sess, in_tensors, out_tensors)
tflite_model = converter.convert()
@@ -124,6 +126,11 @@ class TocoConverter(object):
# Converting a SavedModel.
converter = lite.TocoConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
+
+ # Converting a tf.keras model.
+ converter = lite.TocoConverter.from_keras_model_file(keras_model)
+ tflite_model = converter.convert()
+ ```
"""
def __init__(self, graph_def, input_tensors, output_tensors):
@@ -354,7 +361,7 @@ class TocoConverter(object):
quantized_stats = None
# Converts model.
- result = toco_convert(
+ result = _toco_convert_impl(
input_data=self._graph_def,
input_tensors=self._input_tensors,
output_tensors=self._output_tensors,
diff --git a/tensorflow/contrib/lite/python/op_hint.py b/tensorflow/contrib/lite/python/op_hint.py
index 7908689ce4..8c920132e5 100644
--- a/tensorflow/contrib/lite/python/op_hint.py
+++ b/tensorflow/contrib/lite/python/op_hint.py
@@ -25,9 +25,9 @@ Example:
def tflite_cool_activation(input):
# A cool activation function.
custom = tf.contrib.lite.OpHint("cool_activation")
- input = custom.add_inputs(input)
+ input, = custom.add_inputs(input)
output = tf.sigmoid(input) * input
- custom.add_outputs(output)
+ output, = custom.add_outputs(output)
return output
image = tf.placeholder(tf.float32, (1, 16, 16, 1))
@@ -64,18 +64,27 @@ ops don't actually exist in the normal TensorFlow runtime, but will be
understood by toco later.
"""
+# TODO(aselle): Make this use generic graph transformations.
+# TODO(aselle): _tensor_name_base should be called _tensor_name_to_op_name.
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections as _collections
-import itertools as _itertools
+import copy as _copy
import uuid as _uuid
+import six as _six
-from tensorflow.contrib import framework as _framework
from tensorflow.core.framework import attr_value_pb2 as _attr_value_pb2
+from tensorflow.core.framework import graph_pb2 as _graph_pb2
+from tensorflow.core.framework import node_def_pb2 as _node_def_pb2
from tensorflow.python.framework import ops as _ops
+# TODO(aselle): publicize these apis if we continue to use these.
+from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
+from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
from tensorflow.python.ops import array_ops as _array_ops
+from tensorflow.python.util import compat as _compat
from tensorflow.python.util.all_util import remove_undocumented
@@ -97,11 +106,174 @@ class OpHint(object):
constructs, this mechanism can be retired and changed to use python defun's.
"""
- # Attr constants that are used for representation in the GraphDef
+ # Attr constants that are used for representation in the GraphDef. These
+ # will be used on every Identity op that is involved in a total OpHint.
+
+ # Name of the OpHint function (cosmetic).
FUNCTION_NAME_ATTR = "_tflite_function_name"
+ # UUID of the function (each OpHint gets a new uuid).
FUNCTION_UUID_ATTR = "_tflite_function_uuid"
+ # The index index of the input (or nothing if it is an output).
FUNCTION_INPUT_INDEX_ATTR = "_tflite_function_input_index"
+ # The output index of the output (or nothing if it is an input).
FUNCTION_OUTPUT_INDEX_ATTR = "_tflite_function_output_index"
+ # An index that orders aggregate arguments. Aggregate arguments are ones
+ # that are separate but will be fused horizontally. For example a static LSTM
+ # has a lstm cell for each time step. Each one has a separate opHint, but a
+ # fused SequentialLSTM will treat this as a single tensor.
+ FUNCTION_SORT_INDEX_ATTR = "_tflite_function_sort_index"
+ # The way in which multiple parts of the aggregate argument will be joined
+ # into a fused operand. Valid options are OpHint.AGGREGATE_FIRST,
+ # OpHint.AGGREGATE_LAST, OpHint.AGGREGATE_STACK.
+ FUNCTION_AGGREGATE_ATTR = "_tflite_function_aggregate"
+ # On fused OpHint stub, the order of inputs that the final LSTM call will
+ # have. What this means is that the TensorFlow order might be
+ # "foo", "bar", "stuff" and you might want the TF lite op order to be
+ # "stuff", "foo", "bar", -1 (where -1 is unused). So you would set this
+ # attribute to [2, 0, 1, -1].
+ TFLITE_INPUT_INDICES = "_tflite_input_indices"
+
+ # Types of aggregations
+ # stack: stacks all ophints with matching tags. i.e. for a static rnn.
+ # specifically, this is good for an input or output to a static rnn cell.
+ AGGREGATE_STACK = _compat.as_bytes("stack")
+ # first: only takes the first output (one with lowest sort index)
+ # of matching tags. This is good for the input state to an RNN.
+ AGGREGATE_FIRST = _compat.as_bytes("first")
+ # aggregation last takes only the last tag (one with highest sort index).
+ # This is good for an output value on the last stack item of a
+ # static rnn.
+ AGGREGATE_LAST = _compat.as_bytes("last")
+
+ class OpHintArgumentTracker(object):
+ """Conceptually tracks indices of arguments of "OpHint functions".
+
+ The inputs and arguments of these functions both use an instance
+ of the class so they can have independent numbering."""
+
+ def __init__(self, function_name, unique_function_id, node_name_prefix,
+ attr_name):
+ """Initialize ophint argument.
+
+ Args:
+ function_name: Name of the function that this tracks arguments for.
+ unique_function_id: UUID of function that this tracks arguments for.
+ node_name_prefix: How identities that are created are named.
+ attr_name: Name of attribute to use to store the index for this hint.
+ i.e. FUNCTION_INPUT_INDEX or FUNCTION_OUTPUT_INDEX
+ """
+
+ # The global index is the argument index of the op. This is in contrast
+ # to the sort index which is the sequence number of a particular instance
+ # of a given global index. For example, you may have called add hint
+ # twice with the tag "foo". Then the global index will be 0 for both
+ # and the sort index will be 0 for the first added and 1 for the second.
+ self._function_name = function_name
+ self._unique_function_id = unique_function_id
+ self._next_global_index = 0 # The absolute global index
+ self._used_global_indices = set()
+ self._tag_to_global_index = {} # The argument index a given tag maps to
+ self._tag_to_next_sort_index = {} # The current index for each tag
+ self._node_name_prefix = node_name_prefix
+ self._attr_name = attr_name
+
+ def _get_new_global_index(self, index_override):
+ """Return the next unused argument index in order or use an override.
+
+ Args:
+ index_override: An index to use instead of the next available or None
+ to use the next available.
+
+ Returns:
+ A valid global_index to use for the next hint argument.
+
+ Raises:
+ ValueError: If the index_override is already used by another hint.
+ """
+ if index_override is None:
+ global_index = self._next_global_index
+ else:
+ if index_override in self._used_global_indices:
+ raise ValueError("Index %d was already used by another call to add")
+ global_index = index_override
+ # Make next_global_index valid
+ self._used_global_indices.add(global_index)
+ while self._next_global_index in self._used_global_indices:
+ self._next_global_index += 1
+ return global_index
+
+ def add(self, arg, tag=None, name=None, aggregate=None,
+ index_override=None):
+ """Return a wrapped tensor of an input tensor as an argument.
+
+ Args:
+ arg: A TensorFlow tensor that should be considered an argument.
+ tag: String tag to identify arguments that should be packed.
+ name: Name of argument. This is included in the Identity hint op names.
+ aggregate: Strategy to aggregate.
+ Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
+ and OpHint.AGGREGATE_STACK.
+ Note, aggregate is only valid if tag is specified.
+ index_override: Specify what input/output index should this be in the
+ final stub. i.e. add(arg0, index=1); add(arg1, index=0) wil make the
+ final stub be as stub_func(inputs[arg1, arg0], outputs=[]) rather than
+ the default call order based ordering.
+
+ Returns:
+ A tensor representing the wrapped argument.
+
+ Raises:
+ ValueError: When indices are not consistent.
+ """
+
+ # Find the appropriate index
+ if tag is None:
+ if aggregate is not None:
+ raise ValueError("You must specify `tag` if using aggregate.")
+ global_index = self._get_new_global_index(index_override)
+ sort_index = None
+ else:
+ if aggregate is None:
+ raise ValueError("You must specify `aggregate` if using tag.")
+ if tag not in self._tag_to_global_index:
+ self._tag_to_global_index[tag] = (
+ self._get_new_global_index(index_override))
+ self._tag_to_next_sort_index[tag] = 0
+ elif (index_override and
+ index_override != self._tag_to_global_index[tag]):
+ raise ValueError(
+ "Tag %r was called with two indices %r and %r" %
+ (tag, index_override, self._tag_to_global_index[tag]))
+ global_index = self._tag_to_global_index[tag]
+ sort_index = self._tag_to_next_sort_index[tag]
+ self._tag_to_next_sort_index[tag] += 1
+
+ uuid = self._unique_function_id
+ name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix, self._function_name,
+ uuid, global_index, sort_index, name)
+ identity_op = _array_ops.identity(arg, name=name)
+
+ # pylint: disable=protected-access
+ identity_op.op._set_attr(
+ OpHint.FUNCTION_NAME_ATTR,
+ _attr_value_pb2.AttrValue(
+ s=_compat.as_bytes(self._function_name)))
+ identity_op.op._set_attr(
+ OpHint.FUNCTION_UUID_ATTR,
+ _attr_value_pb2.AttrValue(
+ s=_compat.as_bytes(self._unique_function_id)))
+ identity_op.op._set_attr(
+ self._attr_name, _attr_value_pb2.AttrValue(i=global_index))
+ if sort_index is not None:
+ identity_op.op._set_attr(
+ OpHint.FUNCTION_SORT_INDEX_ATTR,
+ _attr_value_pb2.AttrValue(i=sort_index))
+ if aggregate is not None:
+ identity_op.op._set_attr(
+ OpHint.FUNCTION_AGGREGATE_ATTR,
+ _attr_value_pb2.AttrValue(s=_compat.as_bytes((aggregate))))
+ # pylint: enable=protected-access
+ return identity_op
def __init__(self, function_name, **kwargs):
"""Create a OpHint.
@@ -112,10 +284,14 @@ class OpHint(object):
"""
self._function_name = function_name
self._unique_function_id = _uuid.uuid1().hex # TODO(aselle): Unique enough?
- self._curr_input_index = 0
- self._curr_output_index = 0
self._attrs_to_store_later = kwargs
self._stored_attrs = False
+ self._inputs = OpHint.OpHintArgumentTracker(
+ self._function_name, self._unique_function_id, "InputHint",
+ OpHint.FUNCTION_INPUT_INDEX_ATTR)
+ self._outputs = OpHint.OpHintArgumentTracker(
+ self._function_name, self._unique_function_id, "OutputHint",
+ OpHint.FUNCTION_OUTPUT_INDEX_ATTR)
def _setattr(self, dest_op, name, value):
tensor_value = _ops.convert_to_tensor(value)
@@ -124,68 +300,278 @@ class OpHint(object):
tensor=tensor_value.op.node_def.attr["value"].tensor))
# pylint: enable=protected-access
- def add_inputs(self, *args):
+ def add_input(self, *args, **kwargs):
+ """Add a wrapped input argument to the hint.
+
+ Args:
+ *args: The input tensor.
+ **kwargs:
+ "name" label
+ "tag" a tag to group multiple arguments that will be aggregated. I.e.
+ a string like 'cool_input'. Basically multiple inputs can be added
+ to the same hint for parallel operations that will eventually be
+ combined. An example would be static_rnn which creates multiple copies
+ of state or inputs.
+ "aggregate" aggregation strategy that is valid only for tag non None.
+ Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
+ and OpHint.AGGREGATE_STACK.
+ "index_override" The global index to use. This corresponds to the
+ argument order in the final stub that will be generated.
+ Returns:
+ The wrapped input tensor.
+ """
+ return self._inputs.add(*args, **kwargs)
+
+ def add_output(self, *args, **kwargs):
+ """Add a wrapped output argument to the hint.
+
+ Args:
+ *args: The output tensor.
+ **kwargs:
+ "name" label
+ "tag" a tag to group multiple arguments that will be aggregated. I.e.
+ a string like 'cool_input'. Basically multiple inputs can be added
+ to the same hint for parallel operations that will eventually be
+ combined. An example would be static_rnn which creates multiple copies
+ of state or inputs.
+ "aggregate" aggregation strategy that is valid only for tag non None.
+ Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
+ and OpHint.AGGREGATE_STACK.
+ "index_override" The global index to use. This corresponds to the
+ argument order in the final stub that will be generated.
+ Returns:
+ The wrapped output tensor.
+ """
+ return self._outputs.add(*args, **kwargs)
+
+ def add_inputs(self, *args, **kwargs):
"""Add a sequence of inputs to the function invocation.
Args:
*args: List of inputs to be converted (should be Tf.Tensor).
+ **kwargs: This allows 'names' which should be a list of names.
Returns:
Wrapped inputs (identity standins that have additional metadata). These
are also are also tf.Tensor's.
"""
-
- def augmented_identity(arg):
- identity_op = _array_ops.identity(arg)
- # pylint: disable=protected-access
- identity_op.op._set_attr(
- OpHint.FUNCTION_NAME_ATTR,
- _attr_value_pb2.AttrValue(s=self._function_name))
- identity_op.op._set_attr(
- OpHint.FUNCTION_UUID_ATTR,
- _attr_value_pb2.AttrValue(s=self._unique_function_id))
- identity_op.op._set_attr(
- OpHint.FUNCTION_INPUT_INDEX_ATTR,
- _attr_value_pb2.AttrValue(i=self._curr_input_index))
- # pylint: enable=protected-access
- self._curr_input_index += 1
- return identity_op
-
- return [augmented_identity(arg) for arg in args]
-
- def add_outputs(self, *args):
+ if "names" in kwargs:
+ return [
+ self._inputs.add(arg, name=name)
+ for arg, name in zip(args, kwargs["names"])
+ ]
+ else:
+ return [self._inputs.add(arg) for arg in args]
+
+ def add_outputs(self, *args, **kwargs):
"""Add a sequence of outputs to the function invocation.
Args:
*args: List of outputs to be converted (should be tf.Tensor).
+ **kwargs: See
Returns:
Wrapped outputs (identity standins that have additional metadata). These
are also tf.Tensor's.
"""
+ if "names" in kwargs:
+ return [
+ self._outputs.add(arg, name=name)
+ for arg, name in zip(args, kwargs["names"])
+ ]
+ else:
+ return [self._outputs.add(arg) for arg in args]
+
+
+class _LiteOperand(object):
+ """Abstract operand for a tflite hint function.
+
+ This is a base class that handles representing arguments to an OpHint.
+ It also is able to serialize operands to the stubbed graph_def.
+ Child classes are responsible for being able to
+ store information about the hint identity operators. They are also responsible
+ for knowing how to serialize to output graphdefs.
+
+ Typically this will be implemented by holding one or more identity nodes
+ that were previously discovered as hints.
+ """
+
+ def aggregate_and_return_name_for_input(self, out_graphdef):
+ """This adds the node(s) to out_graphdef and returns the input node name.
+
+ Args:
+ out_graphdef: A graphdef that is ready to have this input added.
+
+ Returns:
+ The the output that the stub should use as an input for this operand.
+
+ Raises:
+ RuntimeError: if the method is not implemented.
+ """
+ del out_graphdef
+ raise RuntimeError("Unimplemented abstract method.")
+
+ def aggregate_and_return_name_for_output(self, fused_op_name, output_index,
+ out_graphdef):
+ """Add node(s) to graph representing output operands and returns type.
+
+ Args:
+ fused_op_name: name of the fused op stub name.
+ output_index: Output index that we are currently processing from stub.
+ out_graphdef: The destination graphdef we are currently building up.
+
+ Returns:
+ The datatype of this identity.
+
+ Raises:
+ RuntimeError: if the method is not implemented.
+ """
+ del fused_op_name, output_index, out_graphdef
+ raise RuntimeError("Unimplemented abstract method.")
- def augmented_identity(arg):
- identity_op = _array_ops.identity(arg)
- # pylint: disable=protected-access
- identity_op.op._set_attr(
- OpHint.FUNCTION_NAME_ATTR,
- _attr_value_pb2.AttrValue(s=self._function_name))
- identity_op.op._set_attr(
- OpHint.FUNCTION_UUID_ATTR,
- _attr_value_pb2.AttrValue(s=self._unique_function_id))
- identity_op.op._set_attr(
- OpHint.FUNCTION_OUTPUT_INDEX_ATTR,
- _attr_value_pb2.AttrValue(i=self._curr_output_index))
- # pylint: enable=protected-access
- self._curr_output_index += 1
- return identity_op
- wrapped_outputs = [augmented_identity(arg) for arg in args]
+class _LiteSingleOperand(_LiteOperand):
+ """A simple operand that is non-aggregated (i.e. most hints)."""
- if not self._stored_attrs:
- for key, value in self._attrs_to_store_later.iteritems():
- self._setattr(wrapped_outputs[0], "_tflite_attr_" + key, value)
- self._stored_attrs = True
+ def __init__(self, node):
+ _LiteOperand.__init__(self)
+ self.node = node
+ self.name = _tensor_name_base(node.name)
- return wrapped_outputs
+ def flatten(self):
+ return [self.name]
+
+ def aggregate_and_return_name_for_input(self, out_graphdef):
+ return self.name
+
+ def aggregate_and_return_name_for_output(self, fused_op_name, index,
+ out_graphdef):
+ output_node = _copy.deepcopy(self.node)
+ del output_node.input[:]
+ output_node.input.append(_tensorflow_output_name(fused_op_name, index))
+ out_graphdef.node.extend([output_node])
+ return self.node.attr["type"].i
+
+ def __str__(self):
+ return str(self.name)
+
+
+class _LiteAggregateOperand(_LiteOperand):
+ """An operand for a tflite hint function that is aggregated from many.
+
+ For example, an LSTM is a grid of operators that are all related. Inputs
+ going into them may need to be fused, so they should all be tracked as
+ related arguments.
+ """
+
+ def __init__(self, aggregation):
+ _LiteOperand.__init__(self)
+ self.aggregation = aggregation
+ self.names = {}
+ self.nodes = {}
+ self.flattened = None
+
+ def add(self, sort, node):
+ self.names[sort] = _tensor_name_base(node.name)
+ self.nodes[sort] = node
+
+ def flatten_nodes(self):
+ """Return a list of all the node protos in aggregation sorted order."""
+ if not self.flattened:
+ self.flattened = [None] * len(self.nodes)
+ for idx, node in _six.iteritems(self.nodes):
+ self.flattened[idx] = node
+ for n in self.nodes:
+ if n is None:
+ raise RuntimeError("Aggregate was missing argument.")
+ if self.aggregation == OpHint.AGGREGATE_FIRST:
+ self.flattened = self.flattened[:1]
+ elif self.aggregation == OpHint.AGGREGATE_LAST:
+ self.flattened = self.flattened[-1:]
+ elif self.aggregation == OpHint.AGGREGATE_STACK:
+ pass
+ else:
+ raise ValueError(
+ "Invalid aggregation type %r specified" % self.aggregation)
+ return self.flattened
+
+ def flatten(self):
+ """Return a list of all node names in aggregation sorted sorter."""
+ return [_tensor_name_base(x.name) for x in self.flatten_nodes()]
+
+ def aggregate_and_return_name_for_input(self, out_graphdef):
+ """This adds the nodes to out_graphdef and returns an aggregated output.
+
+ In particular, if you have 4 inputs to a hint stub, this will be the
+ node that you can use as an output. I.e. you have 4 timesteps from a
+ static rnn, then a fused UnidriecitonalLSTM will expect 1 input with
+ all 4 time steps. So here we make a pack and return the output name of
+ that pack.
+
+ Args:
+ out_graphdef: A graphdef that is ready to have this input added.
+
+ Returns:
+ The name of a pack that aggregates this node.
+ """
+ flattened = self.flatten_nodes()
+ if len(flattened) == 1:
+ return _tensor_name_base(flattened[0].name)
+ else:
+ new_node = _node_def_pb2.NodeDef()
+ new_node.op = "Pack"
+ new_node.name = "OpHintStack-%s" % flattened[0].name
+ new_node.attr["N"].i = len(flattened)
+ new_node.attr["T"].type = flattened[0].attr["T"].type
+ for discrete in flattened:
+ new_node.input.append(_tensor_name_base(discrete.name))
+ out_graphdef.node.extend([new_node])
+ return new_node.name
+
+ def aggregate_and_return_name_for_output(self, fused_op_name, output_index,
+ out_graphdef):
+ """This adds to `out_graphdef` all the unaggregated outputs.
+
+ I.e. we are outputting from a fused stub, but we need to make it compatible
+ with the unfused original graph so we insert an unpack. Ideally in a later
+ stage the unpack -> pack sequences will be removed.
+
+ Args:
+ fused_op_name: The name of the stub we are in the process of fusing.
+ output_index: The output output_index this object represents.
+ out_graphdef: The graphdef we are in the process of buildings
+
+ Returns:
+ The type of the aggregated output (so we can finish building the stub
+ op).
+ """
+ flattened = self.flatten_nodes()
+ if len(flattened) == 1:
+ temp_op = _LiteSingleOperand(flattened[0])
+ return temp_op.aggregate_and_return_name_for_output(
+ fused_op_name, output_index, out_graphdef)
+ else:
+ stack_node = _node_def_pb2.NodeDef()
+ stack_node.op = "Unpack"
+ stack_node.name = "OpHintUnstack-%s" % flattened[0].name
+ stack_node.attr["num"].i = len(flattened)
+ output_type = flattened[0].attr["T"].type
+ stack_node.attr["T"].type = output_type
+ stack_node.input.append(_tensorflow_output_name(
+ fused_op_name, output_index))
+ out_graphdef.node.extend([stack_node])
+
+ for idx, discrete in enumerate(flattened):
+ output_node = _copy.deepcopy(discrete)
+ del output_node.input[:]
+ output_node.input.append(_tensorflow_output_name(stack_node.name, idx))
+ out_graphdef.node.extend([output_node])
+
+ return output_type
+
+ def __str__(self):
+ s = "\t\t\tAGGREGATE %s\n" % self.aggregation
+ for sort, val in self.names.iteritems():
+ s += "\t\t\t%d: %s\n" % (sort, val)
+ return s
class _LiteFuncCall(object):
@@ -212,46 +598,87 @@ class _LiteFuncCall(object):
self.uuid = None
self.params = {}
+ def flattened_inputs_and_outputs(self):
+ """Return a list of inputs and outputs in a flattened format.
+
+ Returns:
+ Tuple of (inputs, outputs). where input and output i a list of names.
+ """
+ def _flatten(input_or_output_dict):
+ flattened_items = []
+ for item in input_or_output_dict.values():
+ flattened_items.extend(item.flatten())
+ return flattened_items
+
+ return _flatten(self.inputs), _flatten(self.outputs)
+
def __str__(self):
- return "tflite function %s call %s\n\tinputs: %r\n\toutputs: %r" % (
- self.function_name, self.uuid, self.inputs, self.outputs)
+ def format_args(items):
+ s = ""
+ for idx, item in items.iteritems():
+ s += ("\t\t%d:\n" % idx) + str(item)
+ return s
+
+ inputs_str = "\tInputs\n" + format_args(self.inputs)
+ outputs_str = "\tOutputs\n" + format_args(self.outputs)
+ return ("tflite function %s call %s\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s"
+ % (self.function_name, self.uuid, inputs_str, outputs_str))
-def _find_all_hints_in_graph_def(session):
+
+def _find_all_hints_in_graph_def(graphdef):
"""Look at the current default graph and return a list of LiteFuncCall objs.
Args:
- session: A TensorFlow session that contains the graph to convert.
+ graphdef: A TensorFlow graph_def to look for LiteFuncCalls.
Returns:
a list of `LifeFuncCall` objects in the form
"""
func_calls = _collections.defaultdict(_LiteFuncCall)
- seen_ops = set()
-
- for op in session.graph.get_operations():
- for operand in _itertools.chain(op.inputs, op.outputs):
- if operand in seen_ops:
- continue
- seen_ops.add(operand)
- attr = operand.op.node_def.attr
- uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
- if OpHint.FUNCTION_UUID_ATTR not in attr:
- continue
- call_def = func_calls[uuid]
- call_def.uuid = uuid
- if OpHint.FUNCTION_UUID_ATTR in attr:
- call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s
- if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr:
- call_def.inputs[attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i] = operand
- if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr:
- call_def.outputs[attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i] = operand
-
- for a in attr:
- if a.startswith("_tflite_attr_"):
- # TODO(aselle): Remember the attribute tensors so we can put them
- # in collapse.
- call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor
+
+ for node in graphdef.node:
+ attr = node.attr
+ # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip
+ uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
+ if (OpHint.FUNCTION_UUID_ATTR not in attr
+ or not attr[OpHint.FUNCTION_UUID_ATTR].s):
+ continue
+
+ # Start building function
+ call_def = func_calls[uuid]
+ call_def.uuid = uuid
+ call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s
+ # Get sorting and aggregation information
+
+ sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
+ if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None)
+ if sort == -1: sort = None
+ aggregation = None
+ if OpHint.FUNCTION_AGGREGATE_ATTR in attr:
+ aggregation = attr[OpHint.FUNCTION_AGGREGATE_ATTR].s
+
+ # Add the input or output
+ def put_operand(stuff, index, sort, operand, aggregation):
+ """Add a given index into the function structure."""
+ if sort is None:
+ stuff[index] = _LiteSingleOperand(operand)
+ else:
+ if index not in stuff:
+ stuff[index] = _LiteAggregateOperand(aggregation)
+ stuff[index].add(sort, operand)
+
+ if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr:
+ put_operand(call_def.inputs, attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i,
+ sort, node, aggregation)
+ if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr:
+ put_operand(call_def.outputs, attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i,
+ sort, node, aggregation)
+
+ # Remember attributes
+ for a in attr:
+ if a.startswith("_tflite_attr_"):
+ call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor
return func_calls
@@ -267,42 +694,305 @@ def _tensor_name_base(full_tensor_name):
Returns:
A name without any device assignment.
"""
- return full_tensor_name.name.split(":")[0]
+ if full_tensor_name.startswith("^"):
+ return full_tensor_name[1:]
+ return full_tensor_name.split(":")[0]
+
+
+def _tensorflow_output_name(tensor_name, output_index):
+ return tensor_name if output_index == 0 else "%s:%d" % (tensor_name,
+ output_index)
+
+
+# TODO(aselle): This should be converted to grappler in the future.
+def _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
+ name_to_input_name):
+ """Checks to make sure node only connects to predecessor graph through inputs.
+
+ Args:
+ n: Node to check
+ reachable_by_input: Nodes that are reachable by all inputs of subgraph
+ input_nodes_set: The set of nodes that are "inputs".
+ name_to_input_name: Maps from name to the list of inputs.
+
+ Raises:
+ TypeError: If the given node uses items past inputs directly.
+ """
+ next_to_visit = [n]
+ visited = set()
+ while next_to_visit:
+ current_node = next_to_visit.pop()
+ visited.add(current_node)
+ if (current_node in reachable_by_input
+ and current_node not in input_nodes_set):
+ raise TypeError(
+ "Node %s uses input %s not in input_nodes." % (n, current_node))
+ if current_node not in input_nodes_set:
+ next_to_visit += [
+ input_node for input_node in name_to_input_name[current_node]
+ if input_node not in visited
+ ]
+
+
+# TODO(aselle): This should be converted to grappler in the future.
+def _convert_single_op_hint_to_stub(call, graph_def):
+ """Given a graph_def, converts `call` into a stub and returns a new graph_def.
+ Args:
+ call: A single function call to be converted.
+ graph_def: A graph_def to use as input (that hass call obviously).
+ Returns:
+ A new transformed graph-def that has call as a stub (single op).
-def convert_op_hints_to_stubs(session):
+ Note: after this process, the graph_def can no longer be loaded into
+ the tensorflow runtime, so all future manipulations are done in graph_def
+ level.
+ """
+ name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
+ graph_def)
+ input_names, output_names = call.flattened_inputs_and_outputs()
+
+ reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
+ reachable_by_output = _bfs_for_reachable_nodes(output_names,
+ name_to_input_name)
+ input_nodes_set = set(input_names)
+ output_nodes_set = set(output_names)
+ nodes_after_fuse = []
+ nodes_deleted_by_fuse = set()
+ # Classify each node. We want to keep everything reachable by input, but
+ # we don't know if things that are not reachable by output or input (things
+ # after fusing).
+ for node in graph_def.node:
+ n = _tensor_name_base(node.name)
+ if n in reachable_by_output:
+ if n not in reachable_by_input and n not in output_nodes_set:
+ # n is an internal node. Check to make sure it is really internal.
+ # TODO(aselle): this could be done more efficiently by flooding
+ # the graph first.
+ _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
+ name_to_input_name)
+ nodes_deleted_by_fuse.add(n)
+ elif n not in reachable_by_input:
+ # n is a node that after all the fusings, so keep it.
+ nodes_after_fuse.append(n)
+ else:
+ # n is a node that is randomly in the graph but not connected to
+ # the chain of dependencies.
+ pass
+
+ # Make a new graphdef with all the pre-input and input nodes
+ out = _graph_pb2.GraphDef()
+ reachable_by_input_sorted = sorted(
+ list(reachable_by_input), key=lambda n: name_to_seq_num[n])
+ for node in reachable_by_input_sorted:
+ out.node.extend([_copy.deepcopy(name_to_node[node])])
+
+ # Create any stacks to aggregate arguments into to a single input
+ # i.e. for static_rnn's.
+ # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1
+ sorted_input_indices = list(call.inputs.keys())
+ sorted_input_indices.sort()
+ sorted_output_indices = list(call.outputs.keys())
+ sorted_output_indices.sort()
+ new_node = _node_def_pb2.NodeDef()
+ # Delegate to each operand to produce the proper new input for this stub node.
+ # In particular, an aggregate input will now be a Pack of some previously
+ # non-fused things.
+ for input_index in sorted_input_indices:
+ inputs = call.inputs[input_index]
+ new_node.input.append(inputs.aggregate_and_return_name_for_input(out))
+ new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices)
+
+ # Ceate the function
+ new_node.op = call.function_name
+ new_node.name = call.uuid
+ out.node.extend([new_node])
+
+ # Now call each output argument to give them a chance to make the proper
+ # output type and add it to our new_node.
+ output_dtypes = []
+ for output_index in sorted_output_indices:
+ output = call.outputs[output_index]
+ output_dtype = (
+ output.aggregate_and_return_name_for_output(new_node.name, output_index,
+ out))
+ output_dtypes.append(output_dtype)
+ new_node.attr["_output_types"].list.type[:] = output_dtypes
+ # TODO(aselle): what is right here?
+ new_node.attr["_output_quantized"].b = False
+
+ # Add post output nodes that do not depend on the outputs
+ for n in nodes_after_fuse:
+ should_keep = True
+ for input_name in name_to_input_name[n]:
+ if input_name in nodes_deleted_by_fuse:
+ should_keep = False
+ if should_keep:
+ out.node.extend([_copy.deepcopy(name_to_node[n])])
+
+ # Misc. graph_def data that needs copying.
+ out.library.CopyFrom(graph_def.library)
+ out.versions.CopyFrom(graph_def.versions)
+
+ return out
+
+
+# TODO(aselle): This should be converted to grappler in the future.
+def _remove_one_redundant_stack_unstack(in_graph_def):
+ """Removes a stack->unstack pattern from in_graph_def in a returned graph.
+
+ Args:
+ in_graph_def: Graph def to use as input.
+ Returns:
+ Simplified tuple (graph_def, changed_something) where changed_something
+ is true if anything was done.
+ """
+ name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
+ in_graph_def)
+ del name_to_seq_num
+
+ # TODO(aselle): Make this not hardcoded.
+ do_generic_pack_unpack = True
+
+ out = _graph_pb2.GraphDef()
+ out.library.CopyFrom(in_graph_def.library)
+ out.versions.CopyFrom(in_graph_def.versions)
+ for n in in_graph_def.node:
+ node_name = _tensor_name_base(n.name)
+ if not node_name.startswith("OpHintStack") and not n.op.startswith("Pack"):
+ continue
+ next_to_visit = [node_name]
+ visited = set()
+
+ unpack_nodes = set()
+ pack_node = node_name
+
+ # Find a pattern of unstack connected to a stack (with identities
+ # in between.
+ matches_pattern = True
+ is_hint_created_stack = False
+ while next_to_visit:
+ current_node_name = next_to_visit[0]
+ visited.add(current_node_name)
+ del next_to_visit[0]
+ node = name_to_node[current_node_name]
+ is_op_hint_stack = node.name.startswith("OpHintStack")
+ is_op_hint_unstack = node.name.startswith("OpHintUnstack")
+ if (node.op == "Identity" or is_op_hint_stack
+ or (do_generic_pack_unpack and node.op == "Pack")):
+ is_hint_created_stack |= is_op_hint_stack
+ next_to_visit += [
+ input_node for input_node in name_to_input_name[current_node_name]
+ if input_node not in visited
+ ]
+ elif (is_op_hint_unstack
+ or (do_generic_pack_unpack and node.op == "Unpack")):
+ unpack_nodes.add(node.name)
+ is_hint_created_stack &= is_op_hint_unstack
+ else:
+ matches_pattern = False
+ break
+ visited.add(node.name)
+
+ if matches_pattern and len(unpack_nodes) == 1:
+ pack_node = node_name
+
+ # Check to see if anyone depends on the intermediate identity or the
+ # Unstacked form
+ no_external_dependency = True
+ for other_n in in_graph_def.node:
+ if other_n.name in visited: continue
+ for input_tensor in name_to_input_name[other_n.name]:
+ input_op = _tensor_name_base(input_tensor)
+ if input_op in visited and input_op != pack_node:
+ no_external_dependency = False
+ # Proceed with the substitution if the stack/unstack pair was created
+ # through hints, or that it was not, but nobody is consuming things
+ # between the stack and unstack.
+ if is_hint_created_stack or no_external_dependency:
+ end = unpack_nodes.pop()
+ end_input = name_to_node[end].input[0]
+ # All nodes that depend on the final stack need to be redone to use
+ for other_n in in_graph_def.node:
+ node_name = _tensor_name_base(other_n.name)
+ if node_name not in visited:
+ new_node = _copy.deepcopy(other_n)
+ new_node.input[:] = [
+ (end_input if stripped == pack_node else
+ non_stripped) for stripped, non_stripped in zip(
+ name_to_input_name[node_name], new_node.input[:])
+ ]
+ out.node.extend([new_node])
+ return out, True
+ return in_graph_def, False
+
+
+def _remove_redundant_stack_unstack(graph_def):
+ curr = graph_def
+ del graph_def
+ changed_stuff = True
+ while changed_stuff:
+ curr, changed_stuff = _remove_one_redundant_stack_unstack(curr)
+ return curr
+
+
+def _convert_op_hints_to_stubs_helper(
+ graph_def, write_callback=lambda sess, graph_def: None):
+ """Converts a graph_def to a new graph_def where all op hints are stubbed.
+
+ Args:
+ graph_def: A graph def that we should convert.
+ write_callback: A function pointer that can be used to write intermediate
+ steps of graph transformation (optional).
+ Returns:
+ A new stubbed graph_def.
+ """
+
+ hints = _find_all_hints_in_graph_def(graph_def)
+ curr_graph_def = graph_def
+ del graph_def # prevent using graph_def again (common source of error)
+ for hint in _six.itervalues(hints):
+ curr_graph_def = _convert_single_op_hint_to_stub(
+ hint, curr_graph_def)
+ write_callback(curr_graph_def, "initial")
+ # The stubbing process can create stacks/unstacks in the case of LSTMs
+ # remove them.
+ curr_graph_def = _remove_redundant_stack_unstack(curr_graph_def)
+ return curr_graph_def
+
+
+def convert_op_hints_to_stubs(session=None,
+ graph_def=None,
+ write_callback=lambda graph_def, comments: None):
"""Converts a graphdef with LiteOp hints into stub operations.
This is used to prepare for toco conversion of complex intrinsic usages.
+ Note: only one of session or graph_def should be used, not both.
Args:
session: A TensorFlow session that contains the graph to convert.
+ graph_def: A graph def that we should convert.
+ write_callback: A function pointer that can be used to write intermediate
+ steps of graph transformation (optional).
Returns:
A new graphdef with all ops contained in OpHints being replaced by
a single op call with the right parameters.
+ Raises:
+ ValueError: If both session and graph_def are provided.
"""
- hints = _find_all_hints_in_graph_def(session)
- current_graph_def = session.graph_def
- for call in hints.values():
- input_names = [None] * len(call.inputs)
- output_names = [None] * len(call.outputs)
- output_dtypes = [None] * len(call.outputs)
- output_quantized = False
- for input_index, tensor in call.inputs.items():
- input_names[input_index] = _tensor_name_base(tensor)
- for output_index, tensor in call.outputs.items():
- output_names[output_index] = _tensor_name_base(tensor)
- output_dtypes[output_index] = tensor.dtype.as_datatype_enum
- # TODO(aselle): Support quantized flag properly
- current_graph_def = _framework.fuse_op(
- current_graph_def, input_names, output_names, output_dtypes,
- output_quantized, call.uuid, call.function_name)
- for node in current_graph_def.node:
- if node.name == call.uuid:
- for param, tensor in call.params.items():
- node.attr[param].tensor.CopyFrom(tensor)
- return current_graph_def
-
-
-_allowed_symbols = ["OpHint", "convert_op_hints_to_stubs"]
+
+ if session is not None and graph_def is not None:
+ raise ValueError("Provide only one of session and graph_def.")
+
+ if session is not None:
+ return _convert_op_hints_to_stubs_helper(session.graph_def, write_callback)
+ elif graph_def is not None:
+ return _convert_op_hints_to_stubs_helper(graph_def, write_callback)
+ else:
+ raise ValueError("Must specify session or graph_def as input.")
+
+
+_allowed_symbols = [
+ "OpHint", "convert_op_hints_to_stubs", "convert_op_hints_to_stubs_new"
+]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index a76cc39635..7d7a4ba94a 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -47,6 +47,9 @@ def _get_toco_converter(flags):
Returns:
TocoConverter object.
+
+ Raises:
+ ValueError: Invalid flags.
"""
# Parse input and output arrays.
input_arrays = _parse_array(flags.input_arrays)
@@ -77,6 +80,9 @@ def _get_toco_converter(flags):
elif flags.keras_model_file:
converter_fn = lite.TocoConverter.from_keras_model_file
converter_kwargs["model_file"] = flags.keras_model_file
+ else:
+ raise ValueError("--graph_def_file, --saved_model_dir, or "
+ "--keras_model_file must be specified.")
return converter_fn(**converter_kwargs)
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 14f88b4c00..e2c537fa4d 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -169,6 +169,7 @@ enum BuiltinOperator : byte {
ONE_HOT = 85,
LOGICAL_AND = 86,
LOGICAL_NOT = 87,
+ UNPACK = 88,
}
// Options for the builtin operators.
@@ -236,6 +237,7 @@ union BuiltinOptions {
OneHotOptions,
LogicalAndOptions,
LogicalNotOptions,
+ UnpackOptions,
}
enum Padding : byte { SAME, VALID }
@@ -565,6 +567,11 @@ table LogicalAndOptions {
table LogicalNotOptions {
}
+table UnpackOptions {
+ num:int;
+ axis:int;
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 3efa153e2c..d367d9a93a 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -220,6 +220,9 @@ struct LogicalAndOptionsT;
struct LogicalNotOptions;
struct LogicalNotOptionsT;
+struct UnpackOptions;
+struct UnpackOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -373,11 +376,12 @@ enum BuiltinOperator {
BuiltinOperator_ONE_HOT = 85,
BuiltinOperator_LOGICAL_AND = 86,
BuiltinOperator_LOGICAL_NOT = 87,
+ BuiltinOperator_UNPACK = 88,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_LOGICAL_NOT
+ BuiltinOperator_MAX = BuiltinOperator_UNPACK
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[87] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[88] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -465,7 +469,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[87] {
BuiltinOperator_LOGICAL_OR,
BuiltinOperator_ONE_HOT,
BuiltinOperator_LOGICAL_AND,
- BuiltinOperator_LOGICAL_NOT
+ BuiltinOperator_LOGICAL_NOT,
+ BuiltinOperator_UNPACK
};
return values;
}
@@ -560,6 +565,7 @@ inline const char **EnumNamesBuiltinOperator() {
"ONE_HOT",
"LOGICAL_AND",
"LOGICAL_NOT",
+ "UNPACK",
nullptr
};
return names;
@@ -635,11 +641,12 @@ enum BuiltinOptions {
BuiltinOptions_OneHotOptions = 61,
BuiltinOptions_LogicalAndOptions = 62,
BuiltinOptions_LogicalNotOptions = 63,
+ BuiltinOptions_UnpackOptions = 64,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_LogicalNotOptions
+ BuiltinOptions_MAX = BuiltinOptions_UnpackOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[64] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[65] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -704,7 +711,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[64] {
BuiltinOptions_LogicalOrOptions,
BuiltinOptions_OneHotOptions,
BuiltinOptions_LogicalAndOptions,
- BuiltinOptions_LogicalNotOptions
+ BuiltinOptions_LogicalNotOptions,
+ BuiltinOptions_UnpackOptions
};
return values;
}
@@ -775,6 +783,7 @@ inline const char **EnumNamesBuiltinOptions() {
"OneHotOptions",
"LogicalAndOptions",
"LogicalNotOptions",
+ "UnpackOptions",
nullptr
};
return names;
@@ -1041,6 +1050,10 @@ template<> struct BuiltinOptionsTraits<LogicalNotOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_LogicalNotOptions;
};
+template<> struct BuiltinOptionsTraits<UnpackOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_UnpackOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1576,6 +1589,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_LogicalNotOptions ?
reinterpret_cast<const LogicalNotOptionsT *>(value) : nullptr;
}
+ UnpackOptionsT *AsUnpackOptions() {
+ return type == BuiltinOptions_UnpackOptions ?
+ reinterpret_cast<UnpackOptionsT *>(value) : nullptr;
+ }
+ const UnpackOptionsT *AsUnpackOptions() const {
+ return type == BuiltinOptions_UnpackOptions ?
+ reinterpret_cast<const UnpackOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -5649,6 +5670,72 @@ inline flatbuffers::Offset<LogicalNotOptions> CreateLogicalNotOptions(
flatbuffers::Offset<LogicalNotOptions> CreateLogicalNotOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct UnpackOptionsT : public flatbuffers::NativeTable {
+ typedef UnpackOptions TableType;
+ int32_t num;
+ int32_t axis;
+ UnpackOptionsT()
+ : num(0),
+ axis(0) {
+ }
+};
+
+struct UnpackOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef UnpackOptionsT NativeTableType;
+ enum {
+ VT_NUM = 4,
+ VT_AXIS = 6
+ };
+ int32_t num() const {
+ return GetField<int32_t>(VT_NUM, 0);
+ }
+ int32_t axis() const {
+ return GetField<int32_t>(VT_AXIS, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_NUM) &&
+ VerifyField<int32_t>(verifier, VT_AXIS) &&
+ verifier.EndTable();
+ }
+ UnpackOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(UnpackOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<UnpackOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnpackOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct UnpackOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_num(int32_t num) {
+ fbb_.AddElement<int32_t>(UnpackOptions::VT_NUM, num, 0);
+ }
+ void add_axis(int32_t axis) {
+ fbb_.AddElement<int32_t>(UnpackOptions::VT_AXIS, axis, 0);
+ }
+ explicit UnpackOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ UnpackOptionsBuilder &operator=(const UnpackOptionsBuilder &);
+ flatbuffers::Offset<UnpackOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<UnpackOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<UnpackOptions> CreateUnpackOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t num = 0,
+ int32_t axis = 0) {
+ UnpackOptionsBuilder builder_(_fbb);
+ builder_.add_axis(axis);
+ builder_.add_num(num);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<UnpackOptions> CreateUnpackOptions(flatbuffers::FlatBufferBuilder &_fbb, const UnpackOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -5971,6 +6058,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const LogicalNotOptions *builtin_options_as_LogicalNotOptions() const {
return builtin_options_type() == BuiltinOptions_LogicalNotOptions ? static_cast<const LogicalNotOptions *>(builtin_options()) : nullptr;
}
+ const UnpackOptions *builtin_options_as_UnpackOptions() const {
+ return builtin_options_type() == BuiltinOptions_UnpackOptions ? static_cast<const UnpackOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -6254,6 +6344,10 @@ template<> inline const LogicalNotOptions *Operator::builtin_options_as<LogicalN
return builtin_options_as_LogicalNotOptions();
}
+template<> inline const UnpackOptions *Operator::builtin_options_as<UnpackOptions>() const {
+ return builtin_options_as_UnpackOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -8441,6 +8535,35 @@ inline flatbuffers::Offset<LogicalNotOptions> CreateLogicalNotOptions(flatbuffer
_fbb);
}
+inline UnpackOptionsT *UnpackOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new UnpackOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void UnpackOptions::UnPackTo(UnpackOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = num(); _o->num = _e; };
+ { auto _e = axis(); _o->axis = _e; };
+}
+
+inline flatbuffers::Offset<UnpackOptions> UnpackOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnpackOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateUnpackOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<UnpackOptions> CreateUnpackOptions(flatbuffers::FlatBufferBuilder &_fbb, const UnpackOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const UnpackOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _num = _o->num;
+ auto _axis = _o->axis;
+ return tflite::CreateUnpackOptions(
+ _fbb,
+ _num,
+ _axis);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -8882,6 +9005,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const LogicalNotOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_UnpackOptions: {
+ auto ptr = reinterpret_cast<const UnpackOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -9152,6 +9279,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const LogicalNotOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_UnpackOptions: {
+ auto ptr = reinterpret_cast<const UnpackOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -9410,6 +9541,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const LogicalNotOptionsT *>(value);
return CreateLogicalNotOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_UnpackOptions: {
+ auto ptr = reinterpret_cast<const UnpackOptionsT *>(value);
+ return CreateUnpackOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -9668,6 +9803,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new LogicalNotOptionsT(*reinterpret_cast<LogicalNotOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_UnpackOptions: {
+ value = new UnpackOptionsT(*reinterpret_cast<UnpackOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -9990,6 +10129,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_UnpackOptions: {
+ auto ptr = reinterpret_cast<UnpackOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/string.h b/tensorflow/contrib/lite/string.h
index 7f8f4e851e..af3fadfcb3 100644
--- a/tensorflow/contrib/lite/string.h
+++ b/tensorflow/contrib/lite/string.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Abstract string. We don't want even absl at this level.
-#ifndef _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_H_
-#define _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_STRING_H_
+#define TENSORFLOW_CONTRIB_LITE_STRING_H_
#include <string>
@@ -26,4 +26,4 @@ using std::string;
} // namespace tflite
-#endif // _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_H_
+#endif // TENSORFLOW_CONTRIB_LITE_STRING_H_
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 9dd5c8ae44..597ee8fb1e 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -1324,6 +1324,71 @@ def make_conv_with_shared_weights_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+# Note: This is a regression test for a bug (b/112303004) that Toco incorrectly
+# transforms Conv into DepthwiseConv when two Conv ops share the same constant
+# weight tensor.
+def make_conv_to_depthwiseconv_with_shared_weights_tests(zip_path):
+ """Make a test where 2 Conv ops shared the same constant weight tensor."""
+
+ test_parameters = [{
+ "input_shape": [[1, 10, 10, 1]],
+ "filter_shape": [[3, 3]],
+ "strides": [[1, 1, 1, 1]],
+ "dilations": [[1, 1, 1, 1]],
+ "padding": ["SAME"],
+ "data_format": ["NHWC"],
+ "channel_multiplier": [3],
+ }]
+
+ def get_tensor_shapes(parameters):
+ input_shape = parameters["input_shape"]
+ filter_size = parameters["filter_shape"]
+ filter_shape = filter_size + [
+ input_shape[3], parameters["channel_multiplier"]
+ ]
+ return [input_shape, filter_shape]
+
+ def build_graph(parameters):
+ """Build a conv graph given `parameters`."""
+ input_shape, filter_shape = get_tensor_shapes(parameters)
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=input_shape)
+
+ # Construct a constant weights tensor which will be used by both Conv2D.
+ filter_tensor = tf.constant(
+ create_tensor_data(np.float32, filter_shape), dtype=tf.float32)
+ input_tensors = [input_tensor]
+
+ # Construct 2 Conv2D operations which use exactly the same input and
+ # weights.
+ result1 = tf.nn.conv2d(
+ input_tensor,
+ filter_tensor,
+ strides=parameters["strides"],
+ dilations=parameters["dilations"],
+ padding=parameters["padding"],
+ data_format=parameters["data_format"])
+ result2 = tf.nn.conv2d(
+ input_tensor,
+ filter_tensor,
+ strides=parameters["strides"],
+ dilations=parameters["dilations"],
+ padding=parameters["padding"],
+ data_format=parameters["data_format"])
+ # Add the 2 results up.
+ out = result1 + result2
+ return input_tensors, [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ # Build list of input values either containing 1 tensor (input) or 2 tensors
+ # (input, filter) based on whether filter is constant or variable input.
+ input_shape, unused_filter_shape = get_tensor_shapes(parameters)
+ values = [create_tensor_data(np.float32, input_shape)]
+ return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_depthwiseconv_tests(zip_path):
"""Make a set of tests to do convolution."""
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
index 71a98a3d56..4dacf9c84b 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.cc
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -139,7 +139,7 @@ class TfLiteDriver::Expectation {
TfLiteDriver::TfLiteDriver(bool use_nnapi, const string& delegate_name)
: use_nnapi_(use_nnapi) {
if (delegate_name == "EAGER") {
- delegate_.reset(new EagerDelegate());
+ delegate_ = EagerDelegate::Create();
}
}
@@ -173,7 +173,9 @@ void TfLiteDriver::LoadModel(const string& bin_file_path) {
interpreter_->UseNNAPI(use_nnapi_);
if (delegate_) {
- if (delegate_->Apply(interpreter_.get()) != kTfLiteOk) {
+ if (interpreter_->ModifyGraphWithDelegate(delegate_.get(),
+ /*allow_dynamic_tensors=*/true) !=
+ kTfLiteOk) {
Invalidate("Unable to the build graph using the delegate");
return;
}
diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
index 1f3ea2e1c7..18c904c6d4 100644
--- a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
+++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
@@ -106,6 +106,17 @@ class Allocator {
// Core allocation routine.
void Allocate(std::size_t size, Alloc* result) {
+ if (size == 0) {
+ // zero-sized arrays get a dummy alloc of (0, 0) that does not
+ // need to be kept in the books (no need to insert that into
+ // live_allocs_).
+ // Note: zero-sized arrays shouldn't exist, but handling that case
+ // here allows such pathological cases to get a cleaner error message
+ // later instead of generating spurious allocator failures.
+ result->start = 0;
+ result->end = 0;
+ return;
+ }
// Naive algorithm: pick the first gap between live allocations,
// that is wide enough for the new array.
std::size_t pos = 0;
@@ -128,6 +139,11 @@ class Allocator {
}
void Deallocate(const Alloc& a) {
+ // Special-case dummy allocs for zero-sized arrays.
+ if (a.start == 0 && a.end == 0) {
+ // Nothing needs to be done, these aren't kept in the books.
+ return;
+ }
auto iter = std::lower_bound(live_allocs_.begin(), live_allocs_.end(), a);
CHECK(iter != live_allocs_.end());
CHECK(*iter == a);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
index 1ea83abf8e..e88839be5d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
@@ -48,7 +48,17 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
// dimension.
return false;
}
- auto& weights_array = model->GetArray(conv_op->inputs[1]);
+
+ const auto& weights_name = conv_op->inputs[1];
+ if (CountOpsWithInput(*model, weights_name) > 1) {
+ // TODO(yunluli): Come up with a way to do the weights shuffling only once.
+ AddMessageF(
+ "Not changing %s to DepthwiseConv because the weights is consumed by "
+ "another op.",
+ LogName(*conv_op));
+ return false;
+ }
+ auto& weights_array = model->GetArray(weights_name);
if (!weights_array.buffer) {
// Yield until the weights are resolved as a constant array.
return false;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
index d395d7a6a0..f5f2f77460 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
@@ -117,6 +117,7 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
&quantized_max);
if (fakequant_op->narrow_range) {
quantized_min++;
+ output_array.narrow_range = true;
}
// It is important for matching accuracy between TF training and TFLite
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
index 41562ab393..a6f665b5f0 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
@@ -100,13 +100,7 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
AddMessageF("Resolving constant reshape of %s", LogName(*op));
- if (input_array.minmax) {
- output_array.GetOrCreateMinMax() = input_array.GetMinMax();
- }
- if (input_array.quantization_params) {
- output_array.GetOrCreateQuantizationParams() =
- input_array.GetQuantizationParams();
- }
+ CopyMinMaxAndQuantizationRelatedFields(input_array, &output_array);
// Erase input arrays if no longer used.
for (const auto& input : op->inputs) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc
index 0b0d070714..5cfa1a5582 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc
@@ -128,15 +128,7 @@ bool ResolveConstantTile::Run(Model* model, std::size_t op_index) {
multiples_array.data_type == ArrayDataType::kInt64)
<< "Only int32/int64 indices are supported";
- // Copy min/max info if present. The ranges of the selected values may be
- // a subset of the original range but we want to ensure the quantization
- // params stay the same.
- if (input_array.minmax) {
- const auto& input_minmax = input_array.GetMinMax();
- auto& output_minmax = output_array.GetOrCreateMinMax();
- output_minmax.min = input_minmax.min;
- output_minmax.max = input_minmax.max;
- }
+ CopyMinMaxAndQuantizationRelatedFields(input_array, &output_array);
CHECK(!output_array.buffer);
switch (output_array.data_type) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc
index 1fd20314b1..fe15dfa06f 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc
@@ -128,13 +128,7 @@ bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) {
}
const Array& input_array = model->GetArray(op->inputs[0]);
- if (input_array.minmax) {
- output_array.GetOrCreateMinMax() = input_array.GetMinMax();
- }
- if (input_array.quantization_params) {
- output_array.GetOrCreateQuantizationParams() =
- input_array.GetQuantizationParams();
- }
+ CopyMinMaxAndQuantizationRelatedFields(input_array, &output_array);
if (op->perm.empty()) {
// Yield until perm has been populated by ResolveTransposeAttributes.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc
index 5f0cece67a..fedf4441e2 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc
@@ -154,6 +154,7 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) {
pack_op->inputs = pack_inputs;
pack_op->outputs = {batch_op->outputs[0]};
pack_op->axis = 0;
+ pack_op->values_count = pack_inputs.size();
model->operators.emplace(tail_it, pack_op);
// Remove the old batch matmul now that we've unrolled.
diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.h b/tensorflow/contrib/lite/toco/python/toco_python_api.h
index 7e8ad9c1da..ee054bbed9 100644
--- a/tensorflow/contrib/lite/toco/python/toco_python_api.h
+++ b/tensorflow/contrib/lite/toco/python/toco_python_api.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
-#define _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
#include <Python.h>
#include <string>
@@ -33,4 +33,4 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
} // namespace toco
-#endif // _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 2ad2719811..3a4542f522 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -2278,4 +2278,14 @@ void UndoWeightsShuffling(Model* model) {
}
}
+void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst) {
+ if (src.minmax) {
+ dst->GetOrCreateMinMax() = src.GetMinMax();
+ }
+ if (src.quantization_params) {
+ dst->GetOrCreateQuantizationParams() = src.GetQuantizationParams();
+ }
+ dst->narrow_range = src.narrow_range;
+}
+
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index b99e6111fe..bdeb203024 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -348,6 +348,9 @@ tensorflow::Status NumElements(const std::vector<T>& shape, U* num_elements) {
// so that the rest of toco doesn't need to know about shuffled weights.
void UndoWeightsShuffling(Model* model);
+// Copies minmax, quantization_params, and narrow_range.
+void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst);
+
} // namespace toco
#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD
index 2cb07eb6ec..dc97d22401 100644
--- a/tensorflow/contrib/lite/tools/benchmark/BUILD
+++ b/tensorflow/contrib/lite/tools/benchmark/BUILD
@@ -5,8 +5,8 @@ package(default_visibility = [
licenses(["notice"]) # Apache 2.0
load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
-load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts")
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts")
common_copts = ["-Wall"] + tflite_copts()
@@ -35,6 +35,25 @@ cc_binary(
],
)
+cc_binary(
+ name = "benchmark_model_plus_eager",
+ srcs = [
+ "benchmark_main.cc",
+ ],
+ copts = common_copts + ["-DTFLITE_EXTENDED"],
+ linkopts = tflite_linkopts() + select({
+ "//tensorflow:android": [
+ "-pie", # Android 5.0 and later supports only PIE
+ "-lm", # some builtin ops, e.g., tanh, need -lm
+ ],
+ "//conditions:default": [],
+ }),
+ deps = [
+ ":benchmark_tflite_model_plus_eager_lib",
+ ":logging",
+ ],
+)
+
cc_test(
name = "benchmark_test",
srcs = ["benchmark_test.cc"],
@@ -88,7 +107,25 @@ cc_library(
"//tensorflow/contrib/lite:string_util",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/profiling:profile_summarizer",
- "//tensorflow/contrib/lite/profiling:profiler",
+ ],
+)
+
+cc_library(
+ name = "benchmark_tflite_model_plus_eager_lib",
+ srcs = [
+ "benchmark_tflite_model.cc",
+ "logging.h",
+ ],
+ hdrs = ["benchmark_tflite_model.h"],
+ copts = common_copts + ["-DTFLITE_EXTENDED"],
+ deps = [
+ ":benchmark_model_lib",
+ ":logging",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/delegates/eager:delegate",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/profiling:profile_summarizer",
],
)
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
index 7f97f5d0cd..02039922b4 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -23,6 +23,9 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#ifdef TFLITE_EXTENDED
+#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
+#endif // TFLITE_EXTENDED
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/op_resolver.h"
@@ -261,6 +264,16 @@ void BenchmarkTfLiteModel::Init() {
bool use_nnapi = params_.Get<bool>("use_nnapi");
interpreter->UseNNAPI(use_nnapi);
+
+#ifdef TFLITE_EXTENDED
+ TFLITE_LOG(INFO) << "Instantiating Eager Delegate";
+ delegate_ = EagerDelegate::Create();
+ if (delegate_) {
+ interpreter->ModifyGraphWithDelegate(delegate_.get(),
+ /*allow_dynamic_tensors=*/true);
+ }
+#endif // TFLITE_EXTENDED
+
auto interpreter_inputs = interpreter->inputs();
if (!inputs.empty()) {
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
index 9931dcbafe..4b22d80cbb 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
@@ -20,6 +20,9 @@ limitations under the License.
#include <string>
#include <vector>
+#ifdef TFLITE_EXTENDED
+#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
+#endif // TFLITE_EXTENDED
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/profiling/profile_summarizer.h"
#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h"
@@ -52,6 +55,7 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
public:
BenchmarkTfLiteModel();
BenchmarkTfLiteModel(BenchmarkParams params);
+ virtual ~BenchmarkTfLiteModel() {}
std::vector<Flag> GetFlags() override;
void LogParams() override;
@@ -59,7 +63,6 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
uint64_t ComputeInputBytes() override;
void Init() override;
void RunImpl() override;
- virtual ~BenchmarkTfLiteModel() {}
struct InputLayerInfo {
std::string name;
@@ -67,6 +70,9 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
};
private:
+#ifdef TFLITE_EXTENDED
+ std::unique_ptr<EagerDelegate> delegate_;
+#endif // TFLITE_EXTENDED
std::unique_ptr<tflite::FlatBufferModel> model;
std::unique_ptr<tflite::Interpreter> interpreter;
std::vector<InputLayerInfo> inputs;
diff --git a/tensorflow/contrib/lite/util.cc b/tensorflow/contrib/lite/util.cc
index 8ccb65c24f..7950653da9 100644
--- a/tensorflow/contrib/lite/util.cc
+++ b/tensorflow/contrib/lite/util.cc
@@ -14,8 +14,15 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/util.h"
+#include <cstring>
+
namespace tflite {
+bool IsEagerOp(const char* custom_name) {
+ return custom_name && strncmp(custom_name, kEagerCustomCodePrefix,
+ strlen(kEagerCustomCodePrefix)) == 0;
+}
+
TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input) {
return ConvertArrayToTfLiteIntArray(input.size(), input.data());
}
diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h
index 3c4801183b..f5b208afbb 100644
--- a/tensorflow/contrib/lite/util.h
+++ b/tensorflow/contrib/lite/util.h
@@ -26,6 +26,16 @@ limitations under the License.
namespace tflite {
+// The prefix of Eager op custom code.
+// This will be matched agains the `custom_code` field in `OperatorCode`
+// Flatbuffer Table.
+// WARNING: This is an experimental API and subject to change.
+constexpr char kEagerCustomCodePrefix[] = "Eager";
+
+// Checks whether the prefix of the custom name indicates the operation is an
+// Eager operation.
+bool IsEagerOp(const char* custom_name);
+
// Converts a `std::vector` to a `TfLiteIntArray`. The caller takes ownership
// of the returned pointer.
TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input);
diff --git a/tensorflow/contrib/lite/util_test.cc b/tensorflow/contrib/lite/util_test.cc
index 04579c53aa..32bf917a59 100644
--- a/tensorflow/contrib/lite/util_test.cc
+++ b/tensorflow/contrib/lite/util_test.cc
@@ -41,6 +41,16 @@ TEST(ConvertVectorToTfLiteIntArray, TestWithEmptyVector) {
TfLiteIntArrayFree(output);
}
+TEST(UtilTest, IsEagerOp) {
+ EXPECT_TRUE(IsEagerOp("Eager"));
+ EXPECT_TRUE(IsEagerOp("EagerOp"));
+ EXPECT_FALSE(IsEagerOp("eager"));
+ EXPECT_FALSE(IsEagerOp("Eage"));
+ EXPECT_FALSE(IsEagerOp("OpEager"));
+ EXPECT_FALSE(IsEagerOp(nullptr));
+ EXPECT_FALSE(IsEagerOp(""));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lookup/BUILD b/tensorflow/contrib/lookup/BUILD
index e3928a82a2..83e80f25bc 100644
--- a/tensorflow/contrib/lookup/BUILD
+++ b/tensorflow/contrib/lookup/BUILD
@@ -34,6 +34,7 @@ tf_py_test(
":lookup_py",
"//third_party/py/numpy",
"@six_archive//:six",
+ "//tensorflow/contrib/data",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index 8c0bfefb30..f83765a48d 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -18,6 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_lookup_ops
@@ -39,6 +42,7 @@ from tensorflow.python.ops.lookup_ops import TextFileIndex
from tensorflow.python.ops.lookup_ops import TextFileInitializer
from tensorflow.python.ops.lookup_ops import TextFileStringTableInitializer
# pylint: enable=unused-import
+from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.saver import BaseSaverBuilder
from tensorflow.python.util.deprecation import deprecated
@@ -285,7 +289,7 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None):
return table.lookup(tensor)
-class MutableHashTable(LookupInterface):
+class MutableHashTable(LookupInterface, checkpointable.CheckpointableBase):
"""A generic mutable hash table implementation.
Data can be inserted by calling the insert method. It does not support
@@ -336,6 +340,13 @@ class MutableHashTable(LookupInterface):
dtype=value_dtype)
self._value_shape = self._default_value.get_shape()
+ executing_eagerly = context.executing_eagerly()
+ if executing_eagerly and shared_name is None:
+ # TODO(allenl): This will leak memory due to kernel caching by the
+ # shared_name attribute value (but is better than the alternative of
+ # sharing everything by default when executing eagerly; hopefully creating
+ # tables in a loop is uncommon).
+ shared_name = "table_%d" % (ops.uid(),)
# The table must be shared if checkpointing is requested for multi-worker
# training to work correctly. Use the node name if no shared_name has been
# explicitly specified.
@@ -355,9 +366,12 @@ class MutableHashTable(LookupInterface):
value_dtype=value_dtype,
value_shape=self._default_value.get_shape(),
name=name)
+ if executing_eagerly:
+ op_name = None
+ else:
+ op_name = self._table_ref.op.name.split("/")[-1]
super(MutableHashTable, self).__init__(key_dtype, value_dtype,
- self._table_ref.op.name.split(
- "/")[-1])
+ op_name)
if checkpoint:
saveable = MutableHashTable._Saveable(self, name)
@@ -419,11 +433,10 @@ class MutableHashTable(LookupInterface):
TypeError: when `keys` or `values` doesn't match the table data
types.
"""
- # pylint: disable=protected-access
- lookup_ops._check_table_dtypes(self, keys.dtype, values.dtype)
- # pylint: enable=protected-access
with ops.name_scope(name, "%s_lookup_table_insert" % self._name,
[self._table_ref, keys, values]) as name:
+ keys = ops.convert_to_tensor(keys, self._key_dtype, name="keys")
+ values = ops.convert_to_tensor(values, self._value_dtype, name="values")
with ops.colocate_with(self._table_ref):
# pylint: disable=protected-access
op = gen_lookup_ops.lookup_table_insert_v2(
@@ -447,6 +460,10 @@ class MutableHashTable(LookupInterface):
self._table_ref, self._key_dtype, self._value_dtype, name=name)
return exported_keys, exported_values
+ def _gather_saveables_for_checkpoint(self):
+ """For object-based checkpointing."""
+ return {"table": functools.partial(MutableHashTable._Saveable, table=self)}
+
class _Saveable(BaseSaverBuilder.SaveableObject):
"""SaveableObject implementation for MutableHashTable."""
@@ -459,14 +476,15 @@ class MutableHashTable(LookupInterface):
# pylint: disable=protected-access
super(MutableHashTable._Saveable, self).__init__(table, specs, name)
- def restore(self, restored_tensors, unused_restored_shapes):
+ def restore(self, restored_tensors, restored_shapes):
+ del restored_shapes # unused
# pylint: disable=protected-access
with ops.colocate_with(self.op._table_ref):
return gen_lookup_ops.lookup_table_import_v2(
self.op._table_ref, restored_tensors[0], restored_tensors[1])
-class MutableDenseHashTable(LookupInterface):
+class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase):
"""A generic mutable hash table implementation using tensors as backing store.
Data can be inserted by calling the insert method. It does not support
@@ -537,6 +555,13 @@ class MutableDenseHashTable(LookupInterface):
use_node_name_sharing = checkpoint and shared_name is None
empty_key = ops.convert_to_tensor(
empty_key, dtype=key_dtype, name="empty_key")
+ executing_eagerly = context.executing_eagerly()
+ if executing_eagerly and shared_name is None:
+ # TODO(allenl): This will leak memory due to kernel caching by the
+ # shared_name attribute value (but is better than the alternative of
+ # sharing everything by default when executing eagerly; hopefully creating
+ # tables in a loop is uncommon).
+ shared_name = "table_%d" % (ops.uid(),)
self._table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
empty_key=empty_key,
shared_name=shared_name,
@@ -545,8 +570,12 @@ class MutableDenseHashTable(LookupInterface):
value_shape=self._value_shape,
initial_num_buckets=initial_num_buckets,
name=name)
+ if executing_eagerly:
+ op_name = None
+ else:
+ op_name = self._table_ref.op.name.split("/")[-1]
super(MutableDenseHashTable, self).__init__(
- key_dtype, value_dtype, self._table_ref.op.name.split("/")[-1])
+ key_dtype, value_dtype, op_name)
if checkpoint:
saveable = MutableDenseHashTable._Saveable(self, name)
@@ -637,6 +666,11 @@ class MutableDenseHashTable(LookupInterface):
return exported_keys, exported_values
+ def _gather_saveables_for_checkpoint(self):
+ """For object-based checkpointing."""
+ return {"table": functools.partial(
+ MutableDenseHashTable._Saveable, table=self)}
+
class _Saveable(BaseSaverBuilder.SaveableObject):
"""SaveableObject implementation for MutableDenseHashTable."""
@@ -649,7 +683,8 @@ class MutableDenseHashTable(LookupInterface):
# pylint: disable=protected-access
super(MutableDenseHashTable._Saveable, self).__init__(table, specs, name)
- def restore(self, restored_tensors, unused_restored_shapes):
+ def restore(self, restored_tensors, restored_shapes):
+ del restored_shapes # unused
# pylint: disable=protected-access
with ops.colocate_with(self.op._table_ref):
return gen_lookup_ops.lookup_table_import_v2(
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 6fb5244fc6..f9b0358a36 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -23,6 +23,7 @@ import numpy as np
import six
from tensorflow.contrib import lookup
+from tensorflow.contrib.data.python.ops import counter
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -37,6 +38,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
+from tensorflow.python.training.checkpointable import util as checkpointable
class HashTableOpTest(test.TestCase):
@@ -382,6 +384,59 @@ class MutableHashTableOpTest(test.TestCase):
output = table.lookup(input_string)
self.assertAllEqual([-1, 0, 1, 2, -1], output.eval())
+ @test_util.run_in_graph_and_eager_modes
+ def testObjectSaveRestore(self):
+ save_dir = os.path.join(self.get_temp_dir(), "save_restore")
+ save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
+
+ v0 = variables.Variable(10.0, name="v0")
+ v1 = variables.Variable(20.0, name="v1")
+
+ default_val = -1
+ keys = constant_op.constant(["b", "c", "d"], dtypes.string)
+ values = constant_op.constant([0, 1, 2], dtypes.int64)
+ table = lookup.MutableHashTable(
+ dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)
+
+ checkpoint = checkpointable.Checkpoint(table=table, v0=v0, v1=v1)
+ self.evaluate([v0.initializer, v1.initializer])
+
+ # Check that the parameter nodes have been initialized.
+ self.assertEqual(10.0, self.evaluate(v0))
+ self.assertEqual(20.0, self.evaluate(v1))
+
+ self.assertAllEqual(0, self.evaluate(table.size()))
+ self.evaluate(table.insert(keys, values))
+ self.assertAllEqual(3, self.evaluate(table.size()))
+
+ save_path = checkpoint.save(save_prefix)
+ del table, checkpoint, v0, v1
+
+ v0 = variables.Variable(-1.0, name="v0")
+ v1 = variables.Variable(-1.0, name="v1")
+ default_val = -1
+ table = lookup.MutableHashTable(
+ dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)
+ self.evaluate(table.insert(
+ constant_op.constant(["a", "c"], dtypes.string),
+ constant_op.constant([12, 24], dtypes.int64)))
+ self.assertAllEqual(2, self.evaluate(table.size()))
+
+ checkpoint = checkpointable.Checkpoint(table=table, v0=v0, v1=v1)
+
+ # Restore the saved values in the parameter nodes.
+ checkpoint.restore(save_path).run_restore_ops()
+ # Check that the parameter nodes have been restored.
+ self.assertEqual(10.0, self.evaluate(v0))
+ self.assertEqual(20.0, self.evaluate(v1))
+
+ self.assertAllEqual(3, self.evaluate(table.size()))
+
+ input_string = constant_op.constant(["a", "b", "c", "d", "e"],
+ dtypes.string)
+ output = table.lookup(input_string)
+ self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output))
+
def testSharing(self):
# Start a server to store the table state
server = server_lib.Server(
@@ -646,11 +701,11 @@ class MutableHashTableOpTest(test.TestCase):
default_val)
# insert with keys of the wrong type
- with self.assertRaises(TypeError):
+ with self.assertRaises(ValueError):
table.insert(constant_op.constant([4, 5, 6]), values).run()
# insert with values of the wrong type
- with self.assertRaises(TypeError):
+ with self.assertRaises(ValueError):
table.insert(keys, constant_op.constant(["a", "b", "c"])).run()
self.assertAllEqual(0, table.size().eval())
@@ -1009,6 +1064,60 @@ class MutableDenseHashTableOpTest(test.TestCase):
output = table.lookup(input_string)
self.assertAllEqual([-1, 0, 1, 2, -1], output.eval())
+ @test_util.run_in_graph_and_eager_modes
+ def testObjectSaveRestore(self):
+ save_dir = os.path.join(self.get_temp_dir(), "save_restore")
+ save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
+
+ default_value = -1
+ empty_key = 0
+ keys = constant_op.constant([11, 12, 13], dtypes.int64)
+ values = constant_op.constant([0, 1, 2], dtypes.int64)
+ save_table = lookup.MutableDenseHashTable(
+ dtypes.int64,
+ dtypes.int64,
+ default_value=default_value,
+ empty_key=empty_key,
+ name="t1",
+ checkpoint=True,
+ initial_num_buckets=32)
+
+ save_checkpoint = checkpointable.Checkpoint(table=save_table)
+
+ self.assertAllEqual(0, self.evaluate(save_table.size()))
+ self.evaluate(save_table.insert(keys, values))
+ self.assertAllEqual(3, self.evaluate(save_table.size()))
+ self.assertAllEqual(32, len(self.evaluate(save_table.export()[0])))
+
+ save_path = save_checkpoint.save(save_prefix)
+ del save_table, save_checkpoint
+
+ load_table = lookup.MutableDenseHashTable(
+ dtypes.int64,
+ dtypes.int64,
+ default_value=default_value,
+ empty_key=empty_key,
+ name="t1",
+ checkpoint=True,
+ initial_num_buckets=64)
+ self.evaluate(load_table.insert(
+ constant_op.constant([11, 14], dtypes.int64),
+ constant_op.constant([12, 24], dtypes.int64)))
+ self.assertAllEqual(2, self.evaluate(load_table.size()))
+ self.assertAllEqual(64, len(self.evaluate(load_table.export()[0])))
+
+ restore_checkpoint = checkpointable.Checkpoint(table=load_table)
+
+ # Restore the saved values in the parameter nodes.
+ restore_checkpoint.restore(save_path).run_restore_ops()
+
+ self.assertAllEqual(3, self.evaluate(load_table.size()))
+ self.assertAllEqual(32, len(self.evaluate(load_table.export()[0])))
+
+ input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64)
+ output = load_table.lookup(input_string)
+ self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output))
+
def testVectorSaveRestore(self):
save_dir = os.path.join(self.get_temp_dir(), "vector_save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
@@ -2397,5 +2506,60 @@ class IdTableWithHashBucketsTest(test.TestCase):
hasher_spec=lookup.StrongHashSpec([None, 2]))
+class MutableHashTableBenchmark(test.Benchmark):
+
+ def _create_table(self):
+ return lookup.MutableHashTable(dtypes.int64, dtypes.float32, 0.0)
+
+ def benchmark_single_repeated_scalar_insert_scalar(self):
+ table = self._create_table()
+ value = variables.Variable(1.0)
+ insert = table.insert(0, value)
+ size = table.size()
+ with session.Session() as sess:
+ sess.run(value.initializer)
+ self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=10000)
+ assert sess.run(size) == 1
+
+ def benchmark_many_repeated_scalar_insert_scalar(self):
+ table = self._create_table()
+ c = counter.Counter().make_one_shot_iterator().get_next()
+ value = variables.Variable(1.0)
+ insert = table.insert(c, value)
+ size = table.size()
+ with session.Session() as sess:
+ sess.run(value.initializer)
+ self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=10000)
+ assert sess.run(size) >= 10000
+
+ def benchmark_single_repeated_batch_32_insert_scalar(self):
+ table = self._create_table()
+ value = variables.Variable([1.0] * 32)
+ insert = table.insert(list(range(32)), value)
+ size = table.size()
+ with session.Session() as sess:
+ sess.run(value.initializer)
+ self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=1000)
+ assert sess.run(size) == 32
+
+ def benchmark_many_repeated_batch_32_insert_scalar(self):
+ table = self._create_table()
+ c = counter.Counter().make_one_shot_iterator().get_next()
+ value = variables.Variable([1.0] * 32)
+ insert = table.insert(32 * c + list(range(32)), value)
+ size = table.size()
+ with session.Session() as sess:
+ sess.run(value.initializer)
+ self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=1000)
+ assert sess.run(size) >= 1000*32
+
+
+class MutableDenseHashTableBenchmark(MutableHashTableBenchmark):
+
+ def _create_table(self):
+ return lookup.MutableDenseHashTable(
+ dtypes.int64, dtypes.float32, default_value=0.0, empty_key=-1)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/losses/__init__.py b/tensorflow/contrib/losses/__init__.py
index db58647d48..92b380df53 100644
--- a/tensorflow/contrib/losses/__init__.py
+++ b/tensorflow/contrib/losses/__init__.py
@@ -15,7 +15,7 @@
"""Ops for building neural network losses.
-See @{$python/contrib.losses}.
+See [Contrib Losses](https://tensorflow.org/api_guides/python/contrib.losses).
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/losses/python/losses/__init__.py b/tensorflow/contrib/losses/python/losses/__init__.py
index 6e9d1d4a77..1675387227 100644
--- a/tensorflow/contrib/losses/python/losses/__init__.py
+++ b/tensorflow/contrib/losses/python/losses/__init__.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""Ops for building neural network losses.
-See @{$python/contrib.losses}.
+See [Contrib Losses](https://tensorflow.org/api_guides/python/contrib.losses).
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/losses/python/metric_learning/__init__.py b/tensorflow/contrib/losses/python/metric_learning/__init__.py
index 4e551d6aca..3d93a4d0ac 100644
--- a/tensorflow/contrib/losses/python/metric_learning/__init__.py
+++ b/tensorflow/contrib/losses/python/metric_learning/__init__.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""Ops for building neural network losses.
-See @{$python/contrib.losses}.
+See [Contrib Losses](https://tensorflow.org/api_guides/python/contrib.losses).
"""
from __future__ import absolute_import
@@ -35,5 +35,3 @@ _allowed_symbols = [
'triplet_semihard_loss',
]
remove_undocumented(__name__, _allowed_symbols)
-
-
diff --git a/tensorflow/contrib/makefile/compile_nsync.sh b/tensorflow/contrib/makefile/compile_nsync.sh
index a28fc3a87f..cb4c94d92f 100755
--- a/tensorflow/contrib/makefile/compile_nsync.sh
+++ b/tensorflow/contrib/makefile/compile_nsync.sh
@@ -256,6 +256,7 @@ for arch in $archs; do
esac
makefile='
+ AR := ${NDK_ROOT}/toolchains/'"$toolchain"'/prebuilt/'"$android_os_arch"'/bin/'"$bin_prefix"'-ar
CC=${CC_PREFIX} \
${NDK_ROOT}/toolchains/'"$toolchain"'/prebuilt/'"$android_os_arch"'/bin/'"$bin_prefix"'-g++
PLATFORM_CPPFLAGS=--sysroot \
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index ecf2e120df..66a3315700 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -301,7 +301,6 @@ tensorflow/core/ops/array_grad.cc
tensorflow/core/kernels/spacetobatch_functor.cc
tensorflow/core/kernels/spacetobatch_op.cc
tensorflow/core/kernels/batchtospace_op.cc
-tensorflow/core/kernels/warn_about_ints.cc
tensorflow/core/kernels/segment_reduction_ops.cc
tensorflow/core/ops/audio_ops.cc
tensorflow/core/kernels/decode_proto_op.cc
diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py
index 88798d61b7..5645784f8d 100644
--- a/tensorflow/contrib/metrics/__init__.py
+++ b/tensorflow/contrib/metrics/__init__.py
@@ -14,7 +14,9 @@
# ==============================================================================
"""Ops for evaluation metrics and summary statistics.
-See the @{$python/contrib.metrics} guide.
+See the
+[Contrib Metrics](https://tensorflow.org/api_guides/python/contrib.metrics)
+guide.
@@auc_with_confidence_intervals
@@streaming_accuracy
diff --git a/tensorflow/contrib/model_pruning/BUILD b/tensorflow/contrib/model_pruning/BUILD
index 16ddc38f5a..e662b11be8 100644
--- a/tensorflow/contrib/model_pruning/BUILD
+++ b/tensorflow/contrib/model_pruning/BUILD
@@ -119,6 +119,7 @@ py_test(
deps = [
":pruning_utils",
"//tensorflow/python:client_testlib",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py
index cd58526ed3..a81abac2fa 100644
--- a/tensorflow/contrib/model_pruning/python/pruning.py
+++ b/tensorflow/contrib/model_pruning/python/pruning.py
@@ -476,8 +476,8 @@ class Pruning(object):
smoothed_threshold, new_mask = self._update_mask(pooled_weights,
threshold)
- updated_mask = pruning_utils.kronecker_product(
- new_mask, array_ops.ones(self._block_dim))
+
+ updated_mask = pruning_utils.expand_tensor(new_mask, self._block_dim)
sliced_mask = array_ops.slice(
updated_mask, [0, 0],
[squeezed_weights.get_shape()[0],
diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py
index ef6c6a3f5d..b50a372e9d 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_utils.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py
@@ -69,7 +69,7 @@ def weight_threshold_variable(var, scope):
scope: The variable scope of the variable var
Returns:
- a scalar threshold variable initialized to 0.
+ A scalar threshold variable initialized to 0.
"""
with variable_scope.variable_scope(scope):
threshold = variable_scope.get_variable(
@@ -97,6 +97,74 @@ def kronecker_product(mat1, mat2):
return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])
+def expand_tensor(tensor, block_dims):
+ """Expands a 2D tensor by replicating the tensor values.
+
+ This is equivalent to the kronecker product of the tensor and a matrix of
+ ones of size block_dims.
+
+ Example:
+
+ tensor = [[1,2]
+ [3,4]]
+ block_dims = [2,2]
+
+ result = [[1 1 2 2]
+ [1 1 2 2]
+ [3 3 4 4]
+ [3 3 4 4]]
+
+ Args:
+ tensor: A 2D tensor that needs to be expanded.
+ block_dims: List of integers specifying the expansion factor.
+
+ Returns:
+ The expanded tensor
+
+ Raises:
+ ValueError: if tensor is not rank-2 or block_dims is does not have 2
+ elements.
+ """
+ if tensor.get_shape().ndims != 2:
+ raise ValueError('Input tensor must be rank 2')
+
+ if len(block_dims) != 2:
+ raise ValueError('block_dims must have 2 elements')
+
+ block_height, block_width = block_dims
+
+ def _tile_rows(tensor, multiple):
+ """Create a new tensor by tiling the tensor along rows."""
+ return array_ops.tile(tensor, [multiple, 1])
+
+ def _generate_indices(num_rows, block_dim):
+ indices = np.zeros(shape=[num_rows * block_dim, 1], dtype=np.int32)
+ for k in range(block_dim):
+ for r in range(num_rows):
+ indices[k * num_rows + r] = r * block_dim + k
+ return indices
+
+ def _replicate_rows(tensor, multiple):
+ tensor_shape = tensor.shape.as_list()
+ expanded_shape = [tensor_shape[0] * multiple, tensor_shape[1]]
+ indices = constant_op.constant(_generate_indices(tensor_shape[0], multiple))
+ return array_ops.scatter_nd(indices, _tile_rows(tensor, multiple),
+ expanded_shape)
+
+ expanded_tensor = tensor
+
+ # Expand rows by factor block_height.
+ if block_height > 1:
+ expanded_tensor = _replicate_rows(tensor, block_height)
+
+ # Transpose and expand by factor block_width. Transpose the result.
+ if block_width > 1:
+ expanded_tensor = array_ops.transpose(
+ _replicate_rows(array_ops.transpose(expanded_tensor), block_width))
+
+ return expanded_tensor
+
+
def _histogram(values, value_range, nbins=100, dtype=dtypes.int32, name=None):
"""Return histogram of values.
diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
index ccde5b4e8a..06d7f97437 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.model_pruning.python import pruning_utils
@@ -26,6 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -43,20 +45,6 @@ class PruningUtilsTest(test.TestCase):
cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value])
self.assertAllEqual(cdf.eval(), cdf_from_histogram.eval())
- def _compare_pooling_methods(self, weights, pooling_kwargs):
- with self.test_session():
- variables.global_variables_initializer().run()
- pooled_weights_tf = array_ops.squeeze(
- nn_ops.pool(
- array_ops.reshape(
- weights,
- [1, weights.get_shape()[0],
- weights.get_shape()[1], 1]), **pooling_kwargs))
- pooled_weights_factorized_pool = pruning_utils.factorized_pool(
- weights, **pooling_kwargs)
- self.assertAllClose(pooled_weights_tf.eval(),
- pooled_weights_factorized_pool.eval())
-
def testHistogram(self):
width = 10
height = 10
@@ -95,26 +83,60 @@ class PruningUtilsTest(test.TestCase):
weights = variable_scope.get_variable("weights", shape=[5, 5, 128, 128])
self._compare_cdf(weights)
- def testFactorizedAvgPool(self):
+
+@parameterized.named_parameters(
+ ("1x1", [1, 1]), ("4x4", [4, 4]), ("6x6", [6, 6]), ("1x4", [1, 4]),
+ ("4x1", [4, 1]), ("1x8", [1, 8]), ("8x1", [8, 1]))
+class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase):
+
+ def _compare_pooling_methods(self, weights, pooling_kwargs):
+ with self.test_session():
+ variables.global_variables_initializer().run()
+ pooled_weights_tf = array_ops.squeeze(
+ nn_ops.pool(
+ array_ops.reshape(
+ weights,
+ [1, weights.get_shape()[0],
+ weights.get_shape()[1], 1]), **pooling_kwargs))
+ pooled_weights_factorized_pool = pruning_utils.factorized_pool(
+ weights, **pooling_kwargs)
+ self.assertAllClose(pooled_weights_tf.eval(),
+ pooled_weights_factorized_pool.eval())
+
+ def _compare_expand_tensor_with_kronecker_product(self, tensor, block_dim):
+ with self.test_session() as session:
+ variables.global_variables_initializer().run()
+ expanded_tensor = pruning_utils.expand_tensor(tensor, block_dim)
+ kronecker_product = pruning_utils.kronecker_product(
+ tensor, array_ops.ones(block_dim))
+ expanded_tensor_val, kronecker_product_val = session.run(
+ [expanded_tensor, kronecker_product])
+ self.assertAllEqual(expanded_tensor_val, kronecker_product_val)
+
+ def testFactorizedAvgPool(self, window_shape):
weights = variable_scope.get_variable("weights", shape=[1024, 2048])
pooling_kwargs = {
- "window_shape": [2, 4],
+ "window_shape": window_shape,
"pooling_type": "AVG",
- "strides": [2, 4],
+ "strides": window_shape,
"padding": "SAME"
}
self._compare_pooling_methods(weights, pooling_kwargs)
- def testFactorizedMaxPool(self):
+ def testFactorizedMaxPool(self, window_shape):
weights = variable_scope.get_variable("weights", shape=[1024, 2048])
pooling_kwargs = {
- "window_shape": [2, 4],
+ "window_shape": window_shape,
"pooling_type": "MAX",
- "strides": [2, 4],
+ "strides": window_shape,
"padding": "SAME"
}
self._compare_pooling_methods(weights, pooling_kwargs)
+ def testExpandTensor(self, block_dim):
+ weights = random_ops.random_normal(shape=[1024, 512])
+ self._compare_expand_tensor_with_kronecker_product(weights, block_dim)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.h b/tensorflow/contrib/nccl/kernels/nccl_manager.h
index 57a96c5d33..09fad35d23 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_manager.h
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager.h
@@ -20,6 +20,13 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+// TODO(rmlarsen): Get rid of this workaround. "gpu_assert" is defined when
+// setting EIGEN_USE_THREADS. But when defining EIGEN_USE_THREADS here,
+// incAtomic and other CUDA specific symbols are no longer recognized.
+#ifndef gpu_assert
+#define gpu_assert(x)
+#endif
+
#include "third_party/nccl/nccl.h"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index 778b710d78..5319a8b655 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -20,6 +20,7 @@ py_library(
"python/training/elastic_average_optimizer.py",
"python/training/external_optimizer.py",
"python/training/ggt.py",
+ "python/training/lars_optimizer.py",
"python/training/lazy_adam_optimizer.py",
"python/training/model_average_optimizer.py",
"python/training/moving_average_optimizer.py",
@@ -365,3 +366,18 @@ py_test(
"@absl_py//absl/testing:parameterized",
],
)
+
+py_test(
+ name = "lars_optimizer_test",
+ srcs = ["python/training/lars_optimizer_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":opt_py",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py
index 9471fb0181..781621dba0 100644
--- a/tensorflow/contrib/opt/__init__.py
+++ b/tensorflow/contrib/opt/__init__.py
@@ -24,6 +24,7 @@ from tensorflow.contrib.opt.python.training.addsign import *
from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import *
from tensorflow.contrib.opt.python.training.elastic_average_optimizer import *
from tensorflow.contrib.opt.python.training.external_optimizer import *
+from tensorflow.contrib.opt.python.training.lars_optimizer import *
from tensorflow.contrib.opt.python.training.ggt import *
from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import *
from tensorflow.contrib.opt.python.training.model_average_optimizer import *
@@ -46,6 +47,7 @@ _allowed_symbols = [
'DelayCompensatedGradientDescentOptimizer',
'DropStaleGradientOptimizer',
'ExternalOptimizerInterface',
+ 'LARSOptimizer',
'LazyAdamOptimizer',
'NadamOptimizer',
'MovingAverageOptimizer',
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
index 5763593b81..bbafd59aae 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
@@ -17,22 +17,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
-
-from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import optimizer
+from tensorflow.python.training import saver
from tensorflow.python.training import session_run_hook
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import data_flow_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import constant_op
LOCAL_VARIABLE_NAME = 'local_center_variable'
GLOBAL_VARIABLE_NAME = 'global_center_variable'
+GLOBAL_STEP = 'global_step'
class ElasticAverageCustomGetter(object):
@@ -52,16 +53,32 @@ class ElasticAverageCustomGetter(object):
with tf.device(
tf.train.replica_device_setter(
worker_device=worker_device,
- ps_device="/job:ps/cpu:0",
+ ps_device="/job:ps",
cluster=cluster)),
tf.variable_scope('',custom_getter=ea_custom_getter):
- hid_w = tf.get_variable(
- initializer=tf.truncated_normal(
- [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
- stddev=1.0 / IMAGE_PIXELS),
- name="hid_w")
- hid_b = tf.get_variable(initializer=tf.zeros([FLAGS.hidden_units]),
- name="hid_b")
+ ...
+ create your model here
+ ...
+ with tf.device(worker_device):
+ opt = tf.train.MomentumOptimizer(...)
+ optimizer = ElasticAverageOptimizer(
+ opt,
+ num_worker=2,
+ moving_rate=0.01, # or use default value
+ communication_period=20,
+ ea_custom_getter=ea_custom_getter)
+ ...
+ train_op = optimizer.apply_gradients(
+ grads_vars,
+ global_step=global_step)
+ ...
+ hooks = [optimizer.make_session_run_hook(is_chief, task_index)]
+ ...
+ with tf.train.MonitoredTrainingSession(master=server.target,
+ is_chief=is_chief,
+ checkpoint_dir=("...),
+ save_checkpoint_secs=600,
+ hooks=hooks) as mon_sess:
"""
def __init__(self, worker_device):
@@ -83,24 +100,40 @@ class ElasticAverageCustomGetter(object):
collections=[ops.GraphKeys.LOCAL_VARIABLES],
*args,
**kwargs)
- global_center_variable = variable_scope.variable(
+ if kwargs['reuse'] == True:
+ return local_var
+ global_center_variable = getter(
name='%s/%s' % (GLOBAL_VARIABLE_NAME, name),
- initial_value=local_var.initialized_value(),
trainable=False,
- collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES],
+ *args,
+ **kwargs)
with ops.device(self._worker_device):
- local_center_variable = variable_scope.variable(
+ local_center_variable = getter(
name='%s/%s' % (LOCAL_VARIABLE_NAME, name),
- initial_value=local_var.initialized_value(),
trainable=False,
- collections=[ops.GraphKeys.LOCAL_VARIABLES])
-
- self._local_map[local_var] = local_center_variable
- self._global_map[local_var] = global_center_variable
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ *args,
+ **kwargs)
+ if kwargs['partitioner'] is None:
+ self._local_map[local_var] = local_center_variable
+ self._global_map[local_var] = global_center_variable
+ else:
+ v_list = list(local_var)
+ for i in range(len(v_list)):
+ self._local_map[v_list[i]] \
+ = list(local_center_variable)[i]
+ self._global_map[v_list[i]] \
+ = list(global_center_variable)[i]
return local_var
else:
- return getter(name, trainable, collections, *args, **kwargs)
+ return getter(
+ name,
+ trainable=trainable,
+ collections=collections,
+ *args,
+ **kwargs)
class ElasticAverageOptimizer(optimizer.Optimizer):
@@ -125,6 +158,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
moving_rate=None,
rho=None,
use_locking=True,
+ synchronous=False,
name='ElasticAverageOptimizer'):
"""Construct a new gradient descent optimizer.
@@ -136,9 +170,16 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
communication_period: An int point value to controls the frequency
of the communication between every worker and the ps.
moving_rate: A floating point value to control the elastic difference.
- rho: the amount of exploration we allow ine the model. The default
+ rho: the amount of exploration we allow in the model. The default
value is moving_rate/learning_rate
+ rho=0.0 is suggested in async mode.
use_locking: If True use locks for update operations.
+ synchronous: Add_sync_queues_and_barrier or not.
+ True: all workers will wait for each other before start training
+ False: worker can start training when its initilization is done,
+ no need to wait for everyone is ready.
+ in case one worker is restarted, it can join and continue
+ training without being blocked.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "ElasticAverageOptimizer".
"""
@@ -148,6 +189,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
self._period = communication_period
self._local_map = ea_custom_getter._local_map
self._global_map = ea_custom_getter._global_map
+ self._synchronous = synchronous
if moving_rate is None:
self._moving_rate = self.BETA / communication_period / num_worker
@@ -241,11 +283,29 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
TypeError: If `grads_and_vars` is malformed.
ValueError: If none of the variables have gradients.
"""
+ global_old = set(n.op.name for n in variables.global_variables())
apply_updates = self._opt.apply_gradients(grads_and_vars)
+ global_new = set(n.op.name for n in variables.global_variables())
with ops.control_dependencies([apply_updates]):
local_update = state_ops.assign_add(
self._local_step, 1, name='local_step_update').op
+ # this is for place the variables created by optimizer to local collection
+ # e.g., AdamOptimizer will create beta as global variables
+ def _adjust_optimizer_variable_collection(opt_vars):
+ g = ops.get_default_graph()
+ idx = 0
+ for _ in range(len(g._collections[ops.GraphKeys.GLOBAL_VARIABLES])):
+ var = g.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)[idx]
+ name = var.op.name
+ if name in opt_vars:
+ ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, var)
+ del g.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)[idx]
+ else:
+ idx += 1
+
+ _adjust_optimizer_variable_collection(global_new - global_old)
+
# update global variables.
def _Update_global_variables():
local_vars = [v for g, v in grads_and_vars if g is not None]
@@ -290,7 +350,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
variables equal to the global center variables before the training begins"""
def _Add_sync_queues_and_barrier(enqueue_after_list):
- """Adds ops to enqueu on all worker queues"""
+ """Adds ops to enqueue on all worker queues"""
sync_queues = [
data_flow_ops.FIFOQueue(
self._num_worker, [dtypes.bool],
@@ -324,6 +384,9 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
init_ops.append(state_ops.assign(lc_var, gc_var))
init_op = control_flow_ops.group(*(init_ops))
+ if self._synchronous == False:
+ return init_op
+
sync_queue_op = _Add_sync_queues_and_barrier([init_op])
return sync_queue_op
@@ -331,6 +394,51 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
"""Creates a hook to handle ElasticAverageOptimizerHook ops such as initialization."""
return _ElasticAverageOptimizerHook(self, is_chief, task_index)
+ def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs):
+ """Create a saver copy global_center_variable to trainable variables
+ Please call this function after all your variables created with
+ ElasticAverageCustomGetter. For evaluations or inference, use this saver
+ during training. It will save the global_center_variable of the trained
+ parameters under the original parameter names.
+ Args:
+ var_list: List of variables to save, as per `Saver()`.
+ If set to None, save all the trainable_variables that have
+ been created before this call.
+ name: The name of the saver.
+ **kwargs: Keyword arguments of `Saver()`.
+ Returns:
+ A `tf.train.Saver` object.
+ Raises:
+ RuntimeError: global_center_variable is empty, please make sure
+ this is called after model created and
+ ElasticAverageCustomGetter is used when declaring you model
+ """
+ if not self._global_map:
+ raise RuntimeError('global_center_variable is empty, please make sure '
+ 'this is called after model created and '
+ 'ElasticAverageCustomGetter is used when declaring '
+ 'you model')
+
+ if var_list is None:
+ var_list = variables.trainable_variables()
+ if not isinstance(var_list, dict):
+ var_list = saver.BaseSaverBuilder.OpListToDict(var_list)
+
+ swapped_var_list = {}
+ for key, var in var_list.items():
+ tensor = var
+
+ if not isinstance(var, list):
+ for tvar in variables.trainable_variables():
+ if tvar.op.name == var.op.name:
+ tensor = self._global_map.get(tvar, var)
+ break
+ else: #partitioned variable
+ tensor = [self._global_map.get(lvar, lvar) for lvar in var]
+
+ swapped_var_list[key] = tensor
+
+ return saver.Saver(swapped_var_list, name=name, **kwargs)
class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook):
@@ -351,3 +459,7 @@ class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook):
if self._is_chief:
self._global_init_op = variables.global_variables_initializer()
self._variable_init_op = self._ea_optimizer.get_init_op(self._task_index)
+
+ def after_create_session(self, session, coord):
+ """Run initialization ops"""
+ session.run(self._variable_init_op) \ No newline at end of file
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
index 5ed8057b86..5bf6a08de1 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
@@ -17,17 +17,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
import portpicker
+from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training import device_setter
from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
from tensorflow.python.training import training
from tensorflow.python.training import training_util
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.training import device_setter
from tensorflow.contrib.opt.python.training.elastic_average_optimizer import \
ElasticAverageOptimizer, ElasticAverageCustomGetter, GLOBAL_VARIABLE_NAME
@@ -59,29 +64,49 @@ def create_local_cluster(num_workers, num_ps, protocol="grpc"):
# Creates the workers and return their sessions, graphs, train_ops.
# Chief worker will update at last
-def _get_workers(num_workers, period, workers, moving_rate):
+def _get_workers(num_workers, period, workers, moving_rate, num_ps=1):
sessions = []
graphs = []
train_ops = []
+ savers = []
for worker_id in range(num_workers):
graph = ops.Graph()
is_chief = (worker_id == 0)
with graph.as_default():
worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
- ea_coustom = ElasticAverageCustomGetter(worker_device=worker_device)
+ ea_custom = ElasticAverageCustomGetter(worker_device=worker_device)
with variable_scope.variable_scope(
- "", custom_getter=ea_coustom), ops.device(
+ "", custom_getter=ea_custom), ops.device(
device_setter.replica_device_setter(
worker_device=worker_device,
ps_device="/job:ps/task:0/cpu:0",
ps_tasks=1)):
- global_step = variables.Variable(0, name="global_step", trainable=False)
+ global_step = training_util.get_or_create_global_step()
var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
var_1 = variable_scope.get_variable(initializer=1.0, name="v1")
+ if num_ps > 1:
+ with variable_scope.variable_scope(
+ "",
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_ps, axis=0),
+ custom_getter=ea_custom), ops.device(
+ device_setter.replica_device_setter(
+ worker_device=worker_device,
+ ps_device="/job:ps/task:0/cpu:0",
+ ps_tasks=num_ps)):
+
+ partition_var = variable_scope.get_variable(
+ 'partition_var',
+ shape=[2, 4],
+ initializer=init_ops.ones_initializer)
+ part_0 = list(partition_var)[0]
+ part_1 = list(partition_var)[1]
with ops.device("/job:worker/task:" + str(worker_id)):
grads_0 = constant_op.constant(-1.0)
grads_1 = constant_op.constant(-1.0)
+ grads_part_0 = constant_op.constant([[-1., -1., -1., -1.]])
+ grads_part_1 = constant_op.constant([[-1., -1., -1., -1.]])
sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
opt = ElasticAverageOptimizer(
@@ -89,12 +114,22 @@ def _get_workers(num_workers, period, workers, moving_rate):
num_worker=num_workers,
moving_rate=moving_rate,
communication_period=period,
- ea_custom_getter=ea_coustom)
- train_op = [
- opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]),
- global_step)
- ]
+ ea_custom_getter=ea_custom)
+ if num_ps == 1:
+ train_op = [
+ opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]),
+ global_step)
+ ]
+ else:
+ train_op = [
+ opt.apply_gradients(([grads_0, var_0],
+ [grads_1, var_1],
+ [grads_part_0, part_0],
+ [grads_part_1, part_1]),
+ global_step)
+ ]
easgd_hook = opt.make_session_run_hook(is_chief, worker_id)
+ saver = opt.swapping_saver()
# Creates MonitoredSession
sess = training.MonitoredTrainingSession(
workers[worker_id].target, hooks=[easgd_hook])
@@ -102,8 +137,9 @@ def _get_workers(num_workers, period, workers, moving_rate):
sessions.append(sess)
graphs.append(graph)
train_ops.append(train_op)
+ savers.append(saver)
- return sessions, graphs, train_ops
+ return sessions, graphs, train_ops, savers
class ElasticAverageOptimizerTest(test.TestCase):
@@ -118,7 +154,7 @@ class ElasticAverageOptimizerTest(test.TestCase):
cluster, workers, _ = create_local_cluster(
num_workers=num_workers, num_ps=num_ps)
- sessions, graphs, train_ops = _get_workers(
+ sessions, graphs, train_ops, savers = _get_workers(
num_workers, communication_period, workers, 1.0)
var_0 = graphs[0].get_tensor_by_name("v0:0")
@@ -158,6 +194,21 @@ class ElasticAverageOptimizerTest(test.TestCase):
self.assertAllEqual(2.0, sessions[0].run(var_0_g))
self.assertAllEqual(3.0, sessions[0].run(var_1_g))
self.assertAllEqual(1, sessions[0].run(global_step))
+ sessions[0].run(train_ops[0])
+
+ # save, data will be global value
+ outfile = os.path.join(test.get_temp_dir(), "model")
+ savers[0].save(sessions[0]._sess._sess._sess._sess,
+ save_path=outfile)
+ ops.reset_default_graph() # restore on a new graph
+ with session.Session() as sess:
+ v0 = variable_scope.get_variable(initializer=0.0, name="v0")
+ v1 = variable_scope.get_variable(initializer=1.0, name="v1")
+ sess.run(variables.local_variables_initializer())
+ saver_opt = saver.Saver(var_list=[v1, v0])
+ saver_opt.restore(sess, outfile)
+ self.assertAllEqual(2.0, sess.run(v0))
+ self.assertAllEqual(3.0, sess.run(v1))
def test2Worker1Period(self):
num_workers = 2
@@ -166,8 +217,8 @@ class ElasticAverageOptimizerTest(test.TestCase):
cluster, workers, _ = create_local_cluster(
num_workers=num_workers, num_ps=num_ps)
- sessions, graphs, train_ops = _get_workers(
- num_workers, communication_period, workers, 0.5)
+ sessions, graphs, train_ops, savers = _get_workers(
+ num_workers, communication_period, workers, 0.5, num_ps=2)
var_0 = graphs[0].get_tensor_by_name("v0:0")
var_1 = graphs[0].get_tensor_by_name("v1:0")
@@ -177,6 +228,9 @@ class ElasticAverageOptimizerTest(test.TestCase):
var_0_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v0:0")
var_1_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v1:0")
+ part_0_g = graphs[0].get_tensor_by_name(
+ GLOBAL_VARIABLE_NAME + "/partition_var/part_0:0")
+
# Verify the initialized value.
self.assertAllEqual(0.0, sessions[0].run(var_0))
self.assertAllEqual(1.0, sessions[0].run(var_1))
@@ -194,22 +248,45 @@ class ElasticAverageOptimizerTest(test.TestCase):
self.assertAllEqual(1.75, sessions[0].run(var_1_g))
self.assertAllEqual(0.75, sessions[1].run(var_0_1))
self.assertAllEqual(1.75, sessions[1].run(var_1_1))
+ # part_0 of global_center copy
+ part_0_g = sessions[0].run(part_0_g)
+
+ outfile = os.path.join(test.get_temp_dir(), "model")
+ savers[0].save(sessions[0]._sess._sess._sess._sess,
+ save_path=outfile)
+
+ # verify restore of partitioned_variables
+ ops.reset_default_graph() # restore on a new graph
+ g = ops.get_default_graph()
+ with session.Session() as sess, g.as_default():
+ with variable_scope.variable_scope(
+ "",
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_ps, axis=0)):
+ partition_var = variable_scope.get_variable(
+ 'partition_var',
+ shape=[2, 4],
+ initializer=init_ops.ones_initializer)
+ s = saver.Saver(var_list=[partition_var])
+ s.restore(sess, outfile)
+ part_0 = g.get_tensor_by_name('partition_var/part_0:0')
+ self.assertAllEqual(part_0_g, sess.run(part_0))
def testPS2TasksWithClusterSpecClass(self):
cluster_spec = server_lib.ClusterSpec({
"ps": ["ps0:2222", "ps1:2222"],
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
})
- ea_coustom = ElasticAverageCustomGetter(worker_device="/job:worker/task:0")
+ ea_custom = ElasticAverageCustomGetter(worker_device="/job:worker/task:0")
from tensorflow.python.training import device_setter
with ops.device(
device_setter.replica_device_setter(cluster=cluster_spec,
worker_device="/job:worker/task:0",
ps_device="/job:ps")), \
- variable_scope.variable_scope("", custom_getter=ea_coustom):
+ variable_scope.variable_scope("", custom_getter=ea_custom):
v = variable_scope.get_variable(initializer=[1, 2], name="v")
w = variable_scope.get_variable(initializer=[2, 1], name="w")
- v_g, w_g = ea_coustom._global_map[v], ea_coustom._global_map[w]
+ v_g, w_g = ea_custom._global_map[v], ea_custom._global_map[w]
self.assertDeviceEqual("/job:worker/task:0", v.device)
self.assertDeviceEqual("job:ps/task:0", v_g.device)
self.assertDeviceEqual("/job:worker/task:0", w.device)
diff --git a/tensorflow/contrib/opt/python/training/lars_optimizer.py b/tensorflow/contrib/opt/python/training/lars_optimizer.py
new file mode 100644
index 0000000000..a8dafd9a4c
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/lars_optimizer.py
@@ -0,0 +1,164 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Layer-wise Adaptive Rate Scaling optimizer for large-batch training."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import training_ops
+
+
+class LARSOptimizer(optimizer.Optimizer):
+ """Layer-wise Adaptive Rate Scaling for large batch training.
+
+ Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
+ I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
+
+ Implements the LARS learning rate scheme presented in the paper above. This
+ optimizer is useful when scaling the batch size to up to 32K without
+ significant performance degradation. It is recommended to use the optimizer
+ in conjunction with:
+ - Gradual learning rate warm-up
+ - Linear learning rate scaling
+ - Poly rule learning rate decay
+
+ Note, LARS scaling is currently only enabled for dense tensors. Sparse tensors
+ use the default momentum optimizer.
+ """
+
+ def __init__(
+ self,
+ learning_rate,
+ momentum=0.9,
+ weight_decay=0.0001,
+ # The LARS coefficient is a hyperparameter
+ eeta=0.001,
+ epsilon=0.0,
+ name="LARSOptimizer",
+ # Enable skipping variables from LARS scaling.
+ # TODO(sameerkm): Enable a direct mechanism to pass a
+ # subset of variables to the optimizer.
+ skip_list=None,
+ use_nesterov=False):
+ """Construct a new LARS Optimizer.
+
+ Args:
+ learning_rate: A `Tensor` or floating point value. The base learning rate.
+ momentum: A floating point value. Momentum hyperparameter.
+ weight_decay: A floating point value. Weight decay hyperparameter.
+ eeta: LARS coefficient as used in the paper. Dfault set to LARS
+ coefficient from the paper. (eeta / weight_decay) determines the highest
+ scaling factor in LARS.
+ epsilon: Optional epsilon parameter to be set in models that have very
+ small gradients. Default set to 0.0.
+ name: Optional name prefix for variables and ops created by LARSOptimizer.
+ skip_list: List of strings to enable skipping variables from LARS scaling.
+ If any of the strings in skip_list is a subset of var.name, variable
+ 'var' is skipped from LARS scaling. For a typical classification model
+ with batch normalization, the skip_list is ['batch_normalization',
+ 'bias']
+ use_nesterov: when set to True, nesterov momentum will be enabled
+
+ Raises:
+ ValueError: If a hyperparameter is set to a non-sensical value.
+ """
+ if momentum < 0.0:
+ raise ValueError("momentum should be positive: %s" % momentum)
+ if weight_decay < 0.0:
+ raise ValueError("weight_decay should be positive: %s" % weight_decay)
+ super(LARSOptimizer, self).__init__(use_locking=False, name=name)
+
+ self._learning_rate = learning_rate
+ self._momentum = momentum
+ self._weight_decay = weight_decay
+ self._eeta = eeta
+ self._epsilon = epsilon
+ self._name = name
+ self._skip_list = skip_list
+ self._use_nesterov = use_nesterov
+
+ def _create_slots(self, var_list):
+ for v in var_list:
+ self._zeros_slot(v, "momentum", self._name)
+
+ def compute_lr(self, grad, var):
+ scaled_lr = self._learning_rate
+ if self._skip_list is None or not any(v in var.name
+ for v in self._skip_list):
+ w_norm = linalg_ops.norm(var, ord=2)
+ g_norm = linalg_ops.norm(grad, ord=2)
+ trust_ratio = array_ops.where(
+ math_ops.greater(w_norm, 0),
+ array_ops.where(
+ math_ops.greater(g_norm, 0),
+ (self._eeta * w_norm /
+ (g_norm + self._weight_decay * w_norm + self._epsilon)), 1.0),
+ 1.0)
+ scaled_lr = self._learning_rate * trust_ratio
+ return scaled_lr
+
+ def _apply_dense(self, grad, var):
+ scaled_lr = self.compute_lr(grad, var)
+ mom = self.get_slot(var, "momentum")
+ return training_ops.apply_momentum(
+ var,
+ mom,
+ scaled_lr,
+ grad,
+ self._momentum,
+ use_locking=False,
+ use_nesterov=self._use_nesterov)
+
+ def _resource_apply_dense(self, grad, var):
+ scaled_lr = self.compute_lr(grad, var)
+ mom = self.get_slot(var, "momentum")
+ return training_ops.resource_apply_momentum(
+ var.handle,
+ mom.handle,
+ scaled_lr,
+ grad,
+ self._momentum,
+ use_locking=False,
+ use_nesterov=self._use_nesterov)
+
+ # Fallback to momentum optimizer for sparse tensors
+ def _apply_sparse(self, grad, var):
+ mom = self.get_slot(var, "momentum")
+ return training_ops.sparse_apply_momentum(
+ var,
+ mom,
+ math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
+ grad.values,
+ grad.indices,
+ math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov).op
+
+ def _resource_apply_sparse(self, grad, var, indices):
+ mom = self.get_slot(var, "momentum")
+ return training_ops.resource_sparse_apply_momentum(
+ var.handle,
+ mom.handle,
+ math_ops.cast(self._learning_rate_tensor, grad.dtype),
+ grad,
+ indices,
+ math_ops.cast(self._momentum_tensor, grad.dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov)
diff --git a/tensorflow/contrib/opt/python/training/lars_optimizer_test.py b/tensorflow/contrib/opt/python/training/lars_optimizer_test.py
new file mode 100644
index 0000000000..d94249b994
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/lars_optimizer_test.py
@@ -0,0 +1,127 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0. Licensed to the Apache
+# Software Foundation. You may not use this file except in compliance with the
+# License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test for Layer-wise Adaptive Rate Scaling optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.opt.python.training import lars_optimizer as lo
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class LARSOptimizerTest(test.TestCase):
+
+ def testLARSGradientOneStep(self):
+ for _ in range(10):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ with self.test_session() as sess:
+ shape = [3, 3]
+ var_np = np.ones(shape)
+ grad_np = np.ones(shape)
+ lr_np = 0.1
+ m_np = 0.9
+ wd_np = 0.1
+ ep_np = 1e-5
+ eeta = 0.1
+ vel_np = np.zeros(shape)
+
+ var = variables.Variable(var_np, dtype=dtype)
+ grad = variables.Variable(grad_np, dtype=dtype)
+ opt = lo.LARSOptimizer(
+ learning_rate=lr_np,
+ momentum=m_np,
+ weight_decay=wd_np,
+ eeta=eeta,
+ epsilon=ep_np)
+
+ step = opt.apply_gradients([(grad, var)])
+ variables.global_variables_initializer().run()
+
+ pre_var = sess.run(var)
+ pre_vel = sess.run(opt.get_slot(var, 'momentum'))
+ self.assertAllClose(var_np, pre_var)
+ self.assertAllClose(vel_np, pre_vel)
+
+ step.run()
+ post_var = sess.run(var)
+ post_vel = sess.run(opt.get_slot(var, 'momentum'))
+
+ w_norm = np.linalg.norm(var_np.flatten(), ord=2)
+ g_norm = np.linalg.norm(grad_np.flatten(), ord=2)
+ trust_ratio = eeta * w_norm / (g_norm + wd_np * w_norm + ep_np)
+ scaled_lr = lr_np * trust_ratio
+
+ vel_np = m_np * vel_np + grad_np
+ var_np -= scaled_lr * vel_np
+
+ self.assertAllClose(var_np, post_var)
+ self.assertAllClose(vel_np, post_vel)
+
+ def testLARSGradientMultiStep(self):
+ for _ in range(10):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ with self.test_session() as sess:
+ shape = [3, 3]
+ var_np = np.ones(shape)
+ grad_np = np.ones(shape)
+ lr_np = 0.1
+ m_np = 0.9
+ wd_np = 0.1
+ ep_np = 1e-5
+ eeta = 0.1
+ vel_np = np.zeros(shape)
+
+ var = variables.Variable(var_np, dtype=dtype)
+ grad = variables.Variable(grad_np, dtype=dtype)
+ opt = lo.LARSOptimizer(
+ learning_rate=lr_np,
+ momentum=m_np,
+ eeta=eeta,
+ weight_decay=wd_np,
+ epsilon=ep_np)
+
+ step = opt.apply_gradients([(grad, var)])
+ variables.global_variables_initializer().run()
+
+ pre_var = sess.run(var)
+ pre_vel = sess.run(opt.get_slot(var, 'momentum'))
+ self.assertAllClose(var_np, pre_var)
+ self.assertAllClose(vel_np, pre_vel)
+
+ for _ in range(10):
+ step.run()
+
+ post_var = sess.run(var)
+ post_vel = sess.run(opt.get_slot(var, 'momentum'))
+
+ w_norm = np.linalg.norm(var_np.flatten(), ord=2)
+ g_norm = np.linalg.norm(grad_np.flatten(), ord=2)
+ trust_ratio = eeta * w_norm / (g_norm + wd_np * w_norm + ep_np)
+ scaled_lr = lr_np * trust_ratio
+
+ vel_np = m_np * vel_np + grad_np
+ var_np -= scaled_lr * vel_np
+
+ self.assertAllClose(var_np, post_var)
+ self.assertAllClose(vel_np, post_vel)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD
index 23363617ed..499fec4ffa 100644
--- a/tensorflow/contrib/quantize/BUILD
+++ b/tensorflow/contrib/quantize/BUILD
@@ -244,7 +244,9 @@ py_test(
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:init_ops",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
+ "//tensorflow/python:training",
],
)
diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py
index 2944f964c7..484493f1b2 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph.py
@@ -59,6 +59,10 @@ def _create_graph(input_graph=None,
if input_graph is None:
input_graph = ops.get_default_graph()
+
+ # Add check to see if graph has training ops, if so provide error message and
+ # exit
+ _check_for_training_ops(input_graph)
with input_graph.as_default():
fold_batch_norms.FoldBatchNorms(
input_graph,
@@ -78,6 +82,9 @@ def create_training_graph(input_graph=None, quant_delay=0):
Variables added by the rewrite get added to the global variables collection.
+ This function must be invoked prior to insertion of gradient ops in a graph
+ as quantization should be modeled in both forward and backward passes.
+
The graph has fake quantization ops inserted to simulate the error
introduced by quantization. Since the graph is transformed in place,
the expected behavior of previously held references to nodes and tensors may
@@ -104,7 +111,6 @@ def create_training_graph(input_graph=None, quant_delay=0):
# Currently the values below are hardcoded for mobilenetV1 on imagenet
# Please use the experimental API if you need to tune these values.
freeze_bn_delay = None
-
_create_graph(
input_graph=input_graph,
is_training=True,
@@ -141,6 +147,9 @@ def experimental_create_training_graph(input_graph=None,
scope=None):
"""Rewrites a training input_graph in place for simulated quantization.
+ This function must be invoked prior to insertion of gradient ops in a graph
+ as quantization should be modeled in both forward and backward passes.
+
Variables added by the rewrite get added to the global variables collection.
This function has additional experimental options not (yet) available to
@@ -226,3 +235,45 @@ def experimental_create_eval_graph(input_graph=None,
activation_bits=activation_bits,
quant_delay=quant_delay,
scope=scope)
+
+
+def _check_for_training_ops(g):
+ """Check if training ops are present in the graph.
+
+ Args:
+ g: The tf.Graph on which the check for training ops needs to be
+ performed.
+
+ Raises:
+ ValueError: If a training op is seen in the graph;
+ """
+
+ # The list here is obtained
+ # from https://www.tensorflow.org/api_docs/cc/group/training-ops
+ training_ops = frozenset([
+ 'ApplyAdagrad', 'ApplyAdagradDA', 'ApplyAdam', 'ApplyAddSign',
+ 'ApplyCenteredRMSProp', 'ApplyFtrl', 'ApplyFtrlV2',
+ 'ApplyGradientDescent', 'ApplyMomentum', 'ApplyPowerSign',
+ 'ApplyProximalAdagrad', 'ApplyProximalGradientDescent', 'ApplyRMSProp',
+ 'ResourceApplyAdadelta', 'ResourceApplyAdagrad', 'ResourceApplyAdagradDA',
+ 'ResourceApplyAdam', 'ResourceApplyAddSign',
+ 'ResourceApplyCenteredRMSProp', 'ResourceApplyFtrl',
+ 'ResourceApplyFtrlV2', 'ResourceApplyGradientDescent',
+ 'ResourceApplyMomentum', 'ResourceApplyPowerSign',
+ 'ResourceApplyProximalAdagrad', 'ResourceApplyProximalGradientDescent',
+ 'ResourceApplyRMSProp', 'ResourceSparseApplyAdadelta',
+ 'ResourceSparseApplyAdagrad', 'ResourceSparseApplyAdagradDA',
+ 'ResourceSparseApplyCenteredRMSProp', 'ResourceSparseApplyFtrl',
+ 'ResourceSparseApplyFtrlV2', 'ResourceSparseApplyMomentum',
+ 'ResourceSparseApplyProximalAdagrad',
+ 'ResourceSparseApplyProximalGradientDescent',
+ 'ResourceSparseApplyRMSProp', 'SparseApplyAdadelta', 'SparseApplyAdagrad',
+ 'SparseApplyAdagradDA', 'SparseApplyCenteredRMSProp', 'SparseApplyFtrl',
+ 'SparseApplyFtrlV2', 'SparseApplyMomentum', 'SparseApplyProximalAdagrad',
+ 'SparseApplyProximalGradientDescent', 'SparseApplyRMSProp'
+ ])
+
+ op_types = set([op.type for op in g.get_operations()])
+ train_op_list = op_types.intersection(training_ops)
+ if train_op_list:
+ raise ValueError('Training op found in graph, exiting %s' % train_op_list)
diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py
index 54faf582f1..e80d2183a6 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py
@@ -20,10 +20,12 @@ from __future__ import print_function
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.quantize.python import quantize_graph
+from tensorflow.python import training
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
@@ -145,6 +147,19 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
self.assertTrue(('int64_val: %i' % quant_delay) in const_value)
self.assertTrue(quant_delay_found)
+ def testTrainingOpsCheck(self):
+ self._RunTestOverTrainingRewrites(self._TestTrainingOpsCheck)
+
+ def _TestTrainingOpsCheck(self, rewrite_fn):
+ with ops.Graph().as_default():
+ output = self._ConvLayer()
+ output_scalar = math_ops.reduce_sum(output)
+ loss = math_ops.square(output_scalar - 1)
+ opt = training.gradient_descent.GradientDescentOptimizer(0.0001)
+ opt.minimize(loss)
+ with self.assertRaisesRegexp(ValueError, 'Training op found in graph'):
+ rewrite_fn()
+
def testWeightBits(self):
self._RunTestOverExperimentalRewrites(self._TestWeightBits)
diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py
index cb437f2a2f..026bf08ced 100644
--- a/tensorflow/contrib/rnn/__init__.py
+++ b/tensorflow/contrib/rnn/__init__.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""RNN Cells and additional RNN operations.
-See @{$python/contrib.rnn} guide.
+See [Contrib RNN](https://tensorflow.org/api_guides/python/contrib.rnn) guide.
<!--From core-->
@@RNNCell
diff --git a/tensorflow/contrib/seq2seq/__init__.py b/tensorflow/contrib/seq2seq/__init__.py
index a7279bc339..674f7cdb22 100644
--- a/tensorflow/contrib/seq2seq/__init__.py
+++ b/tensorflow/contrib/seq2seq/__init__.py
@@ -15,7 +15,9 @@
"""Ops for building neural network seq2seq decoders and losses.
-See the @{$python/contrib.seq2seq} guide.
+See the
+[Contrib Seq2seq](https://tensorflow.org/api_guides/python/contrib.seq2seq)
+guide.
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py
index 6a2080bcec..d088e74434 100644
--- a/tensorflow/contrib/signal/__init__.py
+++ b/tensorflow/contrib/signal/__init__.py
@@ -14,7 +14,9 @@
# ==============================================================================
"""Signal processing operations.
-See the @{$python/contrib.signal} guide.
+See the
+[Contrib Signal](https://tensorflow.org/api_guides/python/contrib.signal)
+guide.
@@frame
@@hamming_window
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
index 164f3e58e6..22d6e499d2 100644
--- a/tensorflow/contrib/tensor_forest/BUILD
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -515,6 +515,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":client_lib",
+ "//tensorflow/contrib/estimator:head",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index 8fa0b3ada9..db970deff5 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib import layers
+from tensorflow.contrib.estimator.python.estimator import head as core_head_lib
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
@@ -25,7 +26,6 @@ from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_f
from tensorflow.contrib.tensor_forest.client import eval_metrics
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.estimator import estimator as core_estimator
-from tensorflow.python.estimator.canned import head as core_head_lib
from tensorflow.python.estimator.export.export_output import PredictOutput
from tensorflow.python.feature_column import feature_column as fc_core
from tensorflow.python.framework import ops
@@ -130,17 +130,23 @@ def _get_default_head(params, weights_name, output_type, name=None):
head_name=name)
else:
if params.regression:
- return core_head_lib._regression_head( # pylint:disable=protected-access
+ return core_head_lib.regression_head(
weight_column=weights_name,
label_dimension=params.num_outputs,
name=name,
loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
else:
- return core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint:disable=protected-access
- n_classes=params.num_classes,
- weight_column=weights_name,
- name=name,
- loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ if params.num_classes == 2:
+ return core_head_lib.binary_classification_head(
+ weight_column=weights_name,
+ name=name,
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ else:
+ return core_head_lib.multi_class_head(
+ n_classes=params.num_classes,
+ weight_column=weights_name,
+ name=name,
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
def get_model_fn(params,
graph_builder_class,
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc
index 6cb2c881e2..7716536ba4 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc
@@ -54,17 +54,24 @@ InequalityDecisionNodeEvaluator::InequalityDecisionNodeEvaluator(
CHECK(safe_strto32(test.feature_id().id().value(), &feature_num_))
<< "Invalid feature ID: [" << test.feature_id().id().value() << "]";
threshold_ = test.threshold().float_value();
- include_equals_ =
- test.type() == decision_trees::InequalityTest::LESS_OR_EQUAL;
+ _test_type = test.type();
}
int32 InequalityDecisionNodeEvaluator::Decide(
const std::unique_ptr<TensorDataSet>& dataset, int example) const {
const float val = dataset->GetExampleValue(example, feature_num_);
- if (val < threshold_ || (include_equals_ && val == threshold_)) {
- return left_child_id_;
- } else {
- return right_child_id_;
+ switch (_test_type) {
+ case decision_trees::InequalityTest::LESS_OR_EQUAL:
+ return val <= threshold_ ? left_child_id_ : right_child_id_;
+ case decision_trees::InequalityTest::LESS_THAN:
+ return val < threshold_ ? left_child_id_ : right_child_id_;
+ case decision_trees::InequalityTest::GREATER_OR_EQUAL:
+ return val >= threshold_ ? left_child_id_ : right_child_id_;
+ case decision_trees::InequalityTest::GREATER_THAN:
+ return val > threshold_ ? left_child_id_ : right_child_id_;
+ default:
+ LOG(ERROR) << "Unknown split test type: " << _test_type;
+ return -1;
}
}
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h
index 3db351c328..6497787f84 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h
@@ -55,9 +55,7 @@ class InequalityDecisionNodeEvaluator : public BinaryDecisionNodeEvaluator {
protected:
int32 feature_num_;
float threshold_;
-
- // If decision is '<=' as opposed to '<'.
- bool include_equals_;
+ ::tensorflow::decision_trees::InequalityTest_Type _test_type;
};
// Evaluator for splits with multiple weighted features.
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc
index af5cf72a3c..3db1335563 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc
@@ -60,6 +60,40 @@ TEST(InequalityDecisionNodeEvaluatorTest, TestStrictlyLess) {
ASSERT_EQ(eval->Decide(dataset, 4), 1);
}
+TEST(InequalityDecisionNodeEvaluatorTest, TestGreaterOrEqual) {
+ InequalityTest test;
+ test.mutable_feature_id()->mutable_id()->set_value("0");
+ test.mutable_threshold()->set_float_value(3.0);
+ test.set_type(InequalityTest::GREATER_OR_EQUAL);
+ std::unique_ptr<InequalityDecisionNodeEvaluator> eval(
+ new InequalityDecisionNodeEvaluator(test, 0, 1));
+
+ std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
+ new tensorflow::tensorforest::TestableDataSet(
+ {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1));
+
+ ASSERT_EQ(eval->Decide(dataset, 2), 1);
+ ASSERT_EQ(eval->Decide(dataset, 3), 0);
+ ASSERT_EQ(eval->Decide(dataset, 4), 0);
+}
+
+TEST(InequalityDecisionNodeEvaluatorTest, TestStrictlyGreater) {
+ InequalityTest test;
+ test.mutable_feature_id()->mutable_id()->set_value("0");
+ test.mutable_threshold()->set_float_value(3.0);
+ test.set_type(InequalityTest::GREATER_THAN);
+ std::unique_ptr<InequalityDecisionNodeEvaluator> eval(
+ new InequalityDecisionNodeEvaluator(test, 0, 1));
+
+ std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
+ new tensorflow::tensorforest::TestableDataSet(
+ {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1));
+
+ ASSERT_EQ(eval->Decide(dataset, 2), 1);
+ ASSERT_EQ(eval->Decide(dataset, 3), 1);
+ ASSERT_EQ(eval->Decide(dataset, 4), 0);
+}
+
TEST(MatchingDecisionNodeEvaluatorTest, Basic) {
MatchingValuesTest test;
test.mutable_feature_id()->mutable_id()->set_value("0");
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index fc0d22d112..26236a0435 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -387,17 +387,19 @@ cuda_py_tests(
name = "tf_trt_integration_test",
srcs = [
"test/base_test.py",
- # "test/batch_matmul_test.py",
- # "test/biasadd_matmul_test.py",
- # "test/binary_tensor_weight_broadcast_test.py", # Blocked by trt4 installation
- # "test/concatenation_test.py", # Blocked by trt4 installation
+ "test/batch_matmul_test.py",
+ "test/biasadd_matmul_test.py",
+ "test/binary_tensor_weight_broadcast_test.py",
+ "test/concatenation_test.py",
"test/const_broadcast_test.py",
+ "test/manual_test.py",
+ "test/memory_alignment_test.py",
"test/multi_connection_neighbor_engine_test.py",
"test/neighboring_engine_test.py",
- # "test/unary_test.py", # Blocked by trt4 installation
- # "test/vgg_block_nchw_test.py",
- # "test/vgg_block_test.py",
- "test/memory_alignment_test.py",
+ "test/rank_two_test.py",
+ "test/unary_test.py",
+ "test/vgg_block_nchw_test.py",
+ "test/vgg_block_test.py",
],
additional_deps = [
":tf_trt_integration_test_base",
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 35fa590254..863074e773 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -155,12 +155,22 @@ tensorflow::Status ValidateInputProperties(const PartialTensorShape& shape,
for (int d = 1; d < shape.dims(); ++d) {
if (shape.dim_size(d) < 0) {
return tensorflow::errors::InvalidArgument(
- "Input tensor has a unknown non-batch dimemension at dim ", d);
+ "Input tensor with shape ", shape.DebugString(),
+ " has an unknown non-batch dimemension at dim ", d);
}
}
return Status::OK();
}
+string DebugString(const nvinfer1::Dims& dims) {
+ string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
+ for (int i = 0; i < nvinfer1::Dims::MAX_DIMS; ++i) {
+ StrAppend(&out, dims.d[i], ",");
+ }
+ StrAppend(&out, ")");
+ return out;
+}
+
// Return whether or not the broadcast is feasible;
bool TensorRTGetBroadcastShape(const nvinfer1::Dims& operand_l,
const bool operand_l_is_tensor,
@@ -353,6 +363,13 @@ class TRT_ShapedWeights {
// Default converter
operator nvinfer1::Weights() const { return GetWeightsForTRT(); }
+ string DebugString() const {
+ return StrCat(
+ "TRT_ShapedWeights(shape=", convert::DebugString(shape_), ", type=",
+ type_, ", values=", reinterpret_cast<uintptr_t>(values_),
+ ", empty_weight_flag=", empty_weight_flag_, ")");
+ }
+
// TODO(aaroey): make these private.
nvinfer1::Dims shape_;
tensorflow::DataType type_;
@@ -367,11 +384,14 @@ class TRT_TensorOrWeights {
public:
explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor)
: tensor_(tensor), weights_(DT_FLOAT), variant_(TRT_NODE_TENSOR) {}
+
explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights)
: tensor_(nullptr), weights_(weights), variant_(TRT_NODE_WEIGHTS) {}
+
// TODO(aaroey): use rvalue reference.
TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs)
: tensor_(rhs.tensor_), weights_(rhs.weights_), variant_(rhs.variant_) {}
+
~TRT_TensorOrWeights() {}
bool is_tensor() const { return variant_ == TRT_NODE_TENSOR; }
@@ -381,18 +401,22 @@ class TRT_TensorOrWeights {
CHECK(is_tensor());
return tensor_;
}
+
const nvinfer1::ITensor* tensor() const {
CHECK(is_tensor());
return tensor_;
}
+
TRT_ShapedWeights& weights() {
CHECK(is_weights());
return weights_;
}
+
const TRT_ShapedWeights& weights() const {
CHECK(is_weights());
return weights_;
}
+
nvinfer1::Dims shape() const {
if (is_tensor()) {
return tensor()->getDimensions();
@@ -401,6 +425,18 @@ class TRT_TensorOrWeights {
}
}
+ string DebugString() const {
+ string output = "TRT_TensorOrWeights(type=";
+ if (is_tensor()) {
+ StrAppend(&output, "tensor @", reinterpret_cast<uintptr_t>(tensor_),
+ ", shape=", convert::DebugString(tensor_->getDimensions()));
+ } else {
+ StrAppend(&output, "weights=", weights_.DebugString());
+ }
+ StrAppend(&output, ")");
+ return output;
+ }
+
private:
nvinfer1::ITensor* tensor_;
TRT_ShapedWeights weights_;
@@ -555,7 +591,7 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights,
}
void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
- TRT_ShapedWeights* oweights, int num_groups) {
+ TRT_ShapedWeights* oweights, const int num_groups) {
CHECK_EQ(iweights.type_, oweights->type_);
CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
// K indexes over output channels, C over input channels, and R and S over the
@@ -563,13 +599,13 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
const int r = iweights.shape_.d[0];
const int s = iweights.shape_.d[1];
// TRT requires GKcRS, while TF depthwise has RSCK where c=1, C=G
- VLOG(2) << "num_groups: " << num_groups;
const int c = iweights.shape_.d[2] / num_groups;
- VLOG(2) << "c" << iweights.shape_.d[2] << " then " << c;
const int k = iweights.shape_.d[3] * num_groups;
- VLOG(2) << "k" << iweights.shape_.d[3] << " then " << k;
- VLOG(2) << "r" << iweights.shape_.d[0] << " then " << r;
- VLOG(2) << "s" << iweights.shape_.d[1] << " then " << s;
+ VLOG(2) << "num_groups: " << num_groups
+ << "c" << iweights.shape_.d[2] << " then " << c
+ << "k" << iweights.shape_.d[3] << " then " << k
+ << "r" << iweights.shape_.d[0] << " then " << r
+ << "s" << iweights.shape_.d[1] << " then " << s;
oweights->shape_.d[0] = k / num_groups;
oweights->shape_.d[1] = c * num_groups;
oweights->shape_.d[2] = r;
@@ -607,63 +643,15 @@ using OpConverter =
std::vector<TRT_TensorOrWeights>*)>;
class Converter {
- // TODO(aaroey): fix the order of members.
- std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
- std::unordered_map<string, OpConverter> op_registry_;
- OpConverter plugin_converter_;
- nvinfer1::INetworkDefinition* trt_network_;
- std::list<std::vector<uint8_t>> temp_bufs_;
- // TODO(aaroey): inline the definition of TRTWeightStore here, and add APIs to
- // operate the stored weights instead of operating it directly.
- TRTWeightStore* weight_store_;
- bool fp16_;
- void register_op_converters();
- tensorflow::Status get_inputs(const tensorflow::NodeDef& node_def,
- std::vector<TRT_TensorOrWeights>* inputs) {
- for (auto const& input_name : node_def.input()) {
- /*************************************************************************
- * TODO(jie): handle case 1) here.
- * Normalizes the inputs and extracts associated metadata:
- * 1) Inputs can contain a colon followed by a suffix of characters.
- * That suffix may be a single number (e.g. inputName:1) or several
- * word characters separated from a number by a colon
- * (e.g. inputName:foo:1). The
- * latter case is used to denote inputs and outputs of functions.
- * 2) Control dependency inputs contain caret at the beginning and we
- * remove this and annotate the edge as a control dependency.
- ************************************************************************/
- // skip control nodes
- if (input_name[0] == '^') continue;
- string name = input_name;
- auto first = name.find_first_of(':');
- // TODO(aaroey): why removing the colon but not the zero? A bug?
- if (first != string::npos && first + 2 == name.size() &&
- name[first + 1] == '0')
- name.erase(first);
-
- VLOG(2) << "retrieve input: " << name;
- if (trt_tensors_.count(name)) {
- inputs->push_back(trt_tensors_.at(name));
- } else {
- // TODO(aaroey): this should not happen, make it a CHECK.
- // TODO(aaroey): use StrCat for pattern like this.
- string msg("Node ");
- StrAppend(&msg, node_def.name(), " should have an input named '", name,
- "' but it is not available");
- LOG(ERROR) << msg;
- return tensorflow::errors::InvalidArgument(msg);
- }
- }
- return tensorflow::Status::OK();
- }
-
public:
explicit Converter(nvinfer1::INetworkDefinition* trt_network,
TRTWeightStore* ws, bool fp16)
: trt_network_(trt_network), weight_store_(ws), fp16_(fp16) {
this->register_op_converters();
}
+
TRTWeightStore* weight_store() { return weight_store_; }
+
TRT_ShapedWeights get_temp_weights(tensorflow::DataType type,
nvinfer1::Dims shape) {
TRT_ShapedWeights weights(type, nullptr, shape);
@@ -672,8 +660,10 @@ class Converter {
weights.SetValues(weight_store_->store_.back().data());
return weights;
}
+
// TODO(aaroey): fix all the namings.
bool isFP16() { return fp16_; }
+
TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) {
return this->get_temp_weights(weights.type_, weights.shape_);
}
@@ -684,7 +674,6 @@ class Converter {
const string& op = node_def.op();
std::vector<TRT_TensorOrWeights> outputs;
if (PluginFactoryTensorRT::GetInstance()->IsPlugin(op)) {
- // TODO(aaroey): plugin_converter_ is not set, fix it.
TF_RETURN_IF_ERROR(plugin_converter_(*this, node_def, inputs, &outputs));
} else {
if (!op_registry_.count(op)) {
@@ -702,7 +691,8 @@ class Converter {
if (output.is_tensor()) {
output.tensor()->setName(output_name.c_str());
}
- VLOG(2) << "Write out tensor: " << output_name;
+ VLOG(2) << "Adding out tensor " << output_name << ": "
+ << output.DebugString();
if (!trt_tensors_.insert({output_name, output}).second) {
return tensorflow::errors::AlreadyExists(
"Output tensor already exists for op: " + op);
@@ -751,6 +741,63 @@ class Converter {
layer->setReshapeDimensions(reshape_dims);
return layer->getOutput(0);
}
+
+ private:
+ std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
+ std::unordered_map<string, OpConverter> op_registry_;
+ OpConverter plugin_converter_;
+ nvinfer1::INetworkDefinition* trt_network_;
+ std::list<std::vector<uint8_t>> temp_bufs_;
+
+ // TODO(aaroey): inline the definition of TRTWeightStore here, and add APIs to
+ // operate the stored weights instead of operating it directly.
+ TRTWeightStore* weight_store_;
+
+ bool fp16_;
+
+ void register_op_converters();
+
+ tensorflow::Status get_inputs(const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights>* inputs) {
+ for (auto const& input_name : node_def.input()) {
+ /*************************************************************************
+ * TODO(jie): handle case 1) here.
+ * Normalizes the inputs and extracts associated metadata:
+ * 1) Inputs can contain a colon followed by a suffix of characters.
+ * That suffix may be a single number (e.g. inputName:1) or several
+ * word characters separated from a number by a colon
+ * (e.g. inputName:foo:1). The
+ * latter case is used to denote inputs and outputs of functions.
+ * 2) Control dependency inputs contain caret at the beginning and we
+ * remove this and annotate the edge as a control dependency.
+ ************************************************************************/
+ // skip control nodes
+ if (input_name[0] == '^') continue;
+ string name = input_name;
+ auto first = name.find_first_of(':');
+ // TODO(aaroey): why removing the colon but not the zero? A bug?
+ // TODO(aaroey): use TensorId
+ if (first != string::npos && first + 2 == name.size() &&
+ name[first + 1] == '0') {
+ name.erase(first);
+ }
+
+ if (trt_tensors_.count(name)) {
+ TRT_TensorOrWeights& input = trt_tensors_.at(name);
+ inputs->push_back(input);
+ VLOG(2) << "Retrieved input " << name << ": " << input.DebugString();
+ } else {
+ // TODO(aaroey): this should not happen, make it a CHECK.
+ // TODO(aaroey): use StrCat for pattern like this.
+ string msg("Node ");
+ StrAppend(&msg, node_def.name(), " should have an input named '", name,
+ "' but it is not available");
+ LOG(ERROR) << msg;
+ return tensorflow::errors::InvalidArgument(msg);
+ }
+ }
+ return tensorflow::Status::OK();
+ }
};
TRT_ShapedWeights ConvertFP32ToFP16(Converter& ctx,
@@ -1187,17 +1234,11 @@ tensorflow::Status ConvertConv2DHelper(
VLOG(2) << "groups count: " << num_groups;
TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
-
- VLOG(2) << "weight shape: " << weights_rsck.shape_.nbDims;
- for (int i = 0; i < weights_rsck.shape_.nbDims; i++) {
- VLOG(2) << weights_rsck.shape_.d[i];
- }
-
+ VLOG(2) << "weight shape: " << weights_rsck.DebugString();
if (weights_rsck.shape_.nbDims != 4) {
return tensorflow::errors::Internal(
"Conv2D expects kernel of dimension 4, at: " + node_def.name());
}
-
if (ctx.isFP16()) {
weights_rsck = ConvertFP32ToFP16(ctx, inputs.at(1).weights());
}
@@ -1209,16 +1250,13 @@ tensorflow::Status ConvertConv2DHelper(
nvinfer1::DimsHW kernel_size;
kernel_size.h() = weights.shape_.d[2];
kernel_size.w() = weights.shape_.d[3];
- VLOG(2) << "RSCK: ";
- for (int i = 0; i < 4; i++) {
- VLOG(2) << " " << weights.shape_.d[i];
- }
+ VLOG(2) << "RSCK: " << weights.DebugString();
VLOG(2) << "kernel size: " << kernel_size.h() << ", " << kernel_size.w();
// TODO(jie): stride. (NHWC/NCHW)
const auto tf_stride = attrs.get<std::vector<int>>("strides");
VLOG(2) << "h_INDEX" << h_index << ", w_index " << w_index;
- VLOG(2) << "stride!!!: " << tf_stride[0] << tf_stride[1] << tf_stride[2]
+ VLOG(2) << "stride: " << tf_stride[0] << tf_stride[1] << tf_stride[2]
<< tf_stride[3];
const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
@@ -1240,10 +1278,7 @@ tensorflow::Status ConvertConv2DHelper(
// TODO(jie): handle asymmetric padding
VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
<< padding[1].first << padding[1].second;
-
- auto dim_before = tensor->getDimensions();
- VLOG(2) << "TENSOR before: " << dim_before.d[0] << ", " << dim_before.d[1]
- << dim_before.d[2] << ", " << dim_before.d[3];
+ VLOG(2) << "TENSOR before: " << DebugString(tensor->getDimensions());
auto pad_layer = ctx.network()->addPadding(
*const_cast<nvinfer1::ITensor*>(tensor),
nvinfer1::DimsHW(padding[0].first, padding[1].first),
@@ -1251,9 +1286,7 @@ tensorflow::Status ConvertConv2DHelper(
TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name());
padding = {{0, 0}, {0, 0}};
tensor = pad_layer->getOutput(0);
- auto dim_after = tensor->getDimensions();
- VLOG(2) << "TENSOR after: " << dim_after.d[0] << ", " << dim_after.d[1]
- << dim_after.d[2] << ", " << dim_after.d[3];
+ VLOG(2) << "TENSOR after: " << DebugString(tensor->getDimensions());
}
nvinfer1::IConvolutionLayer* layer =
@@ -1266,17 +1299,12 @@ tensorflow::Status ConvertConv2DHelper(
layer->setName(node_def.name().c_str());
layer->setNbGroups(num_groups);
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
-
- auto dim_after = output_tensor->getDimensions();
- VLOG(2) << "TENSOR out: " << dim_after.d[0] << ", " << dim_after.d[1] << ", "
- << dim_after.d[2] << ", " << dim_after.d[3];
-
+ VLOG(2) << "TENSOR out: " << DebugString(output_tensor->getDimensions());
+ VLOG(2) << "data_format: " << data_format;
if (data_format == "NHWC") {
// TODO(jie): transpose it back!
output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name());
- } else {
- VLOG(2) << "NCHW !!!!";
}
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
@@ -1990,22 +2018,22 @@ tensorflow::Status ConvertReduce(Converter& ctx,
return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32");
}
- const auto keep_dims = attrs.get<bool>("keep_dims");
- auto index_list_data =
- static_cast<int*>(const_cast<void*>(index_list.GetValues()));
-
int axes = 0;
if (index_list.count() == 0) {
return tensorflow::errors::InvalidArgument(
"TRT cannot support reduce on all (batch) dimensions, at",
node_def.name());
} else {
+ auto index_list_data =
+ static_cast<int*>(const_cast<void*>(index_list.GetValues()));
for (int i = 0; i < index_list.count(); i++) {
- if (index_list_data[i] == 0) {
+ int axis = index_list_data[i];
+ if (axis < 0) axis += tensor->getDimensions().nbDims + 1;
+ if (axis == 0) {
return tensorflow::errors::InvalidArgument(
"TRT cannot reduce at batch dimension, at", node_def.name());
}
- axes |= (1 << (index_list_data[i] - 1));
+ axes |= (1 << (axis - 1));
}
}
@@ -2025,6 +2053,7 @@ tensorflow::Status ConvertReduce(Converter& ctx,
" , at ", node_def.name());
}
+ const auto keep_dims = attrs.get<bool>("keep_dims");
nvinfer1::ILayer* layer =
ctx.network()->addReduce(*const_cast<nvinfer1::ITensor*>(tensor),
reduce_operation, axes, keep_dims);
@@ -2694,8 +2723,6 @@ tensorflow::Status ConvertGraphDefToEngine(
VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op();
if (tensorflow::str_util::StartsWith(node_name, kInputPHName) &&
(node_def.op() == "Placeholder")) {
- nvinfer1::DimsCHW input_dim_pseudo_chw;
- for (int i = 0; i < 8; i++) input_dim_pseudo_chw.d[i] = 0;
int32 slot_number = -1;
if (!tensorflow::strings::safe_strto32(
node_name.c_str() + strlen(kInputPHName), &slot_number)) {
@@ -2713,28 +2740,25 @@ tensorflow::Status ConvertGraphDefToEngine(
LOG(WARNING) << error_message;
return Status(status.code(), error_message);
}
- if (VLOG_IS_ON(1)) {
- string dim_str("dims=");
- StrAppend(&dim_str, "[ ", shape.dim_size(0));
- for (int i = 1; i < shape.dims(); i++) {
- StrAppend(&dim_str, ", ", shape.dim_size(i));
- }
- StrAppend(&dim_str, " ]");
- VLOG(1) << dim_str;
- }
+
+#if NV_TENSORRT_MAJOR == 3
+ nvinfer1::DimsCHW input_dim;
+#elif NV_TENSORRT_MAJOR > 3
+ nvinfer1::Dims input_dim;
+#endif
for (int i = 1; i < shape.dims(); i++) {
- input_dim_pseudo_chw.d[i - 1] = shape.dim_size(i);
+ input_dim.d[i - 1] = shape.dim_size(i);
}
-
- input_dim_pseudo_chw.nbDims = shape.dims() - 1;
- nvinfer1::ITensor* input_tensor = converter.network()->addInput(
- node_name.c_str(), dtype, input_dim_pseudo_chw);
+ input_dim.nbDims = shape.dims() - 1;
+ nvinfer1::ITensor* input_tensor =
+ converter.network()->addInput(node_name.c_str(), dtype, input_dim);
if (!input_tensor) {
return tensorflow::errors::InvalidArgument(
"Failed to create Input layer tensor ", node_name,
" rank=", shape.dims() - 1);
}
- VLOG(1) << "Input tensor name :" << node_name;
+ VLOG(2) << "Adding engine input tensor " << node_name << " with shape "
+ << DebugString(input_dim);
if (!converter.insert_input_tensor(node_name, input_tensor)) {
return tensorflow::errors::AlreadyExists(
"Output tensor already exists for op: " + node_name);
@@ -2937,10 +2961,25 @@ bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const {
<< ": " << status;
return false;
}
- if (shape.dims() < 3 && in_edge->src()->type_string() != "Const") {
+
+
+ if (in_edge->src()->type_string() != "Const" &&
+#if NV_TENSORRT_MAJOR == 3
+ // TRT 3.x only support 4 dimensional input tensor.
+ shape.dims() != 4) {
+#else
+ // Single dimensional input tensor is not supported since the first
+ // dimension is treated as batch dimension.
+ shape.dims() < 2) {
+#endif
VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name()
- << " which has an input at port " << in_edge->dst_input()
- << " with #dim<3 and is not a const: " << shape;
+ << " which has an input at port " << in_edge->dst_input() << " with"
+#if NV_TENSORRT_MAJOR == 3
+ << " #dim!=4"
+#else
+ << " #dim<2"
+#endif
+ << " and is not a const: " << shape;
return false;
}
return true;
diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc
index 2de7973750..11335d7da6 100644
--- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc
+++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc
@@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h"
#include <vector>
+#define EIGEN_USE_GPU
#include "tensorflow/core/framework/op_kernel.h"
-
-#if GOOGLE_CUDA
-#if GOOGLE_TENSORRT
#include "cuda/include/cuda_runtime_api.h"
#include "tensorflow/core/platform/stream_executor.h"
@@ -80,5 +81,5 @@ REGISTER_KERNEL_BUILDER(Name("IncPluginTRT").Device(DEVICE_GPU), IncPluginTRT);
} // namespace tensorrt
} // namespace tensorflow
-#endif // GOOGLE_CUDA
#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc
index b43f1b190f..c82d4a0183 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment.cc
@@ -74,6 +74,7 @@ class SimpleNode {
const std::vector<SimpleEdge*>& in_edges() const { return in_edges_; }
const std::vector<SimpleEdge*>& out_edges() const { return out_edges_; }
+
std::vector<SimpleNode*> in_nodes() const {
std::vector<SimpleNode*> res;
res.reserve(in_edges_.size());
@@ -82,6 +83,16 @@ class SimpleNode {
}
return res;
}
+
+ std::vector<SimpleNode*> out_nodes() const {
+ std::vector<SimpleNode*> res;
+ res.reserve(out_edges_.size());
+ for (const auto e : out_edges_) {
+ if (e) res.push_back(e->dst());
+ }
+ return res;
+ }
+
const string& name() const { return node_->name(); }
const tensorflow::Node* tf_node() const { return node_; }
int id() const { return id_; }
@@ -215,45 +226,53 @@ SimpleGraph::~SimpleGraph() {
namespace {
-bool CheckCycles(const std::unique_ptr<SimpleGraph>& g, const SimpleNode* src,
- const std::vector<SimpleNode*>& start) {
- // Copied from TF ReverseDFS, which only works for tensorflow::Graph.
+// Copied from TF ReverseDFS, which only works for tensorflow::Graph.
+void StableDFS(const SimpleGraph& g, bool reverse,
+ const std::vector<const SimpleNode*>& start,
+ const std::function<bool(const SimpleNode*)>& enter,
+ const std::function<bool(const SimpleNode*)>& leave) {
+ // Stack of work to do.
struct Work {
- SimpleNode* node;
+ const SimpleNode* node;
bool leave; // Are we entering or leaving n?
};
-
std::vector<Work> stack(start.size());
for (int i = 0; i < start.size(); ++i) {
stack[i] = Work{start[i], false};
}
- std::vector<bool> visited(g->num_node_ids(), false);
+ auto get_nodes = reverse ? [](const SimpleNode* n) { return n->in_nodes(); }
+ : [](const SimpleNode* n) { return n->out_nodes(); };
+ std::vector<bool> visited(g.num_node_ids(), false);
while (!stack.empty()) {
Work w = stack.back();
stack.pop_back();
auto n = w.node;
if (w.leave) {
- if (n == src) {
- return true;
- }
+ if (leave && !leave(n)) return;
continue;
}
if (visited[n->id()]) continue;
visited[n->id()] = true;
- // Arrange to call leave(n) when all done with descendants.
- stack.push_back(Work{n, true});
+ if (enter && !enter(n)) return;
- auto nodes = n->in_nodes();
- for (const auto node : nodes) {
+ // Arrange to call leave(n) when all done with descendants.
+ if (leave) stack.push_back(Work{n, true});
+
+ auto nodes = get_nodes(n);
+ std::vector<const SimpleNode*> nodes_sorted(nodes.begin(), nodes.end());
+ std::sort(nodes_sorted.begin(), nodes_sorted.end(),
+ [](const SimpleNode* lhs, const SimpleNode* rhs) {
+ return lhs->name() < rhs->name();
+ });
+ for (const SimpleNode* node : nodes_sorted) {
if (!visited[node->id()]) {
stack.push_back(Work{node, false});
}
}
}
- return false;
}
bool CanContractEdge(const SimpleEdge* edge,
@@ -289,14 +308,21 @@ bool CanContractEdge(const SimpleEdge* edge,
// To achieve this goal, the correct way seems to be:
// 1. remove any direct edge from src->dst;
// 2. detect if src can reach dst, if so they cannot be merged.
- std::vector<SimpleNode*> dfs_start_nodes;
- for (SimpleNode* node : dst->in_nodes()) {
+ std::vector<const SimpleNode*> dfs_start_nodes;
+ for (const SimpleNode* node : dst->in_nodes()) {
if (node != src) {
dfs_start_nodes.push_back(node);
}
}
-
- const bool has_cycle = CheckCycles(graph, src, dfs_start_nodes);
+ bool has_cycle = false;
+ StableDFS(*graph, /*reverse=*/true, dfs_start_nodes, /*enter=*/nullptr,
+ [&has_cycle, src](const SimpleNode* n) {
+ if (n == src) {
+ has_cycle = true;
+ return false;
+ }
+ return true;
+ });
return !has_cycle;
}
} // namespace
@@ -403,15 +429,13 @@ tensorflow::Status SegmentGraph(
// In the future if we have a measure of how beneficial it is to include a
// given node in a TRT subgraph then we can revisit this algorithm to take
// advantage of that information.
- std::vector<tensorflow::Node*> tforder;
- tensorflow::GetPostOrder(*tf_graph, &tforder);
- // use postorder implementation from tensorflow and construct mirror in
- // internal format
- std::vector<SimpleNode*> order;
- order.reserve(tforder.size());
- for (const auto tfnode : tforder) {
- order.push_back(graph->FindNodeId(tfnode->id()));
- }
+ std::vector<const SimpleNode*> order;
+ order.reserve(graph->num_node_ids());
+ StableDFS(*graph, /*reverse=*/false, {graph->source_node()},
+ /*enter=*/nullptr, [&order](const SimpleNode* n) {
+ order.push_back(n);
+ return true;
+ });
for (const SimpleNode* node : order) {
// All output nodes of 'node' have been visited...
VLOG(3) << "Trying node " << node->name() << " id=" << node->id();
diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py
index 8ea5a63735..e9ac833d55 100644
--- a/tensorflow/contrib/tensorrt/test/base_test.py
+++ b/tensorflow/contrib/tensorrt/test/base_test.py
@@ -40,6 +40,7 @@ class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [100, 24, 24, 2]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -62,19 +63,21 @@ class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
identity = array_ops.identity(relu, "identity")
pool = nn_ops.max_pool(
identity, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
- array_ops.squeeze(pool, name=self.output_name)
+ array_ops.squeeze(pool, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
- # breaks the connection check, fix it.
- # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add",
- # "relu", "identity", "max_pool"]
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(100, 6, 6, 6),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(100, 6, 6, 6)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
+ # breaks the connection check, fix it.
+ # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add",
+ # "relu", "identity", "max_pool"]
+ return ["my_trt_op_0"]
class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
@@ -85,6 +88,7 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [100, 24, 24, 2]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -115,20 +119,22 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
q = math_ops.mul(q, edge, name="mul1")
s = math_ops.add(p, q, name="add1")
s = math_ops.sub(s, r, name="sub1")
- array_ops.squeeze(s, name=self.output_name)
+ array_ops.squeeze(s, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
- # breaks the connection check, fix it.
- # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1",
- # "add", "sub1"];
- # - my_trt_op_1 should have ["weights","conv", "div"]
- expected_engines=["my_trt_op_0", "my_trt_op_1"],
- expected_output_dims=(100, 12, 12, 6),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(100, 12, 12, 6)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
+ # breaks the connection check, fix it.
+ # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1",
+ # "add", "sub1"];
+ # - my_trt_op_1 should have ["weights","conv", "div"]
+ return ["my_trt_op_0", "my_trt_op_1"]
class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
@@ -143,6 +149,7 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
"""Create a graph containing two segment."""
input_name = "input"
input_dims = [2, 32, 32, 3]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -161,18 +168,20 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
c = constant_op.constant(1.0, name="c3")
n = math_ops.add(n, c, name="add3")
n = math_ops.mul(n, n, name="mul3")
- array_ops.squeeze(n, name=self.output_name)
+ array_ops.squeeze(n, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines={
- # Only the first engine is built.
- "my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"]
- },
- expected_output_dims=tuple(input_dims),
- allclose_atol=1.e-06,
- allclose_rtol=1.e-06)
+ output_names=[output_name],
+ expected_output_dims=[tuple(input_dims)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ # Only the first engine is built.
+ "my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"]
+ }
class PartiallyConvertedTestB(PartiallyConvertedTestA):
@@ -184,13 +193,12 @@ class PartiallyConvertedTestB(PartiallyConvertedTestA):
trt_convert.clear_test_values("")
trt_convert.add_test_value("my_trt_op_0:CreateTRTNode", "fail")
- def GetParams(self):
- """Create a graph containing two segment."""
- return super(PartiallyConvertedTestB, self).GetParams()._replace(
- expected_engines={
- # Only the second engine is built.
- "my_trt_op_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"]
- })
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ # Only the second engine is built.
+ "my_trt_op_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"]
+ }
class ConstInputTest(trt_test.TfTrtIntegrationTestBase):
@@ -199,6 +207,7 @@ class ConstInputTest(trt_test.TfTrtIntegrationTestBase):
"""Create a graph containing multiple segment."""
input_name = "input"
input_dims = [2, 32, 32, 3]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -221,18 +230,20 @@ class ConstInputTest(trt_test.TfTrtIntegrationTestBase):
n = math_ops.add(n, c, name="add2")
n = math_ops.mul(n, n, name="mul1")
n = math_ops.add(n, n, name="add3")
- array_ops.squeeze(n, name=self.output_name)
+ array_ops.squeeze(n, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines={
- "my_trt_op_0": ["add", "add1", "mul"],
- "my_trt_op_1": ["add2", "add3", "mul1"]
- },
- expected_output_dims=tuple(input_dims),
- allclose_atol=1.e-06,
- allclose_rtol=1.e-06)
+ output_names=[output_name],
+ expected_output_dims=[tuple(input_dims)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ "my_trt_op_0": ["add", "add1", "mul"],
+ "my_trt_op_1": ["add2", "add3", "mul1"]
+ }
class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
@@ -241,6 +252,7 @@ class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
"""Create a graph containing single segment."""
input_name = "input"
input_dims = [2, 32, 32, 3]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -251,15 +263,17 @@ class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
n = math_ops.add(n, c, name="add")
n = math_ops.mul(n, n, name="mul")
n = math_ops.add(n, n, name="add1")
- array_ops.squeeze(n, name=self.output_name)
+ array_ops.squeeze(n, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines={"my_trt_op_0": ["c", "add", "add1", "mul"]},
- expected_output_dims=tuple(input_dims),
- allclose_atol=1.e-06,
- allclose_rtol=1.e-06)
+ output_names=[output_name],
+ expected_output_dims=[tuple(input_dims)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {"my_trt_op_0": ["c", "add", "add1", "mul"]}
class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase):
@@ -268,6 +282,7 @@ class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase):
"""Create a graph containing multiple segment."""
input_name = "input"
input_dims = [2, 32, 32, 3]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -282,22 +297,24 @@ class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase):
n = math_ops.add(n, c, name="add2")
n = math_ops.mul(n, n, name="mul1")
n = math_ops.add(n, n, name="add3")
- array_ops.squeeze(n, name=self.output_name)
+ array_ops.squeeze(n, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines={
- "my_trt_op_0": ["add2", "add3", "mul1"],
- # Why segment ["add", "add1", "mul"] was assigned segment id 1
- # instead of 0: the parent node of this segment is actually const
- # node 'c', but it's removed later since it's const output of the
- # segment which is not allowed.
- "my_trt_op_1": ["add", "add1", "mul"]
- },
- expected_output_dims=tuple(input_dims),
- allclose_atol=1.e-06,
- allclose_rtol=1.e-06)
+ output_names=[output_name],
+ expected_output_dims=[tuple(input_dims)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ "my_trt_op_0": ["add2", "add3", "mul1"],
+ # Why segment ["add", "add1", "mul"] was assigned segment id 1
+ # instead of 0: the parent node of this segment is actually const
+ # node 'c', but it's removed later since it's const output of the
+ # segment which is not allowed.
+ "my_trt_op_1": ["add", "add1", "mul"]
+ }
class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase):
@@ -306,6 +323,7 @@ class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase):
"""Create a graph containing multiple segment."""
input_name = "input"
input_dims = [2, 32, 32, 3]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -328,18 +346,20 @@ class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase):
mul1 = math_ops.mul(add2, add2, name="mul1")
with g.control_dependencies([d1, d2, add, add1]):
add3 = math_ops.add(mul1, mul1, name="add3")
- array_ops.squeeze(add3, name=self.output_name)
+ array_ops.squeeze(add3, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines={
- "my_trt_op_0": ["c1", "add", "add1", "mul"],
- "my_trt_op_1": ["c2", "add2", "add3", "mul1"]
- },
- expected_output_dims=tuple(input_dims),
- allclose_atol=1.e-06,
- allclose_rtol=1.e-06)
+ output_names=[output_name],
+ expected_output_dims=[tuple(input_dims)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ "my_trt_op_0": ["c1", "add", "add1", "mul"],
+ "my_trt_op_1": ["c2", "add2", "add3", "mul1"]
+ }
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
index 2e1107e303..2f153c6f2f 100644
--- a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
+++ b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
@@ -37,6 +37,7 @@ class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [12, 5, 8, 12]
+ output_name = "output"
w1_name = "matmul_w1"
w1_dims = [12, 5, 12, 7]
w2_name = "matmul_w2"
@@ -61,15 +62,46 @@ class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase):
x3 = x3 + f
x3 = gen_array_ops.reshape(x3, [12, 5, 8, 7])
out = x1 + x2 + x3
- array_ops.squeeze(out, name=self.output_name)
+ array_ops.squeeze(out, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name, w1_name, w2_name],
input_dims=[input_dims, w1_dims, w2_dims],
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(12, 5, 8, 7),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(12, 5, 8, 7)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ if (run_params.dynamic_engine and
+ not trt_test.IsQuantizationMode(run_params.precision_mode)):
+ return ["my_trt_op_0", "my_trt_op_1"]
+ return ["my_trt_op_1"]
+
+ def ExpectedEnginesToRun(self, run_params):
+ """Return the expected engines to run."""
+ return ["my_trt_op_1"]
+
+ def ShouldRunTest(self, run_params):
+ """Whether to run the test."""
+ # TODO(aaroey): Trt library will fail like:
+ #
+ # ../builder/cudnnBuilder2.cpp:685:
+ # virtual std::vector<nvinfer1::query::Ports<
+ # nvinfer1::query::TensorRequirements>>
+ # nvinfer1::builder::Node::getSupportedFormats(
+ # const nvinfer1::query::Ports<nvinfer1::query::AbstractTensor>&,
+ # const nvinfer1::cudnn::HardwareContext&,
+ # nvinfer1::builder::Format::Type,
+ # const nvinfer1::builder::FormatTypeHack&) const:
+ # Assertion `sf' failed.
+ #
+ # To reproduce, run:
+ # bazel test -c opt --copt=-mavx \
+ # --test_arg=BatchMatMulTest.testTfTrt_ToolConversion_INT8_DynamicEngine \
+ # tensorflow/contrib/tensorrt:batch_matmul_test
+ #
+ # Investigate and fix it.
+ return not trt_test.IsQuantizationMode(run_params.precision_mode)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
index 8be32f59b4..62f4e525f7 100644
--- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
+++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
@@ -38,6 +38,7 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [48, 12]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
@@ -97,18 +98,59 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
out = array_ops.concat(
[x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11], axis=-1)
- out = array_ops.squeeze(out, name=self.output_name)
+ out = array_ops.squeeze(out, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=[
- "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3",
- "my_trt_op_4", "my_trt_op_5", "my_trt_op_6"
- ],
- expected_output_dims=(48, 89),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(48, 89)])
+
+ def GetConversionParams(self, run_params):
+ """Return a ConversionParams for test."""
+ return super(BiasaddMatMulTest,
+ self).GetConversionParams(run_params)._replace(
+ max_batch_size=48, maximum_cached_engines=2)
+
+ def _ValidEngines(self):
+ """Engines expected to build and run."""
+ return [
+ "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_6",
+ "my_trt_op_7", "my_trt_op_8", "my_trt_op_9"
+ ]
+
+ def _InvalidEngines(self):
+ """Engines that will cause conversion error at building time."""
+ return ["my_trt_op_3", "my_trt_op_4", "my_trt_op_5"]
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ # In dynamic engine mode the engines are built in execution time, not in
+ # conversion time, so build errors occurs later. Here three of the engines
+ # will be failed to built but the corresponding engine op are still created.
+ # TODO(aaroey, jjsjann123): fix this.
+ if (run_params.dynamic_engine and
+ not trt_test.IsQuantizationMode(run_params.precision_mode)):
+ return self._ValidEngines() + self._InvalidEngines()
+ return self._ValidEngines()
+
+ def ExpectedEnginesToRun(self, run_params):
+ """Return the expected engines to run."""
+ return self._ValidEngines()
+
+ def ShouldRunTest(self, run_params):
+ """Whether to run the test."""
+ # TODO(aaroey): Trt 4.0 forbids conversion for tensors with rank <3 in int8
+ # mode, which is a bug. Re-enable this when trt library is fixed.
+ return not trt_test.IsQuantizationMode(run_params.precision_mode)
+
+ def ExpectedAbsoluteTolerance(self, run_params):
+ """The absolute tolerance to compare floating point results."""
+ return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-03
+
+ def ExpectedRelativeTolerance(self, run_params):
+ """The relative tolerance to compare floating point results."""
+ return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-03
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
index 9316b14da0..f126ed4238 100644
--- a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
+++ b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
@@ -37,6 +37,7 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [10, 24, 24, 20]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
@@ -104,32 +105,34 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase):
a = constant_op.constant(np.random.randn(24, 20), dtype=dtype)
f = x + a
x = math_ops.sigmoid(f)
- gen_array_ops.reshape(x, [5, -1], name=self.output_name)
+ gen_array_ops.reshape(x, [5, -1], name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=[
- "my_trt_op_0",
- "my_trt_op_1",
- "my_trt_op_2",
- "my_trt_op_3",
- "my_trt_op_4",
- "my_trt_op_5",
- "my_trt_op_6",
- "my_trt_op_7",
- "my_trt_op_8",
- "my_trt_op_9",
- "my_trt_op_10",
- "my_trt_op_11",
- "my_trt_op_12",
- "my_trt_op_13",
- "my_trt_op_14",
- "my_trt_op_15",
- ],
- expected_output_dims=(5, 23040),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(5, 23040)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return [
+ "my_trt_op_0",
+ "my_trt_op_1",
+ "my_trt_op_2",
+ "my_trt_op_3",
+ "my_trt_op_4",
+ "my_trt_op_5",
+ "my_trt_op_6",
+ "my_trt_op_7",
+ "my_trt_op_8",
+ "my_trt_op_9",
+ "my_trt_op_10",
+ "my_trt_op_11",
+ "my_trt_op_12",
+ "my_trt_op_13",
+ "my_trt_op_14",
+ "my_trt_op_15",
+ ]
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/concatenation_test.py b/tensorflow/contrib/tensorrt/test/concatenation_test.py
index 1874b9dd45..465cb02296 100644
--- a/tensorflow/contrib/tensorrt/test/concatenation_test.py
+++ b/tensorflow/contrib/tensorrt/test/concatenation_test.py
@@ -37,6 +37,7 @@ class ConcatenationTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [2, 3, 3, 1]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
@@ -68,15 +69,17 @@ class ConcatenationTest(trt_test.TfTrtIntegrationTestBase):
concat1 = array_ops.concat([r1, r2, r3, r4, r5, r6], axis=-1)
concat2 = array_ops.concat([r7, r8, r9, r10, r11, r12], axis=3)
x = array_ops.concat([concat1, concat2], axis=-1)
- gen_array_ops.reshape(x, [2, -1], name=self.output_name)
+ gen_array_ops.reshape(x, [2, -1], name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(2, 126),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(2, 126)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ["my_trt_op_0"]
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
index 8c59000b70..e32f047866 100644
--- a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
+++ b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
@@ -36,6 +36,7 @@ class ConstBroadcastTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = 'input'
input_dims = [5, 12, 12, 2]
+ output_name = 'output'
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
@@ -53,15 +54,25 @@ class ConstBroadcastTest(trt_test.TfTrtIntegrationTestBase):
dtype=dtype,
name='filt3')
y3 = nn.conv2d(z2, filt3, strides=[1, 1, 1, 1], padding='SAME', name='y3')
- nn.relu(y3, name='output')
+ nn.relu(y3, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=['my_trt_op_0'],
- expected_output_dims=(5, 12, 12, 1),
- allclose_atol=1.e-02,
- allclose_rtol=1.e-02)
+ output_names=[output_name],
+ expected_output_dims=[(5, 12, 12, 1)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ['my_trt_op_0']
+
+ def ExpectedAbsoluteTolerance(self, run_params):
+ """The absolute tolerance to compare floating point results."""
+ return 1.e-04 if run_params.precision_mode == 'FP32' else 1.e-02
+
+ def ExpectedRelativeTolerance(self, run_params):
+ """The relative tolerance to compare floating point results."""
+ return 1.e-04 if run_params.precision_mode == 'FP32' else 1.e-02
if __name__ == '__main__':
diff --git a/tensorflow/contrib/tensorrt/test/manual_test.py b/tensorflow/contrib/tensorrt/test/manual_test.py
new file mode 100644
index 0000000000..1187c759b4
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/manual_test.py
@@ -0,0 +1,114 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Basic tests for TF-TensorRT integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ast
+import os
+
+from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+
+
+class ManualTest(trt_test.TfTrtIntegrationTestBase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ super(ManualTest, self).__init__(methodName)
+ self._params_map = None
+
+ def _GetEnv(self):
+ """Get an environment variable specifying the manual test parameters.
+
+ The value of the environment variable is the string representation of a dict
+ which should contain the following keys:
+ - 'graph_path': the file path to the serialized frozen graphdef
+ - 'input_names': TfTrtIntegrationTestParams.input_names
+ - 'input_dims': TfTrtIntegrationTestParams.input_dims
+ - 'expected_output_dims': TfTrtIntegrationTestParams.expected_output_dims
+ - 'output_name': the name of op to fetch
+ - 'expected_engines_to_run': ExpectedEnginesToRun() will return this
+ - 'expected_engines_to_build': ExpectedEnginesToBuild() will return this
+ - 'max_batch_size': ConversionParams.max_batch_size
+
+ Returns:
+ The value of the environment variable.
+ """
+ return os.getenv('TRT_MANUAL_TEST_PARAMS', '')
+
+ def _GetParamsMap(self):
+ """Parse the environment variable as a dict and return it."""
+ if self._params_map is None:
+ self._params_map = ast.literal_eval(self._GetEnv())
+ return self._params_map
+
+ def GetParams(self):
+ """Testing conversion of manually provided frozen graph."""
+ params_map = self._GetParamsMap()
+ gdef = graph_pb2.GraphDef()
+ with gfile.Open(params_map['graph_path'], 'rb') as f:
+ gdef.ParseFromString(f.read())
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=gdef,
+ input_names=params_map['input_names'],
+ input_dims=params_map['input_dims'],
+ output_names=params_map['output_names'],
+ expected_output_dims=params_map['expected_output_dims'])
+
+ def GetConversionParams(self, run_params):
+ """Return a ConversionParams for test."""
+ conversion_params = super(ManualTest, self).GetConversionParams(run_params)
+ params_map = self._GetParamsMap()
+ if 'max_batch_size' in params_map:
+ conversion_params = conversion_params._replace(
+ max_batch_size=params_map['max_batch_size'])
+ return conversion_params
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return self._GetParamsMap()['expected_engines_to_build']
+
+ def ExpectedEnginesToRun(self, run_params):
+ """Return the expected engines to run."""
+ params_map = self._GetParamsMap()
+ if 'expected_engines_to_run' in params_map:
+ return params_map['expected_engines_to_run']
+ return self.ExpectedEnginesToBuild(run_params)
+
+ def ExpectedAbsoluteTolerance(self, run_params):
+ """The absolute tolerance to compare floating point results."""
+ params_map = self._GetParamsMap()
+ if 'atol' in params_map:
+ return params_map['atol']
+ return 1.e-3
+
+ def ExpectedRelativeTolerance(self, run_params):
+ """The relative tolerance to compare floating point results."""
+ params_map = self._GetParamsMap()
+ if 'rtol' in params_map:
+ return params_map['rtol']
+ return 1.e-3
+
+ def ShouldRunTest(self, run_params):
+ """Whether to run the test."""
+ return len(self._GetEnv())
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
index 66eb6be757..bc7c90081f 100644
--- a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
+++ b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
@@ -36,6 +36,7 @@ class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [2, 15, 15, 3]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -57,15 +58,25 @@ class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase):
strides=[1, 1, 1, 1],
padding="VALID",
name="conv_2")
- array_ops.squeeze(out, name=self.output_name)
+ array_ops.squeeze(out, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(2, 15, 15, 10),
- allclose_atol=1.e-02,
- allclose_rtol=1.e-02)
+ output_names=[output_name],
+ expected_output_dims=[(2, 15, 15, 10)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ["my_trt_op_0"]
+
+ def ExpectedAbsoluteTolerance(self, run_params):
+ """The absolute tolerance to compare floating point results."""
+ return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-02
+
+ def ExpectedRelativeTolerance(self, run_params):
+ """The relative tolerance to compare floating point results."""
+ return 0.1
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
index fd55b8cd99..11be4feaf7 100644
--- a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
+++ b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
@@ -38,6 +38,7 @@ class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [2, 3, 7, 5]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
@@ -72,15 +73,17 @@ class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase):
t = t + q
t = t + d
t = t - edge3
- array_ops.squeeze(t, name=self.output_name)
+ array_ops.squeeze(t, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=["my_trt_op_0", "my_trt_op_1"],
- expected_output_dims=(2, 4, 5, 4),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(2, 4, 5, 4)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ["my_trt_op_0", "my_trt_op_1"]
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
index 51c905a50b..eddeafa38b 100644
--- a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
+++ b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
@@ -37,6 +37,7 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [2, 3, 7, 5]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
@@ -54,18 +55,20 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase):
t = math_ops.mul(conv, b, name="mul")
e = self.trt_incompatible_op(conv, name="incompatible")
t = math_ops.sub(t, e, name="sub")
- array_ops.squeeze(t, name=self.output_name)
+ array_ops.squeeze(t, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines={
- "my_trt_op_0": ["bias", "mul", "sub"],
- "my_trt_op_1": ["weights", "conv"]
- },
- expected_output_dims=(2, 4, 5, 4),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(2, 4, 5, 4)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ "my_trt_op_0": ["bias", "mul", "sub"],
+ "my_trt_op_1": ["weights", "conv"]
+ }
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/rank_two_test.py b/tensorflow/contrib/tensorrt/test/rank_two_test.py
new file mode 100644
index 0000000000..74a4a05925
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/rank_two_test.py
@@ -0,0 +1,89 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Model script to test TF-TensorRT integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class RankTwoTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Test for rank 2 input in TF-TRT."""
+ input_names = ["input", "input2"]
+ # Two paths: first with rank 2 input, second with rank 4 input.
+ input_dims = [[12, 5], [12, 5, 2, 2]]
+ output_name = "output"
+ g = ops.Graph()
+ with g.as_default():
+ outputs = []
+ for i in range(2):
+ x = array_ops.placeholder(
+ dtype=dtypes.float32, shape=input_dims[i], name=input_names[i])
+ c = constant_op.constant(1.0, name="c%d_1" % i)
+ q = math_ops.add(x, c, name="add%d_1" % i)
+ q = math_ops.abs(q, name="abs%d_1" % i)
+ c = constant_op.constant(2.2, name="c%d_2" % i)
+ q = math_ops.add(q, c, name="add%d_2" % i)
+ q = math_ops.abs(q, name="abs%d_2" % i)
+ c = constant_op.constant(3.0, name="c%d_3" % i)
+ q = math_ops.add(q, c, name="add%d_3" % i)
+ if i == 0:
+ for j in range(2):
+ q = array_ops.expand_dims(q, -1, name="expand%d_%d" % (i, j))
+ q = gen_math_ops.reciprocal(q, name="reciprocal%d" % i)
+ outputs.append(q)
+ # Combine both paths
+ q = math_ops.add(outputs[0], outputs[1], name="add")
+ array_ops.squeeze(q, name=output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=input_names,
+ input_dims=input_dims,
+ output_names=[output_name],
+ expected_output_dims=[tuple(input_dims[1])])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ "my_trt_op_0": [
+ "add0_1", "add0_2", "add0_3", "c0_1", "c0_2", "c0_3", "abs0_1",
+ "abs0_2"
+ ],
+ "my_trt_op_1": [
+ "add", "add1_1", "add1_2", "add1_3", "c1_1", "c1_2", "c1_3",
+ "abs1_1", "abs1_2", "reciprocal0", "reciprocal1"
+ ],
+ }
+
+ def ShouldRunTest(self, run_params):
+ """Whether to run the test."""
+ # TODO(aaroey): Trt 4.0 forbids conversion for tensors with rank <3 in int8
+ # mode, which is a bug. Re-enable this when trt library is fixed.
+ return not trt_test.IsQuantizationMode(run_params.precision_mode)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
index 6f85ada464..65ca21cf37 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
@@ -31,6 +31,7 @@ from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
# pylint: enable=unused-import
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
@@ -39,18 +40,23 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
TfTrtIntegrationTestParams = namedtuple("TfTrtIntegrationTestParams", [
- "gdef", "input_names", "input_dims", "expected_engines",
- "expected_output_dims", "allclose_atol", "allclose_rtol"
+ "gdef", "input_names", "input_dims", "output_names", "expected_output_dims"
])
RunParams = namedtuple(
"RunParams",
["use_optimizer", "precision_mode", "dynamic_engine", "test_name"])
+ConversionParams = namedtuple("ConversionParams", [
+ "max_batch_size", "max_workspace_size_bytes", "precision_mode",
+ "minimum_segment_size", "is_dynamic_op", "maximum_cached_engines",
+ "cached_engine_batches"
+])
+
PRECISION_MODES = ["FP32", "FP16", "INT8"]
-def _IsQuantizationMode(mode):
+def IsQuantizationMode(mode):
return mode == "INT8"
@@ -64,10 +70,6 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
"""Class to test Tensorflow-TensorRT integration."""
@property
- def output_name(self):
- return "output"
-
- @property
def trt_incompatible_op(self):
return math_ops.sin
@@ -112,6 +114,10 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
super(TfTrtIntegrationTestBase, cls).setUpClass()
trt_convert.enable_test_value()
+ def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
+ super(TfTrtIntegrationTestBase, self).__init__(methodName)
+ self._trt_test_params = None
+
def setUp(self):
"""Setup method."""
super(TfTrtIntegrationTestBase, self).setUp()
@@ -122,43 +128,97 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
"""Return a TfTrtIntegrationTestParams for test, implemented by subclass."""
raise NotImplementedError()
- def _PrepareRun(self, params, graph_state):
+ def GetConversionParams(self, run_params):
+ """Return a ConversionParams for test."""
+ return ConversionParams(
+ max_batch_size=max([
+ dims[0] for dims in self._GetParamsCached().input_dims if len(dims)
+ ]),
+ max_workspace_size_bytes=1 << 25,
+ precision_mode=self._ToBytes(run_params.precision_mode),
+ minimum_segment_size=2,
+ is_dynamic_op=run_params.dynamic_engine,
+ maximum_cached_engines=1,
+ cached_engine_batches=None)
+
+ def ShouldRunTest(self, run_params):
+ """Whether to run the test."""
+ return True
+
+ def VerifyRunForEngine(self, engine_name, graph_state, expect_run=True):
+ """Verify the state of a particular engine after sess.run()."""
+ if graph_state == GraphState.ORIGINAL:
+ self._ExpectCalibration(engine_name, "")
+ self._ExpectNativeSegment(engine_name, "")
+ self._ExpectTrtEngine(engine_name, "")
+ elif graph_state == GraphState.CALIBRATE:
+ self._ExpectCalibration(engine_name, "done")
+ self._ExpectNativeSegment(engine_name, "done")
+ self._ExpectTrtEngine(engine_name, "")
+ elif graph_state == GraphState.INFERENCE:
+ self._ExpectCalibration(engine_name, "")
+ if expect_run:
+ self._ExpectNativeSegment(engine_name, "")
+ self._ExpectTrtEngine(engine_name, "done")
+ else:
+ self._ExpectNativeSegment(engine_name, "done")
+ self._ExpectTrtEngine(engine_name, "")
+
+ def VerifyRun(self, run_params, graph_state):
+ """Verify the state of all engines after sess.run()."""
+ for engine_name in self.ExpectedEnginesToBuild(run_params):
+ expect_run = (engine_name in self.ExpectedEnginesToRun(run_params))
+ self.VerifyRunForEngine(engine_name, graph_state, expect_run)
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build, implemented by subclass."""
+ raise NotImplementedError()
+
+ def ExpectedEnginesToRun(self, run_params):
+ """Return the expected engines to run."""
+ return self.ExpectedEnginesToBuild(run_params)
+
+ def ExpectedAbsoluteTolerance(self, run_params):
+ """The absolute tolerance to compare floating point results."""
+ return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-03
+
+ def ExpectedRelativeTolerance(self, run_params):
+ """The relative tolerance to compare floating point results."""
+ return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-03
+
+ def _GetParamsCached(self):
+ if self._trt_test_params is None:
+ self._trt_test_params = self.GetParams()
+ return self._trt_test_params
+
+ def _PrepareRun(self, graph_state):
"""Set up necessary testing environment before calling sess.run()."""
# Clear test values added by TRTEngineOp.
trt_convert.clear_test_values("my_trt_op_.*:ExecuteTrtEngine")
trt_convert.clear_test_values("my_trt_op_.*:ExecuteCalibration")
trt_convert.clear_test_values("my_trt_op_.*:ExecuteNativeSegment")
- def _VerifyRun(self, params, graph_state):
- """Verify the state after sess.run()."""
- for engine_name in params.expected_engines:
- if graph_state == GraphState.ORIGINAL:
- self._ExpectCalibration(engine_name, "")
- self._ExpectNativeSegment(engine_name, "")
- self._ExpectTrtEngine(engine_name, "")
- elif graph_state == GraphState.CALIBRATE:
- self._ExpectCalibration(engine_name, "done")
- self._ExpectNativeSegment(engine_name, "done")
- self._ExpectTrtEngine(engine_name, "")
- elif graph_state == GraphState.INFERENCE:
- self._ExpectCalibration(engine_name, "")
- self._ExpectNativeSegment(engine_name, "")
- self._ExpectTrtEngine(engine_name, "done")
-
- def _GetConfigProto(self, params, run_params, graph_state):
+ def _GetConfigProto(self, run_params, graph_state):
"""Get config proto based on specific settings."""
if graph_state != GraphState.ORIGINAL and run_params.use_optimizer:
rewriter_cfg = rewriter_config_pb2.RewriterConfig()
rewriter_cfg.optimizers.extend(["constfold", "layout"])
custom_op = rewriter_cfg.custom_optimizers.add()
custom_op.name = "TensorRTOptimizer"
- custom_op.parameter_map["minimum_segment_size"].i = 2
- custom_op.parameter_map["max_batch_size"].i = max(
- [dims[0] for dims in params.input_dims])
- custom_op.parameter_map["is_dynamic_op"].b = run_params.dynamic_engine
- custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25
- custom_op.parameter_map["precision_mode"].s = self._ToBytes(
- run_params.precision_mode)
+ trt_params = self.GetConversionParams(run_params)
+ custom_op.parameter_map["max_batch_size"].i = trt_params.max_batch_size
+ custom_op.parameter_map["max_workspace_size_bytes"].i = (
+ trt_params.max_workspace_size_bytes)
+ custom_op.parameter_map["precision_mode"].s = trt_params.precision_mode
+ custom_op.parameter_map["minimum_segment_size"].i = (
+ trt_params.minimum_segment_size)
+ custom_op.parameter_map["is_dynamic_op"].b = trt_params.is_dynamic_op
+ custom_op.parameter_map["maximum_cached_engines"].i = (
+ trt_params.maximum_cached_engines)
+ if trt_params.cached_engine_batches:
+ custom_op.parameter_map["cached_engine_batches"].list.i.extend(
+ trt_params.cached_engine_batches)
+
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
else:
graph_options = config_pb2.GraphOptions()
@@ -190,53 +250,67 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
def _ExpectNativeSegment(self, engine_name, value):
self._ExpectTestValue(engine_name, "ExecuteNativeSegment", value)
- def _RunGraph(self, params, gdef, input_data, config, graph_state,
+ def _RunGraph(self,
+ run_params,
+ gdef,
+ input_data,
+ config,
+ graph_state,
num_runs=2):
"""Run given graphdef multiple times."""
+ params = self._GetParamsCached()
assert len(params.input_names) == len(input_data)
g = ops.Graph()
with g.as_default():
io_ops = importer.import_graph_def(
graph_def=gdef,
- return_elements=params.input_names + [self.output_name],
+ return_elements=params.input_names + params.output_names,
name="")
- inp = [i.outputs[0] for i in io_ops[:-1]]
- assert len(inp) == len(input_data)
- out = io_ops[-1].outputs[0]
+ inputs = [op.outputs[0] for op in io_ops[:len(params.input_names)]]
+ assert len(inputs) == len(input_data)
+ outputs = [op.outputs[0] for op in io_ops[len(params.input_names):]]
with self.test_session(
graph=g, config=config, use_gpu=True, force_gpu=True) as sess:
val = None
# Defaults to 2 runs to verify result across multiple runs is same.
for _ in range(num_runs):
- self._PrepareRun(params, graph_state)
- new_val = sess.run(out,
- {inp[i]: input_data[i] for i in range(len(inp))})
- self.assertEqual(params.expected_output_dims, new_val.shape)
+ self._PrepareRun(graph_state)
+ new_val = sess.run(
+ outputs, {inputs[i]: input_data[i] for i in range(len(inputs))})
+ output_len = len(params.expected_output_dims)
+ self.assertEqual(output_len, len(new_val))
+ for i in range(output_len):
+ self.assertEqual(params.expected_output_dims[i], new_val[i].shape)
if val is not None:
- self.assertAllEqual(val, new_val)
+ self.assertAllClose(val, new_val, atol=1.e-06, rtol=1.e-06)
val = new_val
- self._VerifyRun(params, graph_state)
+ self.VerifyRun(run_params, graph_state)
return val
# Use real data that is representative of the inference dataset
# for calibration. For this test script it is random data.
- def _RunCalibration(self, params, gdef, input_data, config):
+ def _RunCalibration(self, run_params, gdef, input_data, config):
"""Run calibration on given graph."""
return self._RunGraph(
- params, gdef, input_data, config, GraphState.CALIBRATE, num_runs=5)
+ run_params, gdef, input_data, config, GraphState.CALIBRATE, num_runs=5)
- def _GetTrtGraphDef(self, params, run_params, gdef):
+ def _GetTrtGraphDef(self, run_params, gdef):
"""Return trt converted graphdef."""
+ params = self._GetParamsCached()
+ trt_params = self.GetConversionParams(run_params)
+ logging.info(trt_params)
return trt_convert.create_inference_graph(
input_graph_def=gdef,
- outputs=[self.output_name],
- max_batch_size=max([dims[0] for dims in params.input_dims]),
- max_workspace_size_bytes=1 << 25,
- precision_mode=run_params.precision_mode,
- minimum_segment_size=2,
- is_dynamic_op=run_params.dynamic_engine)
-
- def _WriteGraph(self, params, run_params, gdef, graph_state):
+ outputs=params.input_names + params.output_names,
+ max_batch_size=trt_params.max_batch_size,
+ max_workspace_size_bytes=trt_params.max_workspace_size_bytes,
+ precision_mode=trt_params.precision_mode,
+ minimum_segment_size=trt_params.minimum_segment_size,
+ is_dynamic_op=trt_params.is_dynamic_op,
+ maximum_cached_engines=trt_params.maximum_cached_engines,
+ cached_engine_batches=trt_params.cached_engine_batches)
+
+ def _WriteGraph(self, run_params, gdef, graph_state):
if graph_state == GraphState.ORIGINAL:
label = "Original"
elif graph_state == GraphState.CALIBRATE:
@@ -247,15 +321,17 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
self.__class__.__name__ + "_" + run_params.test_name + "_" + label +
".pbtxt")
temp_dir = os.getenv("TRT_TEST_TMPDIR", self.get_temp_dir())
- logging.info("Writing graph to %s/%s", temp_dir, graph_name)
- graph_io.write_graph(gdef, temp_dir, graph_name)
+ if temp_dir:
+ logging.info("Writing graph to %s/%s", temp_dir, graph_name)
+ graph_io.write_graph(gdef, temp_dir, graph_name)
- def _VerifyConnections(self, params, converted_gdef):
+ def _VerifyConnections(self, expected_engines, converted_gdef):
+ params = self._GetParamsCached()
old_to_new_node_map = {
self._ToString(node.name): self._ToString(node.name)
for node in params.gdef.node
}
- for engine_name, node_names in params.expected_engines.items():
+ for engine_name, node_names in expected_engines.items():
for node_name in node_names:
old_to_new_node_map[node_name] = engine_name
name_to_node_map = {
@@ -310,97 +386,114 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
msg="expected:\n%s\nvs actual:\n%s" % (sorted(
expected_input_map.items()), sorted(actual_input_map.items())))
- def _VerifyGraphDef(self, params, run_params, gdef, graph_state):
- self._WriteGraph(params, run_params, gdef, graph_state)
+ def _VerifyGraphDef(self, run_params, gdef, graph_state):
+ self._WriteGraph(run_params, gdef, graph_state)
+ expected_engines = self.ExpectedEnginesToBuild(run_params)
num_engines = 0
for node in gdef.node:
if node.op == "TRTEngineOp":
+ logging.info("Found TRTEngineOp: " + node.name)
+ for node in gdef.node:
+ if node.op == "TRTEngineOp":
num_engines += 1
- self.assertTrue(node.name in params.expected_engines)
- self.assertTrue(len(node.attr["serialized_segment"].s))
- self.assertTrue(len(node.attr["segment_funcdef_name"].s))
+ self.assertTrue(node.name in expected_engines, node.name)
+ self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
+ self.assertTrue(len(node.attr["segment_funcdef_name"].s), node.name)
self.assertEqual(
self._ToBytes(run_params.precision_mode),
- node.attr["precision_mode"].s)
+ node.attr["precision_mode"].s, node.name)
is_dynamic_engine = not node.attr["static_engine"].b
- self.assertEqual(run_params.dynamic_engine, is_dynamic_engine)
+ self.assertEqual(run_params.dynamic_engine, is_dynamic_engine,
+ node.name)
has_calibration_data = len(node.attr["calibration_data"].s)
- if (_IsQuantizationMode(run_params.precision_mode) and
+ if (IsQuantizationMode(run_params.precision_mode) and
graph_state == GraphState.INFERENCE):
- self.assertTrue(has_calibration_data)
+ self.assertTrue(has_calibration_data, node.name)
else:
- self.assertFalse(has_calibration_data)
+ self.assertFalse(has_calibration_data, node.name)
if graph_state == GraphState.ORIGINAL:
self.assertEqual(0, num_engines)
else:
- self.assertEqual(num_engines, len(params.expected_engines))
- if isinstance(params.expected_engines, dict):
- self._VerifyConnections(params, gdef)
+ self.assertEqual(num_engines, len(expected_engines))
+ if isinstance(expected_engines, dict):
+ self._VerifyConnections(expected_engines, gdef)
# TODO(aaroey): consider verifying the corresponding TF function.
- def RunTest(self, params, run_params):
+ def RunTest(self, run_params):
+ if not self.ShouldRunTest(run_params):
+ return
assert run_params.precision_mode in PRECISION_MODES
- input_data = [np.random.random_sample(dims) for dims in params.input_dims]
+
+ params = self._GetParamsCached()
input_gdef = params.gdef
- self._VerifyGraphDef(params, run_params, input_gdef, GraphState.ORIGINAL)
+ input_dtypes = {}
+ for node in input_gdef.node:
+ if self._ToString(node.name) in params.input_names:
+ assert self._ToString(node.op) == "Placeholder"
+ input_dtypes[self._ToString(node.name)] = (
+ dtypes.as_dtype(node.attr["dtype"].type).as_numpy_dtype())
+ assert len(params.input_names) == len(input_dtypes)
+
+ input_data = []
+ for i in range(len(params.input_names)):
+ dtype = input_dtypes[params.input_names[i]]
+ # Multiply the input by some constant to avoid all zeros input for integer
+ # types.
+ scale = 10.0 if np.issubdtype(dtype, np.integer) else 1.0
+ dims = params.input_dims[i]
+ input_data.append((scale * np.random.random_sample(dims)).astype(dtype))
+ self._VerifyGraphDef(run_params, input_gdef, GraphState.ORIGINAL)
# Get reference result without running trt.
- config_no_trt = self._GetConfigProto(params, run_params,
- GraphState.ORIGINAL)
+ config_no_trt = self._GetConfigProto(run_params, GraphState.ORIGINAL)
logging.info("Running original graph w/o trt, config:\n%s",
str(config_no_trt))
- ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt,
- GraphState.ORIGINAL)
+ ref_result = self._RunGraph(run_params, input_gdef, input_data,
+ config_no_trt, GraphState.ORIGINAL)
# Run calibration if necessary.
- if _IsQuantizationMode(run_params.precision_mode):
+ if IsQuantizationMode(run_params.precision_mode):
- calib_config = self._GetConfigProto(params, run_params,
- GraphState.CALIBRATE)
+ calib_config = self._GetConfigProto(run_params, GraphState.CALIBRATE)
logging.info("Running calibration graph, config:\n%s", str(calib_config))
if run_params.use_optimizer:
- result = self._RunCalibration(params, input_gdef, input_data,
+ result = self._RunCalibration(run_params, input_gdef, input_data,
calib_config)
else:
- calib_gdef = self._GetTrtGraphDef(params, run_params, input_gdef)
- self._VerifyGraphDef(params, run_params, calib_gdef,
- GraphState.CALIBRATE)
- result = self._RunCalibration(params, calib_gdef, input_data,
+ calib_gdef = self._GetTrtGraphDef(run_params, input_gdef)
+ self._VerifyGraphDef(run_params, calib_gdef, GraphState.CALIBRATE)
+ result = self._RunCalibration(run_params, calib_gdef, input_data,
calib_config)
- infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef)
- self._VerifyGraphDef(params, run_params, infer_gdef, GraphState.INFERENCE)
+ infer_gdef = trt_convert.calib_graph_to_infer_graph(
+ calib_gdef, run_params.dynamic_engine)
+ self._VerifyGraphDef(run_params, infer_gdef, GraphState.INFERENCE)
self.assertAllClose(
ref_result,
result,
- atol=params.allclose_atol,
- rtol=params.allclose_rtol)
+ atol=self.ExpectedAbsoluteTolerance(run_params),
+ rtol=self.ExpectedRelativeTolerance(run_params))
else:
infer_gdef = input_gdef
# Run inference.
- infer_config = self._GetConfigProto(params, run_params,
- GraphState.INFERENCE)
+ infer_config = self._GetConfigProto(run_params, GraphState.INFERENCE)
logging.info("Running final inference graph, config:\n%s",
str(infer_config))
- if run_params.use_optimizer:
- result = self._RunGraph(params, infer_gdef, input_data, infer_config,
- GraphState.INFERENCE)
- else:
- trt_infer_gdef = self._GetTrtGraphDef(params, run_params, infer_gdef)
- self._VerifyGraphDef(params, run_params, trt_infer_gdef,
- GraphState.INFERENCE)
- result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config,
- GraphState.INFERENCE)
+ if not run_params.use_optimizer:
+ infer_gdef = self._GetTrtGraphDef(run_params, infer_gdef)
+ self._VerifyGraphDef(run_params, infer_gdef, GraphState.INFERENCE)
+ result = self._RunGraph(run_params, infer_gdef, input_data, infer_config,
+ GraphState.INFERENCE)
self.assertAllClose(
ref_result,
result,
- atol=params.allclose_atol,
- rtol=params.allclose_rtol)
+ atol=self.ExpectedAbsoluteTolerance(run_params),
+ rtol=self.ExpectedRelativeTolerance(run_params))
def testIdempotence(self):
# Test that applying tensorrt optimizer or offline conversion tools multiple
@@ -421,13 +514,12 @@ def _AddTests(test_class):
"""Gets a single test method based on the parameters."""
def _Test(self):
- params = self.GetParams()
logging.info(
"Running test %s with parameters: use_optimizer=%s, "
"precision_mode=%s, dynamic_engine=%s",
"testTfTrt_" + run_params.test_name, run_params.use_optimizer,
run_params.precision_mode, run_params.dynamic_engine)
- self.RunTest(params, run_params)
+ self.RunTest(run_params)
return _Test
@@ -435,7 +527,7 @@ def _AddTests(test_class):
dynamic_engine_options = [False, True]
for (use_optimizer, precision_mode, dynamic_engine) in itertools.product(
use_optimizer_options, PRECISION_MODES, dynamic_engine_options):
- if _IsQuantizationMode(precision_mode):
+ if IsQuantizationMode(precision_mode):
if use_optimizer:
# TODO(aaroey): if use_optimizer is True we need to get the inference
# graphdef using custom python wrapper class, which is not currently
diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/contrib/tensorrt/test/unary_test.py
index 500057a36d..8736bfb644 100644
--- a/tensorflow/contrib/tensorrt/test/unary_test.py
+++ b/tensorflow/contrib/tensorrt/test/unary_test.py
@@ -38,6 +38,7 @@ class UnaryTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [12, 5, 8, 1, 1, 12]
+ output_name = "output"
input2_name = "input_2"
input2_dims = [12, 5, 8, 1, 12, 1, 1]
g = ops.Graph()
@@ -95,18 +96,20 @@ class UnaryTest(trt_test.TfTrtIntegrationTestBase):
q = a * b
q = q / c
- array_ops.squeeze(q, name=self.output_name)
+ array_ops.squeeze(q, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name, input2_name],
input_dims=[input_dims, input2_dims],
- expected_engines=[
- "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3",
- "my_trt_op_4"
- ],
- expected_output_dims=(12, 5, 8, 12),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(12, 5, 8, 12)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return [
+ "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3",
+ "my_trt_op_4"
+ ]
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
index ab4d224db4..b0271a04b3 100644
--- a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
+++ b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
@@ -38,15 +38,14 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [5, 2, 8, 8]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
x, _, _ = nn_impl.fused_batch_norm(
- x,
- np.random.randn(2).astype(np.float32),
- np.random.randn(2).astype(np.float32),
- mean=np.random.randn(2).astype(np.float32),
- variance=np.random.randn(2).astype(np.float32),
+ x, [1.0, 1.0], [0.0, 0.0],
+ mean=[0.5, 0.5],
+ variance=[1.0, 1.0],
data_format="NCHW",
is_training=False)
e = constant_op.constant(
@@ -67,15 +66,17 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase):
"VALID",
data_format="NCHW",
name="max_pool")
- array_ops.squeeze(v, name="output")
+ array_ops.squeeze(v, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(5, 6, 2, 2),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(5, 6, 2, 2)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ["my_trt_op_0"]
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_test.py
index 56bdf848ea..d7c165784b 100644
--- a/tensorflow/contrib/tensorrt/test/vgg_block_test.py
+++ b/tensorflow/contrib/tensorrt/test/vgg_block_test.py
@@ -38,15 +38,14 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [5, 8, 8, 2]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
x, _, _ = nn_impl.fused_batch_norm(
- x,
- np.random.randn(2).astype(np.float32),
- np.random.randn(2).astype(np.float32),
- mean=np.random.randn(2).astype(np.float32),
- variance=np.random.randn(2).astype(np.float32),
+ x, [1.0, 1.0], [0.0, 0.0],
+ mean=[0.5, 0.5],
+ variance=[1.0, 1.0],
is_training=False)
e = constant_op.constant(
np.random.randn(1, 1, 2, 6), name="weights", dtype=dtype)
@@ -58,15 +57,17 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase):
idty = array_ops.identity(relu, "ID")
v = nn_ops.max_pool(
idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
- array_ops.squeeze(v, name="output")
+ array_ops.squeeze(v, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(5, 2, 2, 6),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(5, 2, 2, 6)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ["my_trt_op_0"]
if __name__ == "__main__":
diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD
index 355303acf6..71b0d48798 100644
--- a/tensorflow/contrib/timeseries/examples/BUILD
+++ b/tensorflow/contrib/timeseries/examples/BUILD
@@ -16,6 +16,7 @@ config_setting(
py_binary(
name = "predict",
srcs = ["predict.py"],
+ data = ["data/period_trend.csv"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = select({
@@ -31,7 +32,6 @@ py_test(
name = "predict_test",
timeout = "long", # Moderate but for asan
srcs = ["predict_test.py"],
- data = ["data/period_trend.csv"],
srcs_version = "PY2AND3",
tags = [
"no_windows", # TODO: needs investigation on Windows
diff --git a/tensorflow/contrib/timeseries/examples/predict.py b/tensorflow/contrib/timeseries/examples/predict.py
index 8147d40caa..b036911314 100644
--- a/tensorflow/contrib/timeseries/examples/predict.py
+++ b/tensorflow/contrib/timeseries/examples/predict.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import argparse
+import os
import sys
import numpy as np
@@ -40,6 +41,10 @@ except ImportError:
FLAGS = None
+_MODULE_PATH = os.path.dirname(__file__)
+_DEFAULT_DATA_FILE = os.path.join(_MODULE_PATH, "data/period_trend.csv")
+
+
def structural_ensemble_train_and_predict(csv_file_name):
# Cycle between 5 latent values over a period of 100. This leads to a very
# smooth periodic component (and a small model), which is a good fit for our
@@ -115,9 +120,12 @@ def main(unused_argv):
if not HAS_MATPLOTLIB:
raise ImportError(
"Please install matplotlib to generate a plot from this example.")
+ input_filename = FLAGS.input_filename
+ if input_filename is None:
+ input_filename = _DEFAULT_DATA_FILE
make_plot("Structural ensemble",
- *structural_ensemble_train_and_predict(FLAGS.input_filename))
- make_plot("AR", *ar_train_and_predict(FLAGS.input_filename))
+ *structural_ensemble_train_and_predict(input_filename))
+ make_plot("AR", *ar_train_and_predict(input_filename))
pyplot.show()
@@ -126,7 +134,7 @@ if __name__ == "__main__":
parser.add_argument(
"--input_filename",
type=str,
- required=True,
- help="Input csv file.")
+ required=False,
+ help="Input csv file (omit to use the data/period_trend.csv).")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py
index 5eb4deefb9..de547f835d 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py
@@ -195,7 +195,7 @@ class ARModelTest(test.TestCase):
self.train_helper(input_window_size=10,
loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS,
train_steps=300,
- max_loss=2.5,
+ max_loss=50., # Just make sure there are no exceptions.
anomaly_distribution=None)
def test_autoregression_normal_multiple_periods(self):
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
index 983455f63d..461fe22210 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
@@ -69,8 +69,10 @@ class TimeSeriesRegressorTest(test.TestCase):
input_pipeline.NumpyReader(features), shuffle_seed=3, num_threads=1,
batch_size=16, window_size=16)
first_estimator.train(input_fn=train_input_fn, steps=1)
- first_loss_before_fit = first_estimator.evaluate(
- input_fn=eval_input_fn, steps=1)["loss"]
+ first_evaluation = first_estimator.evaluate(
+ input_fn=eval_input_fn, steps=1)
+ first_loss_before_fit = first_evaluation["loss"]
+ self.assertAllEqual(first_loss_before_fit, first_evaluation["average_loss"])
self.assertAllEqual([], first_loss_before_fit.shape)
first_estimator.train(input_fn=train_input_fn, steps=1)
first_loss_after_fit = first_estimator.evaluate(
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py
index 32194e400e..1f9f9b7aa6 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head.py
@@ -30,6 +30,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.summary import summary
@@ -123,6 +124,8 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce
metrics[feature_keys.FilteringResults.STATE_TUPLE] = (
_identity_metric_nested(feature_keys.FilteringResults.STATE_TUPLE,
model_outputs.end_state))
+ metrics[metric_keys.MetricKeys.LOSS_MEAN] = metrics_impl.mean(
+ model_outputs.loss, name="average_loss")
return estimator_lib.EstimatorSpec(
loss=model_outputs.loss,
mode=mode,
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
index bda3b53aca..e65e7b74d4 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
@@ -172,6 +172,7 @@ class EvaluationMetricsTests(test.TestCase):
evaluation = estimator.evaluate(input_fn, steps=1)
self.assertIn("plain_boring_metric386", evaluation)
self.assertIn("fun_metric101", evaluation)
+ self.assertIn("average_loss", evaluation)
# The values are deterministic because of fixed tf_random_seed.
# However if they become flaky, remove such exacts comparisons.
self.assertAllClose(evaluation["plain_boring_metric386"], 1.130380)
@@ -398,6 +399,7 @@ class OneShotTests(parameterized.TestCase):
num_threads=1, batch_size=16, window_size=16)
estimator.train(input_fn=train_input_fn, steps=5)
result = estimator.evaluate(input_fn=train_input_fn, steps=1)
+ self.assertIn("average_loss", result)
self.assertNotIn(feature_keys.State.STATE_TUPLE, result)
input_receiver_fn = estimator.build_raw_serving_input_receiver_fn()
export_location = estimator.export_savedmodel(_new_temp_dir(),
diff --git a/tensorflow/contrib/tpu/profiler/op_profile.proto b/tensorflow/contrib/tpu/profiler/op_profile.proto
index 1f249de314..feb177a7da 100644
--- a/tensorflow/contrib/tpu/profiler/op_profile.proto
+++ b/tensorflow/contrib/tpu/profiler/op_profile.proto
@@ -8,6 +8,8 @@ message Profile {
Node by_category = 1;
// Root of a profile broken down by program structure.
Node by_program_structure = 2;
+ // Per program profile, indexed by hlo module name of the program.
+ map<string, Node> per_program = 3;
}
// An entry in the profile tree. (An instruction, or set of instructions).
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index ff893a722f..a5e8277ba5 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -54,7 +54,7 @@ import time
import numpy as np
-from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver
+from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver as tpu_cluster_resolver_lib
from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result
from tensorflow.contrib.tpu.python.ops import tpu_ops
@@ -80,12 +80,54 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import tf_inspect
+
+
+_SESSIONS = {}
+
+
+def tpu_session(cluster_resolver):
+ """Construct or return a `tf.Session` connected to the given cluster."""
+ global _SESSIONS
+ master = cluster_resolver.master()
+ if master not in _SESSIONS:
+ cluster_spec = cluster_resolver.cluster_spec()
+ config = config_pb2.ConfigProto(isolate_session_state=True)
+ if cluster_spec:
+ config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
+
+ graph = ops.Graph()
+ session = tf_session.Session(graph=graph, target=master, config=config)
+
+ with graph.as_default():
+ session.run(tpu.initialize_system())
+
+ _SESSIONS[master] = session
+ return _SESSIONS[master]
+
+
+def reset_tpu_sessions():
+ _SESSIONS.clear()
# Work-around dependency cycle between DistributionStrategy and TPU lib.
-def TPUDistributionStrategy(*args, **kw): # pylint: disable=invalid-name
+def TPUDistributionStrategy(tpu_cluster_resolver=None): # pylint: disable=invalid-name
+ """Construct a TPUDistributionStrategy."""
from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top
- return tpu_strategy.TPUStrategy(*args, **kw)
+ # TODO -- remove this when TPUStrategy API is consistent (b/112705069)
+ if tpu_cluster_resolver is None:
+ tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('')
+
+ args, _, _, _ = tf_inspect.getargspec(tpu_strategy.TPUStrategy.__init__)
+ if len(args) == 3:
+ logging.info('Detected new TPUStrategy API.')
+ return tpu_strategy.TPUStrategy(tpu_cluster_resolver, steps_per_run=1)
+ else:
+ logging.info('Detected old TPUStrategy API.')
+ strategy = tpu_strategy.TPUStrategy(num_cores_per_host=8)
+ strategy._tpu_cluster_resolver = tpu_cluster_resolver
+
+ return strategy
class TPUEmbedding(embeddings.Embedding):
@@ -666,9 +708,10 @@ class TPUFunction(object):
# Clone our CPU model, running within the TPU device context.
with TPURewriteContext(tpu_input_map):
- # TODO(power): Replicate variables.
- with ops.device('/device:TPU:0'):
- self._cloned_model = models.clone_model(self.model)
+ with variable_scope.variable_scope('tpu_model_%s' % id(self.model)):
+ # TODO(power): Replicate variables.
+ with ops.device('/device:TPU:0'):
+ self._cloned_model = models.clone_model(self.model)
# Create a copy of the optimizer for this graph.
if isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
@@ -845,7 +888,7 @@ class TPUFunction(object):
class KerasTPUModel(models.Model):
"""TPU compatible Keras model wrapper."""
- def __init__(self, cpu_model, tpu_name_or_address, strategy):
+ def __init__(self, cpu_model, strategy):
super(models.Model, self).__init__( # pylint: disable=bad-super-call
inputs=cpu_model.inputs,
outputs=cpu_model.outputs,
@@ -862,27 +905,14 @@ class KerasTPUModel(models.Model):
self.train_function = None
self._strategy = strategy
- self._tpu_name_or_address = tpu_name_or_address
+ cluster_resolver = self._strategy._tpu_cluster_resolver
+ self._tpu_name_or_address = cluster_resolver.get_master()
self._cpu_model = cpu_model
self._tpu_model = None
self._tpu_weights_initialized = False
- self._graph = ops.Graph()
-
- self._cluster_resolver = tpu_cluster_resolver.TPUClusterResolver(
- tpu_name_or_address)
- master = self._cluster_resolver.master()
- cluster_spec = self._cluster_resolver.cluster_spec()
- self._session = tf_session.Session(
- graph=self._graph,
- target=master,
- config=config_pb2.ConfigProto(isolate_session_state=True))
-
- # TODO(saeta): Confirm the lines below work in ClusterSpec propagation env.
- if cluster_spec:
- self._session.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
- with self._graph.as_default():
- self._session.run(tpu.initialize_system())
+ self._session = tpu_session(cluster_resolver)
+ self._graph = self._session.graph
# If the input CPU model has already been compiled, compile our TPU model
# immediately.
@@ -1137,7 +1167,7 @@ Output shape: %(output_shape)s
@experimental
-def tpu_model(model, tpu_name_or_address=None, strategy=None):
+def tpu_model(model, strategy=None):
"""Copy `model` along with weights to the TPU. Returns a TPU model.
Usage:
@@ -1148,7 +1178,7 @@ def tpu_model(model, tpu_name_or_address=None, strategy=None):
# If `num_cores_per_host` is greater than one, batch parallelism will be used
# to run on multiple TPU cores.
- strategy = keras_support.TPUDistributionStrategy(num_cores_per_host=8)
+ strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
model = keras_support.tpu_model(model, strategy)
model.compile(
optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0),
@@ -1158,10 +1188,6 @@ def tpu_model(model, tpu_name_or_address=None, strategy=None):
Args:
model: A `KerasTPUModel`.
- tpu_name_or_address: A string that is either the name of the Cloud TPU,
- the grpc address of the Cloud TPU, or (Googlers only) the BNS name of the
- Cloud TPU. If tpu_name_or_address is None, the TPUClusterResolver will
- examine the environment to determine a potential Cloud TPU to use.
strategy: `TPUDistributionStrategy`. The strategy to use for replicating
model across multiple TPU cores.
@@ -1176,9 +1202,8 @@ def tpu_model(model, tpu_name_or_address=None, strategy=None):
# TODO(xiejw): Validate TPU model. TPUModel only?
# TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset?
# TODO(xiejw): Adds reduction option.
+
if strategy is None:
- strategy = TPUDistributionStrategy(num_cores_per_host=1)
- return KerasTPUModel(
- cpu_model=model,
- tpu_name_or_address=tpu_name_or_address,
- strategy=strategy)
+ strategy = TPUDistributionStrategy()
+
+ return KerasTPUModel(cpu_model=model, strategy=strategy)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index f221155568..fed07f00e7 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -762,9 +762,13 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
if not is_dataset:
raise TypeError('`input_fn` must return a `Dataset` for the PER_HOST_V2 '
'input pipeline configuration.')
+
if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
- # TODO(b/XXX): Add predict support for PER_HOST_V2
- raise TypeError('Most PREDICT not yet supported in PER_HOST_V2 mode.')
+ inputs = _InputsWithStoppingSignals(
+ dataset=inputs.dataset,
+ batch_size=ctx.batch_size_for_input_fn,
+ add_padding=True,
+ num_invocations_per_step=ctx.num_of_replicas_per_host)
hooks.append(inputs.dataset_initializer_hook())
tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)
@@ -774,6 +778,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
control_deps = []
per_host_sharded_inputs = []
num_replicas_per_host = ctx.num_of_replicas_per_host
+ cached_signals = None
with ops.device(device):
if not inputs.is_dataset:
raise TypeError('`input_fn` must return a `Dataset` for this mode.')
@@ -781,12 +786,20 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
# Use control dependencies to ensure a deterministic ordering.
with ops.control_dependencies(control_deps):
features, labels = inputs.features_and_labels() # Calls get_next()
+ signals = inputs.signals()
+
+ # All the replicas share the replica 0's stopping singal.
+ # This avoids inconsistent state among different model replcias.
+ if cached_signals:
+ signals['stopping'] = cached_signals['stopping']
+ else:
+ cached_signals = signals
inputs_structure_recorder.validate_and_record_structure(
features, labels)
flattened_inputs = (
inputs_structure_recorder.flatten_features_and_labels(
- features, labels))
+ features, labels, signals))
control_deps.extend(flattened_inputs)
per_host_sharded_inputs.append(flattened_inputs)
@@ -807,7 +820,13 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
tpu_ordinal_function=tpu_ordinal_function_impl)
captured_infeed_queue.capture(infeed_queue)
- return per_host_enqueue_ops
+ if signals is None:
+ return per_host_enqueue_ops
+ else:
+ return {
+ 'ops': per_host_enqueue_ops,
+ 'signals': signals,
+ }
return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset
@@ -3043,16 +3062,48 @@ class _Inputs(object):
class _InputsWithStoppingSignals(_Inputs):
"""Inputs with `_StopSignals` inserted into the dataset."""
- def __init__(self, dataset, batch_size, add_padding=False):
+ def __init__(self,
+ dataset,
+ batch_size,
+ add_padding=False,
+ num_invocations_per_step=1):
assert dataset is not None
-
user_provided_dataset = dataset.map(
_InputsWithStoppingSignals.insert_stopping_signal(
stop=False, batch_size=batch_size, add_padding=add_padding))
- final_batch_dataset = dataset.take(1).map(
- _InputsWithStoppingSignals.insert_stopping_signal(
- stop=True, batch_size=batch_size, add_padding=add_padding))
+ if num_invocations_per_step == 1:
+ final_batch_dataset = dataset.take(1).map(
+ _InputsWithStoppingSignals.insert_stopping_signal(
+ stop=True, batch_size=batch_size, add_padding=add_padding))
+ else:
+ # We append (2 * num_invocations_per_step - 1) batches for exhausting the
+ # user_provided_dataset and stop properly.
+ # For example, if num_invocations_per_step is 2, we append 3 additional
+ # padding batches: b1, b2, b3.
+ # If user_provided_dataset contains two batches: a1, a2
+ # Step 1: [a1, a2]
+ # Step 2: [b1, b2] -> STOP
+ # If user_provided_dataset contains three batches: a1, a2, a3.
+ # The training loops:
+ # Step 1: [a1, a2]
+ # Step 2: [a3, b1]
+ # Step 3: [b2, b3] -> STOP.
+ final_batch_dataset = dataset.take(1).map(
+ _InputsWithStoppingSignals.insert_stopping_signal(
+ stop=True, batch_size=batch_size, add_padding=add_padding))
+ final_batch_dataset = final_batch_dataset.repeat(
+ 2 * num_invocations_per_step - 1)
+
+ def _set_mask(data_dict):
+ signals = data_dict['signals']
+ signals['padding_mask'] = array_ops.ones_like(signals['padding_mask'])
+ data_dict['signals'] = signals
+ return data_dict
+
+ # Mask out the extra batch.
+ final_batch_dataset = final_batch_dataset.map(_set_mask)
+
dataset = user_provided_dataset.concatenate(final_batch_dataset).prefetch(2)
super(_InputsWithStoppingSignals, self).__init__(dataset=dataset)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py
index 3e90957e6d..bd530fdc3a 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py
@@ -286,6 +286,59 @@ class TPUEstimatorStoppingSignalsWithPaddingTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(sliced_features)
+ def test_slice_with_multi_invocations_per_step(self):
+ num_samples = 3
+ batch_size = 2
+
+ params = {'batch_size': batch_size}
+ input_fn, (a, b) = make_input_fn(num_samples=num_samples)
+
+ with ops.Graph().as_default():
+ dataset = input_fn(params)
+ inputs = tpu_estimator._InputsWithStoppingSignals(
+ dataset, batch_size, add_padding=True, num_invocations_per_step=2)
+ hook = inputs.dataset_initializer_hook()
+ features, _ = inputs.features_and_labels()
+ signals = inputs.signals()
+
+ sliced_features = (
+ tpu_estimator._PaddingSignals.slice_tensor_or_dict(features, signals))
+
+ with session.Session() as sess:
+ hook.begin()
+ hook.after_create_session(sess, coord=None)
+
+ result, evaluated_signals = sess.run([sliced_features, signals])
+ self.assertAllEqual(a[:batch_size], result['a'])
+ self.assertAllEqual(b[:batch_size], result['b'])
+ self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+
+ # This is the final partial batch.
+ result, evaluated_signals = sess.run([sliced_features, signals])
+ self.assertEqual(1, len(result['a']))
+ self.assertAllEqual(a[batch_size:num_samples], result['a'])
+ self.assertAllEqual(b[batch_size:num_samples], result['b'])
+ self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+
+ # We should see 3 continuous batches with STOP ('1') as signals and all
+ # of them have mask 1.
+ _, evaluated_signals = sess.run([sliced_features, signals])
+ self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
+ self.assertAllEqual([1.] * batch_size,
+ evaluated_signals['padding_mask'])
+
+ _, evaluated_signals = sess.run([sliced_features, signals])
+ self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
+ self.assertAllEqual([1.] * batch_size,
+ evaluated_signals['padding_mask'])
+
+ _, evaluated_signals = sess.run([sliced_features, signals])
+ self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
+ self.assertAllEqual([1.] * batch_size,
+ evaluated_signals['padding_mask'])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(sliced_features)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/training/__init__.py b/tensorflow/contrib/training/__init__.py
index edd71fb250..3547e71184 100644
--- a/tensorflow/contrib/training/__init__.py
+++ b/tensorflow/contrib/training/__init__.py
@@ -14,7 +14,9 @@
# ==============================================================================
"""Training and input utilities.
-See @{$python/contrib.training} guide.
+See
+[Contrib Training](https://tensorflow.org/api_guides/python/contrib.training)
+guide.
@@batch_sequences_with_states
@@NextQueuedSequenceBatch
diff --git a/tensorflow/contrib/util/__init__.py b/tensorflow/contrib/util/__init__.py
index 08741cf8ca..338acef63f 100644
--- a/tensorflow/contrib/util/__init__.py
+++ b/tensorflow/contrib/util/__init__.py
@@ -15,7 +15,7 @@
"""Utilities for dealing with Tensors.
-See @{$python/contrib.util} guide.
+See [Contrib Util](https://tensorflow.org/api_guides/python/contrib.util) guide.
@@constant_value
@@make_tensor_proto
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 9a8c20b1fd..515237ff29 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -121,6 +121,7 @@ load(
"tf_additional_minimal_lib_srcs",
"tf_additional_mpi_lib_defines",
"tf_additional_proto_hdrs",
+ "tf_additional_proto_compiler_hdrs",
"tf_additional_proto_srcs",
"tf_additional_test_deps",
"tf_additional_test_srcs",
@@ -128,6 +129,7 @@ load(
"tf_jspb_proto_library",
"tf_kernel_tests_linkstatic",
"tf_lib_proto_parsing_deps",
+ "tf_lib_proto_compiler_deps",
"tf_nano_proto_library",
"tf_platform_hdrs",
"tf_platform_srcs",
@@ -613,6 +615,17 @@ cc_library(
],
)
+cc_library(
+ name = "lib_proto_compiler",
+ hdrs = [
+ "platform/protobuf_compiler.h",
+ ] + tf_additional_proto_compiler_hdrs(),
+ copts = tf_copts(),
+ deps = tf_lib_proto_compiler_deps() + [
+ ":lib_proto_parsing",
+ ],
+)
+
# This build rule (along with :lib_internal, :framework, and
# :framework_internal) purposefully omits the definitions of many declared
# symbols, which are included in //tensorflow:libtensorflow_framework.so. Using
@@ -2337,12 +2350,12 @@ tf_generate_proto_text_sources(
srcs = COMMON_PROTO_SRCS,
protodeps = ERROR_CODES_PROTO_SRCS,
srcs_relative_dir = "tensorflow/core/",
+ visibility = ["//visibility:public"],
deps = [
":error_codes_proto_text",
":lib_internal",
":protos_all_proto_cc",
],
- visibility = ["//visibility:public"],
)
cc_library(
@@ -2450,10 +2463,10 @@ cc_header_only_library(
cc_header_only_library(
name = "core_cpu_headers_lib",
+ visibility = ["//visibility:public"],
deps = [
":core_cpu_lib",
],
- visibility = ["//visibility:public"],
)
tf_cuda_library(
@@ -2574,8 +2587,8 @@ tf_cuda_library(
# TODO(josh11b): Is this needed, or can we just use ":protos_all_cc"?
cc_library(
name = "protos_cc",
- deps = ["//tensorflow/core/platform/default/build_config:protos_cc"],
visibility = ["//visibility:public"],
+ deps = ["//tensorflow/core/platform/default/build_config:protos_cc"],
)
# Library containing all of the graph construction code that is
diff --git a/tensorflow/core/api_def/base_api/api_def_DivNoNan.pbtxt b/tensorflow/core/api_def/base_api/api_def_DivNoNan.pbtxt
new file mode 100644
index 0000000000..5604a1a89e
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_DivNoNan.pbtxt
@@ -0,0 +1,9 @@
+op {
+ graph_op_name: "DivNoNan"
+ summary: "Returns 0 if the denominator is zero."
+ description: <<END
+
+*NOTE*: `DivNoNan` supports broadcasting. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_Fill.pbtxt b/tensorflow/core/api_def/base_api/api_def_Fill.pbtxt
index 58262a385c..37d1a9dcbf 100644
--- a/tensorflow/core/api_def/base_api/api_def_Fill.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Fill.pbtxt
@@ -27,5 +27,15 @@ For example:
fill([2, 3], 9) ==> [[9, 9, 9]
[9, 9, 9]]
```
+
+`tf.fill` differs from `tf.constant` in a few ways:
+
+* `tf.fill` only supports scalar contents, whereas `tf.constant` supports
+ Tensor values.
+* `tf.fill` creates an Op in the computation graph that constructs the actual
+ Tensor value at runtime. This is in contrast to `tf.constant` which embeds
+ the entire Tensor into the graph with a `Const` node.
+* Because `tf.fill` evaluates at graph runtime, it supports dynamic shapes
+ based on other runtime Tensors, unlike `tf.constant`.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt
index 1a75e67c0c..e400c7402b 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt
@@ -70,5 +70,7 @@ The resulting update to ref would look like this:
See `tf.scatter_nd` for more details about how to make updates to
slices.
+
+See also `tf.scatter_update` and `tf.batch_scatter_update`.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt
index 4804908afc..4037dee432 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt
@@ -59,5 +59,7 @@ Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt>
</div>
+
+See also `tf.batch_scatter_update` and `tf.scatter_nd_update`.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
index 5e2912fcdd..35f55fe106 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
@@ -16,8 +16,9 @@ END
}
summary: "Computes the maximum along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Computes a tensor such that
\\(output_i = \max_j(data_j)\\) where `max` is over `j` such
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
index a7d85b3f4e..70a07d9b4c 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
@@ -16,8 +16,9 @@ END
}
summary: "Computes the mean along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Computes a tensor such that
\\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
index 74fc598218..b2e3eece38 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
@@ -16,8 +16,9 @@ END
}
summary: "Computes the minimum along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Computes a tensor such that
\\(output_i = \min_j(data_j)\\) where `min` is over `j` such
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
index 4c4363e524..7bac02e23d 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
@@ -16,8 +16,9 @@ END
}
summary: "Computes the product along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Computes a tensor such that
\\(output_i = \prod_j data_j\\) where the product is over `j` such
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
index 583ab3904f..a73306a892 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
@@ -16,8 +16,9 @@ END
}
summary: "Computes the sum along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Computes a tensor such that
\\(output_i = \sum_j data_j\\) where sum is over `j` such
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentMean.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMean.pbtxt
index 866e04e97b..138a6366c8 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentMean.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMean.pbtxt
@@ -21,8 +21,9 @@ END
}
summary: "Computes the mean along sparse segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
dimension, selecting a subset of dimension 0, specified by `indices`.
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt
index af4bc75fa0..b8073d88ac 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt
@@ -30,7 +30,8 @@ END
Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is
misisng, the `output` tensor at that position will be zeroed.
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtN.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtN.pbtxt
index 194bcea726..945bbdcf62 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtN.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtN.pbtxt
@@ -23,7 +23,8 @@ END
description: <<END
N is the size of the segment being reduced.
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt
index 8b502928a5..ff328c8a61 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt
@@ -32,7 +32,8 @@ N is the size of the segment being reduced.
Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is
misisng, the `output` tensor at that position will be zeroed.
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSum.pbtxt
index dfd50bf273..a68e14607f 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSum.pbtxt
@@ -21,8 +21,9 @@ END
}
summary: "Computes the sum along sparse segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
dimension, selecting a subset of dimension 0, specified by `indices`.
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt
index 3bc16577ff..aa5c1fc8d0 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt
@@ -30,8 +30,9 @@ END
Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
misisng, the `output` tensor at that position will be zeroed.
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
For example:
diff --git a/tensorflow/core/api_def/base_api/api_def_StaticRegexReplace.pbtxt b/tensorflow/core/api_def/base_api/api_def_StaticRegexReplace.pbtxt
new file mode 100644
index 0000000000..e382bcec81
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StaticRegexReplace.pbtxt
@@ -0,0 +1,26 @@
+op {
+ graph_op_name: "StaticRegexReplace"
+ in_arg {
+ name: "input"
+ description: "The text to be processed."
+ }
+ out_arg {
+ name: "output"
+ description: "The text after applying pattern and rewrite."
+ }
+ attr {
+ name: "pattern"
+ description: "The regular expression to match the input."
+ }
+ attr {
+ name: "rewrite"
+ description: "The rewrite to be applied to the matched expresion."
+ }
+ attr {
+ name: "replace_global"
+ description: "If True, the replacement is global, otherwise the replacement\nis done only on the first match."
+ }
+ summary: "Replaces the match of pattern in input with rewrite."
+ description: "It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt b/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt
new file mode 100644
index 0000000000..cc21ddc815
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt
@@ -0,0 +1,20 @@
+op {
+ graph_op_name: "StringLength"
+ in_arg {
+ name: "input"
+ description: <<END
+The string for which to compute the length.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+Integer tensor that has the same shape as `input`. The output contains the
+element-wise string lengths of `input`.
+END
+ }
+ summary: "String lengths of `input`."
+ description: <<END
+Computes the length of each string given in the input tensor.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt
deleted file mode 100644
index 82c913d15e..0000000000
--- a/tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt
+++ /dev/null
@@ -1,5 +0,0 @@
-op {
- graph_op_name: "UnsafeDiv"
- summary: "Returns 0 if the denominator is zero."
- description: ""
-}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
index 4ca6780c95..907c6d2022 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
@@ -16,8 +16,9 @@ END
}
summary: "Computes the maximum along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
This operator is similar to the unsorted segment sum operator found
[(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
index 55ea69b5dd..37dd973b23 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
@@ -16,8 +16,9 @@ END
}
summary: "Computes the minimum along segments of a tensor."
description: <<END
-Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+for an explanation of segments.
This operator is similar to the unsorted segment sum operator found
[(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
index 577ff53d60..efbc023705 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
@@ -16,8 +16,9 @@ END
}
summary: "Computes the product along segments of a tensor."
description: <<END
-Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+for an explanation of segments.
This operator is similar to the unsorted segment sum operator found
[(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
index 9aeabd030d..a8874950eb 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
@@ -16,8 +16,9 @@ END
}
summary: "Computes the sum along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Computes a tensor such that
\\(output[i] = sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
diff --git a/tensorflow/core/api_def/python_api/api_def_DivNoNan.pbtxt b/tensorflow/core/api_def/python_api/api_def_DivNoNan.pbtxt
new file mode 100644
index 0000000000..1bf3fba3c6
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_DivNoNan.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "DivNoNan"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt
new file mode 100644
index 0000000000..01c02e1f70
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt
@@ -0,0 +1,6 @@
+op {
+ graph_op_name: "StringLength"
+ endpoint {
+ name: "strings.length"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt
deleted file mode 100644
index 56caabcf3c..0000000000
--- a/tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt
+++ /dev/null
@@ -1,4 +0,0 @@
-op {
- graph_op_name: "UnsafeDiv"
- visibility: HIDDEN
-}
diff --git a/tensorflow/core/common_runtime/eager/attr_builder.cc b/tensorflow/core/common_runtime/eager/attr_builder.cc
index 92307d78f2..cf1cd4134e 100644
--- a/tensorflow/core/common_runtime/eager/attr_builder.cc
+++ b/tensorflow/core/common_runtime/eager/attr_builder.cc
@@ -103,7 +103,6 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) {
return *this; \
}
-DEFINE_SET_ATTR(StringPiece, string_attrs_);
DEFINE_SET_ATTR(float, float_attrs_);
DEFINE_SET_ATTR(int, int_attrs_);
DEFINE_SET_ATTR(bool, bool_attrs_);
@@ -119,9 +118,6 @@ AttrBuilder& AttrBuilder::NumInputs(int n) {
void AttrBuilder::FillAttrValueMap(AttrValueMap* m,
bool include_those_in_node_def) const {
- for (const auto& p : string_attrs_) {
- SetInAttrValueMap(m, p.first, p.second);
- }
for (const auto& p : int_attrs_) {
SetInAttrValueMap(m, p.first, p.second);
}
@@ -211,10 +207,6 @@ tensorflow::Fprint128 AttrBuilder::CacheKey(const string& device) const {
// not been called.
if (node_def_finalized_) return f;
}
- for (const auto& p : string_attrs_) {
- CombineUnordered(
- CacheKeyHelper(p.first, tensorflow::Fingerprint128(p.second)), &f);
- }
for (const auto& p : int_attrs_) {
CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
&f);
diff --git a/tensorflow/core/common_runtime/eager/attr_builder.h b/tensorflow/core/common_runtime/eager/attr_builder.h
index 929b1b8296..fc50bed3c0 100644
--- a/tensorflow/core/common_runtime/eager/attr_builder.h
+++ b/tensorflow/core/common_runtime/eager/attr_builder.h
@@ -131,7 +131,6 @@ class AttrBuilder {
}
}
- AttrVec<StringPiece> string_attrs_;
AttrVec<int> int_attrs_;
AttrVec<float> float_attrs_;
AttrVec<bool> bool_attrs_;
@@ -143,8 +142,6 @@ class AttrBuilder {
}; // namespace tensorflow
template <>
-AttrBuilder& AttrBuilder::Set(StringPiece attr_name, StringPiece&& value);
-template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, int&& value);
template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, float&& value);
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 5bdd547c7f..b859b06fa0 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
@@ -78,6 +79,12 @@ void EagerContext::InitDeviceMapAndAsync() {
}
}
}
+
+ DeviceSet ds;
+ for (Device* d : devices_) {
+ ds.AddDevice(d);
+ }
+ prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
}
bool EagerContext::Async() const {
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 9835b19511..3c95ac590d 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -93,6 +93,9 @@ class EagerContext {
// TODO(apassos) make this return a constant reference
std::vector<Device*>* devices() { return &devices_; }
+ const std::vector<DeviceType>& prioritized_device_type_list() {
+ return prioritized_device_type_list_;
+ }
// Clears the kernel caches.
void ClearCaches();
@@ -210,6 +213,7 @@ class EagerContext {
// Devices owned by device_manager
std::vector<Device*> devices_;
+ std::vector<DeviceType> prioritized_device_type_list_;
// All devices are not owned.
gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_;
Rendezvous* rendezvous_;
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 46065f399c..5b3a64ba98 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -192,17 +192,14 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
}
Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
- DeviceSet ds;
- for (Device* d : *ctx->devices()) {
- ds.AddDevice(d);
- }
DeviceTypeVector final_devices;
- auto status = SupportedDeviceTypesForNode(ds.PrioritizedDeviceTypeList(),
- ndef, &final_devices);
- if (!status.ok()) return status;
+ TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
+ ctx->prioritized_device_type_list(), ndef, &final_devices));
if (final_devices.empty()) {
- return errors::Internal("Could not find valid device for node ",
- ndef.DebugString());
+ return errors::Internal(
+ "Could not find valid device for node.\nNode: ", SummarizeNodeDef(ndef),
+ "\nAll kernels registered for op ", ndef.op(), " :\n",
+ KernelsRegisteredForOp(ndef.op()));
}
for (Device* d : *ctx->devices()) {
if (d->device_type() == final_devices[0].type_string()) {
@@ -211,7 +208,7 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
}
}
return errors::Unknown("Could not find a device for node ",
- ndef.DebugString());
+ SummarizeNodeDef(ndef));
}
Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 951bc4197e..63ed860b9f 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -72,141 +72,58 @@ bool IsInitializationOp(const Node* node) {
return node->op_def().allows_uninitialized_input();
}
-// Sets the timeline_label field of *node_stats, using data from *node.
-// Returns true iff the node is a transfer node.
-// TODO(tucker): merge with the DetailText function in session.cc
-// in a common location.
-bool SetTimelineLabel(const Node* node, NodeExecStatsWrapper* stats) {
- bool is_transfer_node = false;
- if (!stats) {
- return is_transfer_node;
- }
- string memory;
- for (auto& all : stats->stats()->memory()) {
- int64 tot = all.total_bytes();
- if (tot >= 0.1 * 1048576.0) {
- int64 peak = all.peak_bytes();
- if (peak > 0) {
- memory =
- strings::StrCat(memory, "[", all.allocator_name(),
- strings::Printf(" %.1fMB %.1fMB] ", tot / 1048576.0,
- peak / 1048576.0));
- } else {
- memory = strings::StrCat(memory, "[", all.allocator_name(),
- strings::Printf(" %.1fMB] ", tot / 1048576.0));
- }
- }
- }
- const AttrSlice attrs = node->attrs();
- string text;
- if (IsSend(node)) {
- string tensor_name;
- TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
- string recv_device;
- TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device));
- text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
- "(", tensor_name, " @", recv_device);
- is_transfer_node = true;
- } else if (IsRecv(node)) {
- string tensor_name;
- TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
- string send_device;
- TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device));
- text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
- "(", tensor_name, " @", send_device);
- is_transfer_node = true;
- } else {
- text =
- strings::StrCat(memory, node->name(), " = ", node->type_string(), "(",
- str_util::Join(node->requested_inputs(), ", "), ")");
- }
- stats->stats()->set_timeline_label(text);
- return is_transfer_node;
-}
-
// Helper routines for collecting step stats.
namespace nodestats {
-inline int64 NowInUsec() { return Env::Default()->NowMicros(); }
inline int64 NowInNsec() { return Env::Default()->NowNanos(); }
-void SetScheduled(NodeExecStatsWrapper* stats, int64 nanos) {
+void SetScheduled(NodeExecStatsWrapper* stats, int64 micros) {
if (!stats) return;
- stats->stats()->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos);
- stats->stats()->set_scheduled_nanos(nanos);
+ stats->SetScheduled(micros * EnvTime::kMicrosToNanos);
}
void SetAllStart(NodeExecStatsWrapper* stats) {
if (!stats) return;
- int64 now_nanos = NowInNsec();
- stats->stats()->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
- stats->stats()->set_all_start_nanos(now_nanos);
+ stats->RecordExecutorStarted();
}
void SetOpStart(NodeExecStatsWrapper* stats) {
if (!stats) return;
- NodeExecStats* nt = stats->stats();
- DCHECK_NE(nt->all_start_micros(), 0);
- DCHECK_NE(nt->all_start_nanos(), 0);
- int64 now_nanos = NowInNsec();
- nt->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- nt->all_start_micros());
- nt->set_op_start_rel_nanos(now_nanos - nt->all_start_nanos());
+ stats->RecordComputeStarted();
}
void SetOpEnd(NodeExecStatsWrapper* stats) {
if (!stats) return;
- NodeExecStats* nt = stats->stats();
- DCHECK_NE(nt->all_start_micros(), 0);
- DCHECK_NE(nt->all_start_nanos(), 0);
- int64 now_nanos = NowInNsec();
- nt->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- nt->all_start_micros());
- nt->set_op_end_rel_nanos(now_nanos - nt->all_start_nanos());
+ stats->RecordComputeEnded();
}
void SetAllEnd(NodeExecStatsWrapper* stats) {
if (!stats) return;
- NodeExecStats* nt = stats->stats();
- DCHECK_NE(nt->all_start_micros(), 0);
- DCHECK_NE(nt->all_start_nanos(), 0);
- int64 now_nanos = NowInNsec();
- nt->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- nt->all_start_micros());
- nt->set_all_end_rel_nanos(now_nanos - nt->all_start_nanos());
+ stats->RecordExecutorEnded();
}
void SetOutput(NodeExecStatsWrapper* stats, int slot, const Tensor* v) {
if (!stats) return;
- DCHECK(v);
- NodeOutput* no = stats->stats()->add_output();
- no->set_slot(slot);
- v->FillDescription(no->mutable_tensor_description());
+ stats->SetOutput(slot, v);
}
void SetMemory(NodeExecStatsWrapper* stats, OpKernelContext* ctx) {
if (!stats) return;
-
- for (const auto& allocator_pair : ctx->wrapped_allocators()) {
- stats->AddAllocation(allocator_pair.first, allocator_pair.second);
- }
- auto* ms = stats->stats()->mutable_memory_stats();
- ms->set_temp_memory_size(ctx->temp_memory_allocated());
- for (const auto& alloc_id : ctx->persistent_alloc_ids()) {
- ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id);
- }
- ms->set_persistent_memory_size(ctx->persistent_memory_allocated());
+ stats->SetMemory(ctx);
}
void SetReferencedTensors(NodeExecStatsWrapper* stats,
const TensorReferenceVector& tensors) {
if (!stats) return;
- // be careful not to increment the reference count on any tensor
- // while recording the information
- for (size_t i = 0; i < tensors.size(); ++i) {
- AllocationDescription* description =
- stats->stats()->add_referenced_tensor();
- tensors.at(i).FillDescription(description);
+ stats->SetReferencedTensors(tensors);
+}
+
+// Sets the timeline_label field of *stats, using data from *node.
+// Returns true iff the node is a transfer node.
+bool SetTimelineLabel(const Node* node, NodeExecStatsWrapper* stats) {
+ if (!stats) {
+ return false;
}
+ return stats->SetTimelineLabel(node);
}
} // namespace nodestats
@@ -1694,8 +1611,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
if (stats_collector_ && !tagged_node.is_dead) {
// track allocations if and only if we are collecting statistics
params.track_allocations = true;
- stats = new NodeExecStatsWrapper;
- stats->stats()->set_node_name(node->name());
+ stats = new NodeExecStatsWrapper(node->name());
nodestats::SetScheduled(stats, scheduled_nsec);
nodestats::SetAllStart(stats);
}
@@ -2165,7 +2081,8 @@ bool ExecutorState::NodeDone(const Status& s, const Node* node,
NodeExecStatsWrapper* stats,
TaggedNodeReadyQueue* inline_ready) {
nodestats::SetAllEnd(stats);
- if (stats_collector_ != nullptr && !SetTimelineLabel(node, stats)) {
+ if (stats_collector_ != nullptr &&
+ !nodestats::SetTimelineLabel(node, stats)) {
// Only record non-transfer nodes.
// Transfers 'stats' ownership to 'stats_collector_'.
stats_collector_->Save(impl_->params_.device->name(), stats);
diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc
index af6880c6b3..9c2510e6a9 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.cc
+++ b/tensorflow/core/common_runtime/step_stats_collector.cc
@@ -16,12 +16,16 @@ limitations under the License.
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/common_runtime/costmodel_manager.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/framework/tracking_allocator.h"
#include "tensorflow/core/graph/costmodel.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/scanner.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
@@ -36,11 +40,89 @@ struct AllocStats {
};
} // namespace
-NodeExecStatsWrapper::NodeExecStatsWrapper()
- : NodeExecStatsWrapper(new NodeExecStats) {}
+NodeExecStatsWrapper::NodeExecStatsWrapper(const string& node_name)
+ : NodeExecStatsWrapper(new NodeExecStats) {
+ stats_->set_node_name(node_name);
+}
NodeExecStatsWrapper::NodeExecStatsWrapper(NodeExecStats* stats)
: stats_(stats) {}
+void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* v) {
+ DCHECK(v);
+ NodeOutput* no = stats_->add_output();
+ no->set_slot(slot);
+ v->FillDescription(no->mutable_tensor_description());
+}
+
+void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) {
+ for (const auto& allocator_pair : ctx->wrapped_allocators()) {
+ AddAllocation(allocator_pair.first, allocator_pair.second);
+ }
+ auto* ms = stats_->mutable_memory_stats();
+ ms->set_temp_memory_size(ctx->temp_memory_allocated());
+ for (const auto& alloc_id : ctx->persistent_alloc_ids()) {
+ ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id);
+ }
+ ms->set_persistent_memory_size(ctx->persistent_memory_allocated());
+}
+
+void NodeExecStatsWrapper::SetReferencedTensors(
+ const TensorReferenceVector& tensors) {
+ // be careful not to increment the reference count on any tensor
+ // while recording the information
+ for (size_t i = 0; i < tensors.size(); ++i) {
+ AllocationDescription* description = stats_->add_referenced_tensor();
+ tensors.at(i).FillDescription(description);
+ }
+}
+
+// TODO(tucker): merge with the DetailText function in session.cc
+// in a common location.
+bool NodeExecStatsWrapper::SetTimelineLabel(const Node* node) {
+ bool is_transfer_node = false;
+ string memory;
+ for (auto& all : stats_->memory()) {
+ int64 tot = all.total_bytes();
+ if (tot >= 0.1 * 1048576.0) {
+ int64 peak = all.peak_bytes();
+ if (peak > 0) {
+ memory =
+ strings::StrCat(memory, "[", all.allocator_name(),
+ strings::Printf(" %.1fMB %.1fMB] ", tot / 1048576.0,
+ peak / 1048576.0));
+ } else {
+ memory = strings::StrCat(memory, "[", all.allocator_name(),
+ strings::Printf(" %.1fMB] ", tot / 1048576.0));
+ }
+ }
+ }
+ const AttrSlice attrs = node->attrs();
+ string text;
+ if (IsSend(node)) {
+ string tensor_name;
+ TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
+ string recv_device;
+ TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device));
+ text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
+ "(", tensor_name, " @", recv_device);
+ is_transfer_node = true;
+ } else if (IsRecv(node)) {
+ string tensor_name;
+ TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
+ string send_device;
+ TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device));
+ text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
+ "(", tensor_name, " @", send_device);
+ is_transfer_node = true;
+ } else {
+ text =
+ strings::StrCat(memory, node->name(), " = ", node->type_string(), "(",
+ str_util::Join(node->requested_inputs(), ", "), ")");
+ }
+ stats_->set_timeline_label(text);
+ return is_transfer_node;
+}
+
void NodeExecStatsWrapper::AddAllocation(
Allocator* allocator, TrackingAllocator* tracking_allocator) {
AllocatorMemoryUsed* memory = stats_->add_memory();
diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h
index 0394f25839..7206fbf427 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.h
+++ b/tensorflow/core/common_runtime/step_stats_collector.h
@@ -19,7 +19,9 @@ limitations under the License.
#include <unordered_map>
#include <vector>
#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
@@ -30,33 +32,99 @@ class Allocator;
class AllocatorMemoryUsed;
class CostModelManager;
class Graph;
+class Node;
class NodeExecStats;
+class OpKernelContext;
class StepStats;
+class Tensor;
class TrackingAllocator;
// Wraps NodeExecStats and adds allocation to it.
class NodeExecStatsWrapper {
public:
- NodeExecStatsWrapper();
+ NodeExecStatsWrapper(const string& node_name);
// Owns 'stats'.
NodeExecStatsWrapper(NodeExecStats* stats);
// Destructor calls Finalize() to release the TrackingAllocators.
~NodeExecStatsWrapper() { Finalize(); }
- NodeExecStats* stats() { return stats_.get(); }
-
- // "Does not take ownership of the 'allocator'.
- // Transfers ownership of the 'tracking_allocator' to *this."
- void AddAllocation(Allocator* allocator,
- TrackingAllocator* tracking_allocator);
+ // Records the absolute time in nanoseconds at which this node became
+ // runnable (i.e. was scheduled for execution).
+ void SetScheduled(int64 nanos) {
+ stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos);
+ stats_->set_scheduled_nanos(nanos);
+ }
+
+ // Called immediately after this node starts being processed by the executor.
+ void RecordExecutorStarted() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
+ stats_->set_all_start_nanos(now_nanos);
+ }
+
+ // Called immediately before this node's `Compute()` or `ComputeAsync()`
+ // method is called.
+ void RecordComputeStarted() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos());
+ }
+
+ // Called immediately after this node's `Compute()` method returned (or, for
+ // asynchronous operations, the callback passed to its `ComputeAsync()` method
+ // was called).
+ void RecordComputeEnded() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos());
+ }
+
+ // Called immediately after this executor finishes processing this node.
+ void RecordExecutorEnded() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos());
+ }
+
+ // Records information about the tensor produced by this node at the given
+ // output slot.
+ void SetOutput(int slot, const Tensor* v);
+
+ // Records information about the memory allocated during the execution of this
+ // node.
+ void SetMemory(OpKernelContext* ctx);
+
+ // Records information about the tensors that were accessed during the
+ // execution of this node.
+ void SetReferencedTensors(const TensorReferenceVector& tensors);
+
+ // Sets the timeline_label field of the wrapped NodeExecStats, using data
+ // from *node. Returns true iff the node is a transfer node.
+ bool SetTimelineLabel(const Node* node);
private:
friend class StepStatsCollector;
+ NodeExecStats* stats() { return stats_.get(); }
+
// Populates stats_ and releases TrackingAllocator.
void Finalize();
+ // Does not take ownership of the `allocator`.
+ // Takes ownership of `tracking_allocator`.
+ void AddAllocation(Allocator* allocator,
+ TrackingAllocator* tracking_allocator);
+
gtl::InlinedVector<std::pair<AllocatorMemoryUsed*, TrackingAllocator*>, 2>
allocations_;
std::unique_ptr<NodeExecStats> stats_;
diff --git a/tensorflow/core/common_runtime/sycl/sycl_allocator.h b/tensorflow/core/common_runtime/sycl/sycl_allocator.h
index 550f193332..cc5909de17 100644
--- a/tensorflow/core/common_runtime/sycl/sycl_allocator.h
+++ b/tensorflow/core/common_runtime/sycl/sycl_allocator.h
@@ -17,8 +17,8 @@ limitations under the License.
#error This file must only be included when building TensorFlow with SYCL support
#endif
-#ifndef TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
-#define TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/allocator.h"
@@ -72,4 +72,4 @@ class SYCLAllocator : public Allocator {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index b2192c5a80..37029f3f1a 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -562,6 +562,7 @@ cc_library(
deps = [
":worker_cache",
":worker_interface",
+ "//tensorflow/core:framework",
],
)
diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc
index a48f734d3e..269f620e42 100644
--- a/tensorflow/core/distributed_runtime/master.cc
+++ b/tensorflow/core/distributed_runtime/master.cc
@@ -53,6 +53,7 @@ limitations under the License.
#include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@@ -167,13 +168,55 @@ class DeviceFinder {
}
// Enumerates all known workers' target. A target name is a
// prefix of a device name. E.g., /job:mnist/replica:0/task:10.
- CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
- const string& local_device_name = env_->local_devices[0]->name();
- std::vector<string> workers;
- worker_cache->ListWorkers(&workers);
if (filters_.empty()) {
+ // If no filters were specified, we list all known workers in
+ // `worker_cache`.
+ std::vector<string> workers;
+ worker_cache->ListWorkers(&workers);
std::swap(workers, targets_);
} else {
+ // When applying filters, we must include the local worker, even if it
+ // does not match any of the filters.
+ CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
+ const string& local_device_name = env_->local_devices[0]->name();
+ DeviceNameUtils::ParsedName local_parsed_name;
+ CHECK(DeviceNameUtils::ParseFullName(local_device_name,
+ &local_parsed_name));
+ bool all_filters_have_job = true;
+ std::unordered_set<string> filter_job_names({local_parsed_name.job});
+ for (const DeviceNameUtils::ParsedName& filter : filters_) {
+ all_filters_have_job = all_filters_have_job && filter.has_job;
+ if (filter.has_job) {
+ filter_job_names.insert(filter.job);
+ }
+ }
+
+ std::vector<string> workers;
+ if (all_filters_have_job) {
+ // If all of the device filters have a job specified, then we only need
+ // to list the workers in the jobs named in the filter, because a worker
+ // in any other job would not match any filter.
+ for (const string& job_name : filter_job_names) {
+ VLOG(2) << "Selectively listing workers in job: " << job_name;
+ std::vector<string> workers_in_job;
+ worker_cache->ListWorkersInJob(job_name, &workers_in_job);
+ workers.insert(workers.end(), workers_in_job.begin(),
+ workers_in_job.end());
+ }
+ } else {
+ // If any of the device filters does not have a job specified, then we
+ // must list the workers from all jobs.
+ VLOG(2) << "Listing workers in all jobs because some device "
+ << "filter has no job specified. Filters were:";
+ if (device_filters.empty()) {
+ VLOG(2) << "- <NO FILTERS>";
+ } else {
+ for (const string& filter : device_filters) {
+ VLOG(2) << "- " << filter;
+ }
+ }
+ worker_cache->ListWorkers(&workers);
+ }
for (const string& name : workers) {
if (MatchFilters(name) ||
DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
index b7eb3c9015..456c30ecf4 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
@@ -163,6 +163,13 @@ class MultiGrpcChannelCache : public CachingGrpcChannelCache {
}
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) override {
+ for (GrpcChannelCache* cache : caches_) {
+ cache->ListWorkersInJob(job_name, workers);
+ }
+ }
+
string TranslateTask(const string& target) override {
mutex_lock l(mu_); // could use reader lock
GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
@@ -223,6 +230,13 @@ class SparseGrpcChannelCache : public CachingGrpcChannelCache {
}
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) override {
+ if (job_name == job_id_) {
+ ListWorkers(workers);
+ }
+ }
+
string TranslateTask(const string& target) override {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
index 4861cdb691..6fa99d7b14 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
@@ -66,6 +66,8 @@ class GrpcChannelCache {
// /job:<job identifier>/task:<task id>
// e.g. /job:mnist/task:2
virtual void ListWorkers(std::vector<string>* workers) = 0;
+ virtual void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) = 0;
// If found, returns a gRPC channel that is connected to the remote
// worker named by 'target'. 'target' is of the following
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
index f07a5a0974..a814ef85e2 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
@@ -89,13 +89,33 @@ TEST(GrpcChannelTest, HostPorts) {
EXPECT_NE(d_4_1.get(), e_5_2.get());
}
- std::vector<string> workers;
- cc->ListWorkers(&workers);
- EXPECT_EQ(std::vector<string>(
- {"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
- "/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
- "/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
- workers);
+ {
+ std::vector<string> workers;
+ cc->ListWorkers(&workers);
+ EXPECT_EQ(
+ std::vector<string>(
+ {"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
+ "/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
+ "/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
+ workers);
+ }
+
+ {
+ std::vector<string> workers;
+ cc->ListWorkersInJob("mnist", &workers);
+ EXPECT_EQ(
+ std::vector<string>(
+ {"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
+ "/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
+ "/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
+ workers);
+ }
+
+ {
+ std::vector<string> workers;
+ cc->ListWorkersInJob("other", &workers);
+ EXPECT_TRUE(workers.empty());
+ }
}
TEST(GrpcChannelTest, SparseHostPorts) {
@@ -135,13 +155,30 @@ TEST(GrpcChannelTest, SparseHostPorts) {
EXPECT_NE(d_4_1.get(), e_5_2.get());
}
- std::vector<string> workers;
- cc->ListWorkers(&workers);
- std::sort(workers.begin(), workers.end());
- EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
- "/job:mnist/replica:0/task:3",
- "/job:mnist/replica:0/task:4"}),
- workers);
+ {
+ std::vector<string> workers;
+ cc->ListWorkers(&workers);
+ std::sort(workers.begin(), workers.end());
+ EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
+ "/job:mnist/replica:0/task:3",
+ "/job:mnist/replica:0/task:4"}),
+ workers);
+ }
+
+ {
+ std::vector<string> workers;
+ cc->ListWorkersInJob("mnist", &workers);
+ EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
+ "/job:mnist/replica:0/task:3",
+ "/job:mnist/replica:0/task:4"}),
+ workers);
+ }
+
+ {
+ std::vector<string> workers;
+ cc->ListWorkersInJob("other", &workers);
+ EXPECT_TRUE(workers.empty());
+ }
}
TEST(GrpcChannelTest, NewHostPortGrpcChannelValidation) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
index b9f21ea211..e1541db69b 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
@@ -54,6 +54,11 @@ class GrpcWorkerCache : public WorkerCachePartial {
channel_cache_->ListWorkers(workers);
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const override {
+ channel_cache_->ListWorkersInJob(job_name, workers);
+ }
+
WorkerInterface* CreateWorker(const string& target) override {
if (target == local_target_) {
return local_worker_;
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
index 25ff6512a0..b070dd13dd 100644
--- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
@@ -50,6 +50,8 @@ namespace {
// Fake cache implementation for WorkerEnv.
class DummyWorkerCache : public WorkerCacheInterface {
void ListWorkers(std::vector<string>* workers) const override {}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const override {}
WorkerInterface* CreateWorker(const string& target) override {
return nullptr;
}
diff --git a/tensorflow/core/distributed_runtime/test_utils.h b/tensorflow/core/distributed_runtime/test_utils.h
index 48d83845dd..88a97da34d 100644
--- a/tensorflow/core/distributed_runtime/test_utils.h
+++ b/tensorflow/core/distributed_runtime/test_utils.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
+#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@@ -138,6 +139,19 @@ class TestWorkerCache : public WorkerCacheInterface {
}
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const override {
+ workers->clear();
+ for (auto it : workers_) {
+ DeviceNameUtils::ParsedName device_name;
+ CHECK(DeviceNameUtils::ParseFullName(it.first, &device_name));
+ CHECK(device_name.has_job);
+ if (job_name == device_name.job) {
+ workers->push_back(it.first);
+ }
+ }
+ }
+
WorkerInterface* CreateWorker(const string& target) override {
auto it = workers_.find(target);
if (it != workers_.end()) {
diff --git a/tensorflow/core/distributed_runtime/worker_cache.h b/tensorflow/core/distributed_runtime/worker_cache.h
index 8521f8956b..0c8575b4d5 100644
--- a/tensorflow/core/distributed_runtime/worker_cache.h
+++ b/tensorflow/core/distributed_runtime/worker_cache.h
@@ -36,6 +36,8 @@ class WorkerCacheInterface {
// Updates *workers with strings naming the remote worker tasks to
// which open channels have been established.
virtual void ListWorkers(std::vector<string>* workers) const = 0;
+ virtual void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const = 0;
// If "target" names a remote task for which an RPC channel exists
// or can be constructed, returns a pointer to a WorkerInterface object
diff --git a/tensorflow/core/distributed_runtime/worker_cache_wrapper.h b/tensorflow/core/distributed_runtime/worker_cache_wrapper.h
index 43c3b6285b..1f309b4361 100644
--- a/tensorflow/core/distributed_runtime/worker_cache_wrapper.h
+++ b/tensorflow/core/distributed_runtime/worker_cache_wrapper.h
@@ -32,6 +32,10 @@ class WorkerCacheWrapper : public WorkerCacheInterface {
virtual void ListWorkers(std::vector<string>* workers) const {
return wrapped_->ListWorkers(workers);
}
+ virtual void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const {
+ return wrapped_->ListWorkersInJob(job_name, workers);
+ }
// If "target" names a remote task for which an RPC channel exists
// or can be constructed, returns a pointer to a WorkerInterface object
diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc
index ca6dc1b1de..c7d0c6b7f3 100644
--- a/tensorflow/core/distributed_runtime/worker_session.cc
+++ b/tensorflow/core/distributed_runtime/worker_session.cc
@@ -35,6 +35,11 @@ class WorkerFreeListCache : public WorkerCacheInterface {
wrapped_->ListWorkers(workers);
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const override {
+ wrapped_->ListWorkersInJob(job_name, workers);
+ }
+
WorkerInterface* CreateWorker(const string& target) override {
mutex_lock l(mu_);
auto p = workers_.find(target);
diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc
index 9be0dc69d2..3597f43d51 100644
--- a/tensorflow/core/framework/op_def_util.cc
+++ b/tensorflow/core/framework/op_def_util.cc
@@ -172,6 +172,15 @@ const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) {
return nullptr;
}
+const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
+ for (int i = 0; i < api_def.in_arg_size(); ++i) {
+ if (api_def.in_arg(i).name() == name) {
+ return &api_def.in_arg(i);
+ }
+ }
+ return nullptr;
+}
+
#define VALIDATE(EXPR, ...) \
do { \
if (!(EXPR)) { \
diff --git a/tensorflow/core/framework/op_def_util.h b/tensorflow/core/framework/op_def_util.h
index 0ba1325a03..4f67a258d3 100644
--- a/tensorflow/core/framework/op_def_util.h
+++ b/tensorflow/core/framework/op_def_util.h
@@ -20,6 +20,7 @@ limitations under the License.
#define TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_
#include <string>
+#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -47,6 +48,10 @@ OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def);
// Returns nullptr if no such attr is found.
const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def);
+// Searches api_def for input argument with the indicated name.
+// Returns nullptr if no such attr is found.
+const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def);
+
// Produce a human-readable version of an op_def that is more concise
// than a text-format proto. Excludes descriptions.
string SummarizeOpDef(const OpDef& op_def);
diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h
index f1cd37ecda..6f6b7cec3e 100644
--- a/tensorflow/core/framework/register_types.h
+++ b/tensorflow/core/framework/register_types.h
@@ -161,9 +161,12 @@ limitations under the License.
TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \
TF_CALL_uint8(m) TF_CALL_int8(m)
+#define TF_CALL_FLOAT_TYPES(m) \
+ TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)
+
#define TF_CALL_REAL_NUMBER_TYPES(m) \
TF_CALL_INTEGRAL_TYPES(m) \
- TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)
+ TF_CALL_FLOAT_TYPES(m)
#define TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \
TF_CALL_INTEGRAL_TYPES(m) TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m)
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index 33d4cb77ff..976fede148 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -61,8 +61,8 @@ namespace tensorflow {
//
// // Create a var.
// MyVar* my_var = new MyVar;
-// my_var.val = Tensor(DT_FLOAT, my_shape);
-// my_var.val.flat<float>().setZeros(); // 0 initialized.
+// my_var->val = Tensor(DT_FLOAT, my_shape);
+// my_var->val.flat<float>().setZeros(); // 0 initialized.
// ctx->SetStatus(rm.Create("my_container", "my_name", my_var));
//
// // += a variable.
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 5683944e46..833592caab 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -2418,6 +2418,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
csinfo_.conv2d_grad_filter_with_bias =
"__MklDummyConv2DBackpropFilterWithBias";
+ csinfo_.conv3d = "Conv3D";
+ csinfo_.conv3d_grad_input = "Conv3DBackpropInputV2";
+ csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2";
csinfo_.fused_batch_norm = "FusedBatchNorm";
csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
csinfo_.identity = "Identity";
@@ -2468,18 +2471,27 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CopyAttrsConcatV2, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d,
mkl_op_registry::GetMklOpName(csinfo_.conv2d),
- CopyAttrsConv2D, AlwaysRewrite});
+ CopyAttrsConv, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias,
- CopyAttrsConv2D, AlwaysRewrite});
+ CopyAttrsConv, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_filter,
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
- CopyAttrsConv2D, AlwaysRewrite});
+ CopyAttrsConv, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
- csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv2D,
+ csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv,
AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_input,
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
- CopyAttrsConv2D, AlwaysRewrite});
+ CopyAttrsConv, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.conv3d,
+ mkl_op_registry::GetMklOpName(csinfo_.conv3d),
+ CopyAttrsConv, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.conv3d_grad_filter,
+ mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter),
+ CopyAttrsConv, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.conv3d_grad_input,
+ mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input),
+ CopyAttrsConv, AlwaysRewrite});
rinfo_.push_back({csinfo_.fused_batch_norm,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
CopyAttrsFusedBatchNorm, AlwaysRewrite});
@@ -2614,6 +2626,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string conv2d_grad_input;
string conv2d_grad_filter;
string conv2d_grad_filter_with_bias;
+ string conv3d;
+ string conv3d_grad_input;
+ string conv3d_grad_filter;
string fused_batch_norm;
string fused_batch_norm_grad;
string identity;
@@ -3086,7 +3101,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb);
- static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb);
@@ -3571,14 +3586,13 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
// Op-specific functions to copy attributes from old node to new node
//////////////////////////////////////////////////////////////////////////
-void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
- NodeBuilder* nb) {
+void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node,
+ NodeBuilder* nb) {
DataType T;
string data_format;
string padding;
std::vector<int32> strides;
std::vector<int32> dilations;
- bool use_cudnn_on_gpu;
// Get all attributes from old node.
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
@@ -3586,8 +3600,6 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
- TF_CHECK_OK(
- GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
// Add attributes to new node.
nb->Attr("T", T);
@@ -3595,7 +3607,6 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
nb->Attr("dilations", dilations);
nb->Attr("padding", padding);
nb->Attr("data_format", data_format);
- nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
}
void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node,
@@ -3896,7 +3907,7 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd
// Copy attributes from Conv2D to Conv2DWithBias.
- CopyAttrsConv2D(const_cast<const Node*>(pred), &nb);
+ CopyAttrsConv(const_cast<const Node*>(pred), &nb);
// Copy the device assigned to old node to new node.
nb.Device(succ->def().device());
@@ -4007,7 +4018,7 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
}
// Copy attributes from Conv2DBackpropFilter.
- CopyAttrsConv2D(const_cast<const Node*>(fltr), &nb);
+ CopyAttrsConv(const_cast<const Node*>(fltr), &nb);
// Copy the device assigned to old node to new node.
nb.Device(fltr->def().device());
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 46c234d057..e40347fcf4 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -495,16 +495,6 @@ cc_library(
],
)
-cc_library(
- name = "warn_about_ints",
- srcs = ["warn_about_ints.cc"],
- hdrs = ["warn_about_ints.h"],
- deps = [
- "//tensorflow/core:framework",
- "//tensorflow/core:protos_all_cc",
- ],
-)
-
# Private support libraries ---------------------------------------------------
cc_header_only_library(
@@ -1308,6 +1298,7 @@ tf_cuda_cc_test(
srcs = ["gather_nd_op_test.cc"],
deps = [
":gather_nd_op",
+ ":host_constant_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
@@ -2884,7 +2875,7 @@ tf_kernel_library(
copts = if_override_eigen_strong_inline(["/DEIGEN_STRONG_INLINE=inline"]),
prefix = "batch_matmul_op",
deps = MATH_DEPS + if_mkl_ml([
- "//third_party/intel_mkl_ml",
+ "//third_party/mkl:intel_binary_blob",
]),
)
@@ -3176,6 +3167,7 @@ tf_cuda_cc_test(
"//conditions:default": [],
}),
deps = [
+ ":host_constant_op",
":ops_testutil",
":ops_util",
":reduction_ops",
@@ -3311,6 +3303,7 @@ tf_cuda_cc_test(
srcs = ["diag_op_test.cc"],
deps = [
":diag_op",
+ ":host_constant_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
@@ -3531,13 +3524,13 @@ tf_kernel_library(
tf_kernel_library(
name = "softplus_op",
prefix = "softplus_op",
- deps = NN_DEPS + [":warn_about_ints"],
+ deps = NN_DEPS,
)
tf_kernel_library(
name = "softsign_op",
prefix = "softsign_op",
- deps = NN_DEPS + [":warn_about_ints"],
+ deps = NN_DEPS,
)
tf_kernel_library(
@@ -3638,6 +3631,7 @@ tf_cuda_cc_test(
name = "nn_ops_test",
srcs = ["nn_ops_test.cc"],
deps = [
+ ":host_constant_op",
":nn",
":ops_testutil",
":ops_util",
@@ -3785,6 +3779,7 @@ tf_cuda_cc_test(
srcs = ["spacetobatch_benchmark_test.cc"],
deps = [
":batch_space_ops",
+ ":host_constant_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
@@ -3924,6 +3919,7 @@ tf_cuda_cc_test(
size = "small",
srcs = ["random_op_test.cc"],
deps = [
+ ":host_constant_op",
":random_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
@@ -4178,6 +4174,7 @@ tf_cuda_cc_tests(
"sparse_xent_op_test.cc",
],
deps = [
+ ":host_constant_op",
":ops_testutil",
":ops_util",
":sparse",
@@ -4391,6 +4388,7 @@ cc_library(
":regex_full_match_op",
":regex_replace_op",
":string_join_op",
+ ":string_length_op",
":string_split_op",
":string_strip_op",
":string_to_hash_bucket_op",
@@ -4426,6 +4424,12 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "string_length_op",
+ prefix = "string_length_op",
+ deps = STRING_DEPS,
+)
+
+tf_kernel_library(
name = "regex_full_match_op",
prefix = "regex_full_match_op",
deps = STRING_DEPS + ["@com_googlesource_code_re2//:re2"],
@@ -4437,12 +4441,48 @@ tf_kernel_library(
deps = STRING_DEPS + ["@com_googlesource_code_re2//:re2"],
)
+tf_cc_test(
+ name = "regex_replace_op_test",
+ size = "small",
+ srcs = ["regex_replace_op_test.cc"],
+ deps = [
+ ":regex_replace_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
tf_kernel_library(
name = "string_split_op",
prefix = "string_split_op",
deps = STRING_DEPS,
)
+tf_cc_test(
+ name = "string_split_op_test",
+ size = "small",
+ srcs = ["string_split_op_test.cc"],
+ deps = [
+ ":string_split_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
tf_kernel_library(
name = "string_strip_op",
prefix = "string_strip_op",
@@ -4516,6 +4556,7 @@ tf_cuda_cc_test(
size = "small",
srcs = ["multinomial_op_test.cc"],
deps = [
+ ":host_constant_op",
":multinomial_op",
":ops_util",
"//tensorflow/core:core_cpu",
@@ -4543,6 +4584,7 @@ tf_cuda_cc_test(
size = "small",
srcs = ["parameterized_truncated_normal_op_test.cc"],
deps = [
+ ":host_constant_op",
":ops_util",
":parameterized_truncated_normal_op",
"//tensorflow/core:core_cpu",
@@ -5052,7 +5094,6 @@ filegroup(
"training_ops.h",
"transpose_functor.h",
"transpose_op.h",
- "warn_about_ints.h",
"where_op.h",
"xent_op.h",
],
@@ -5229,7 +5270,6 @@ filegroup(
"transpose_functor_cpu.cc",
"transpose_op.cc",
"unique_op.cc",
- "warn_about_ints.cc",
"where_op.cc",
"xent_op.cc",
":android_extended_ops_headers",
diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h
index 475bda848d..766713a338 100644
--- a/tensorflow/core/kernels/batch_matmul_op_impl.h
+++ b/tensorflow/core/kernels/batch_matmul_op_impl.h
@@ -15,6 +15,9 @@ limitations under the License.
// See docs in ../ops/math_ops.cc.
+#ifndef TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
+
#define EIGEN_USE_THREADS
#include <vector>
@@ -613,3 +616,5 @@ class BatchMatMul : public OpKernel {
BatchMatMul<SYCLDevice, TYPE>)
#endif // TENSORFLOW_USE_SYCL
} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
new file mode 100644
index 0000000000..3163c63949
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
@@ -0,0 +1,63 @@
+# Description:
+# This directory contains common utilities used in boosted_trees.
+package(
+ default_visibility = ["//tensorflow:internal"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+
+# Quantiles
+
+cc_library(
+ name = "weighted_quantiles",
+ srcs = [],
+ hdrs = [
+ "weighted_quantiles_buffer.h",
+ "weighted_quantiles_stream.h",
+ "weighted_quantiles_summary.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ ],
+)
+
+tf_cc_test(
+ name = "weighted_quantiles_buffer_test",
+ size = "small",
+ srcs = ["weighted_quantiles_buffer_test.cc"],
+ deps = [
+ ":weighted_quantiles",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_cc_test(
+ name = "weighted_quantiles_summary_test",
+ size = "small",
+ srcs = ["weighted_quantiles_summary_test.cc"],
+ deps = [
+ ":weighted_quantiles",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_cc_test(
+ name = "weighted_quantiles_stream_test",
+ size = "small",
+ srcs = ["weighted_quantiles_stream_test.cc"],
+ deps = [
+ ":weighted_quantiles",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h
new file mode 100644
index 0000000000..07aa9831c4
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h
@@ -0,0 +1,132 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
+
+#include <algorithm>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace quantiles {
+
+// Buffering container ideally suited for scenarios where we need
+// to sort and dedupe/compact fixed chunks of a stream of weighted elements.
+template <typename ValueType, typename WeightType,
+ typename CompareFn = std::less<ValueType>>
+class WeightedQuantilesBuffer {
+ public:
+ struct BufferEntry {
+ BufferEntry(ValueType v, WeightType w)
+ : value(std::move(v)), weight(std::move(w)) {}
+ BufferEntry() : value(), weight(0) {}
+
+ bool operator<(const BufferEntry& other) const {
+ return kCompFn(value, other.value);
+ }
+ bool operator==(const BufferEntry& other) const {
+ return value == other.value && weight == other.weight;
+ }
+ friend std::ostream& operator<<(std::ostream& strm,
+ const BufferEntry& entry) {
+ return strm << "{" << entry.value << ", " << entry.weight << "}";
+ }
+ ValueType value;
+ WeightType weight;
+ };
+
+ explicit WeightedQuantilesBuffer(int64 block_size, int64 max_elements)
+ : max_size_(std::min(block_size << 1, max_elements)) {
+ QCHECK(max_size_ > 0) << "Invalid buffer specification: (" << block_size
+ << ", " << max_elements << ")";
+ vec_.reserve(max_size_);
+ }
+
+ // Disallow copying as it's semantically non-sensical in the Squawd algorithm
+ // but enable move semantics.
+ WeightedQuantilesBuffer(const WeightedQuantilesBuffer& other) = delete;
+ WeightedQuantilesBuffer& operator=(const WeightedQuantilesBuffer&) = delete;
+ WeightedQuantilesBuffer(WeightedQuantilesBuffer&& other) = default;
+ WeightedQuantilesBuffer& operator=(WeightedQuantilesBuffer&& other) = default;
+
+ // Push entry to buffer and maintain a compact representation within
+ // pre-defined size limit.
+ void PushEntry(ValueType value, WeightType weight) {
+ // Callers are expected to act on a full compacted buffer after the
+ // PushEntry call returns.
+ QCHECK(!IsFull()) << "Buffer already full: " << max_size_;
+
+ // Ignore zero and negative weight entries.
+ if (weight <= 0) {
+ return;
+ }
+
+ // Push back the entry to the buffer.
+ vec_.push_back(BufferEntry(std::move(value), std::move(weight)));
+ }
+
+ // Returns a sorted vector view of the base buffer and clears the buffer.
+ // Callers should minimize how often this is called, ideally only right after
+ // the buffer becomes full.
+ std::vector<BufferEntry> GenerateEntryList() {
+ std::vector<BufferEntry> ret;
+ if (vec_.size() == 0) {
+ return ret;
+ }
+ ret.swap(vec_);
+ vec_.reserve(max_size_);
+ std::sort(ret.begin(), ret.end());
+ size_t num_entries = 0;
+ for (size_t i = 1; i < ret.size(); ++i) {
+ if (ret[i].value != ret[i - 1].value) {
+ BufferEntry tmp = ret[i];
+ ++num_entries;
+ ret[num_entries] = tmp;
+ } else {
+ ret[num_entries].weight += ret[i].weight;
+ }
+ }
+ ret.resize(num_entries + 1);
+ return ret;
+ }
+
+ int64 Size() const { return vec_.size(); }
+ bool IsFull() const { return vec_.size() >= max_size_; }
+ void Clear() { vec_.clear(); }
+
+ private:
+ using BufferVector = typename std::vector<BufferEntry>;
+
+ // Comparison function.
+ static constexpr decltype(CompareFn()) kCompFn = CompareFn();
+
+ // Base buffer.
+ size_t max_size_;
+ BufferVector vec_;
+};
+
+template <typename ValueType, typename WeightType, typename CompareFn>
+constexpr decltype(CompareFn())
+ WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>::kCompFn;
+
+} // namespace quantiles
+} // namespace boosted_trees
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer_test.cc b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer_test.cc
new file mode 100644
index 0000000000..75f05d64f3
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer_test.cc
@@ -0,0 +1,99 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+using Buffer =
+ boosted_trees::quantiles::WeightedQuantilesBuffer<double, double>;
+using BufferEntry =
+ boosted_trees::quantiles::WeightedQuantilesBuffer<double,
+ double>::BufferEntry;
+
+class WeightedQuantilesBufferTest : public ::testing::Test {};
+
+TEST_F(WeightedQuantilesBufferTest, Invalid) {
+ EXPECT_DEATH(
+ ({
+ boosted_trees::quantiles::WeightedQuantilesBuffer<double, double>
+ buffer(2, 0);
+ }),
+ "Invalid buffer specification");
+ EXPECT_DEATH(
+ ({
+ boosted_trees::quantiles::WeightedQuantilesBuffer<double, double>
+ buffer(0, 2);
+ }),
+ "Invalid buffer specification");
+}
+
+TEST_F(WeightedQuantilesBufferTest, PushEntryNotFull) {
+ Buffer buffer(20, 100);
+ buffer.PushEntry(5, 9);
+ buffer.PushEntry(2, 3);
+ buffer.PushEntry(-1, 7);
+ buffer.PushEntry(3, 0); // This entry will be ignored.
+
+ EXPECT_FALSE(buffer.IsFull());
+ EXPECT_EQ(buffer.Size(), 3);
+}
+
+TEST_F(WeightedQuantilesBufferTest, PushEntryFull) {
+ // buffer capacity is 4.
+ Buffer buffer(2, 100);
+ buffer.PushEntry(5, 9);
+ buffer.PushEntry(2, 3);
+ buffer.PushEntry(-1, 7);
+ buffer.PushEntry(2, 1);
+
+ std::vector<BufferEntry> expected;
+ expected.emplace_back(-1, 7);
+ expected.emplace_back(2, 4);
+ expected.emplace_back(5, 9);
+
+ // At this point, we have pushed 4 entries and we expect the buffer to be
+ // full.
+ EXPECT_TRUE(buffer.IsFull());
+ EXPECT_EQ(buffer.GenerateEntryList(), expected);
+ EXPECT_FALSE(buffer.IsFull());
+}
+
+TEST_F(WeightedQuantilesBufferTest, PushEntryFullDeath) {
+ // buffer capacity is 4.
+ Buffer buffer(2, 100);
+ buffer.PushEntry(5, 9);
+ buffer.PushEntry(2, 3);
+ buffer.PushEntry(-1, 7);
+ buffer.PushEntry(2, 1);
+
+ std::vector<BufferEntry> expected;
+ expected.emplace_back(-1, 7);
+ expected.emplace_back(2, 4);
+ expected.emplace_back(5, 9);
+
+ // At this point, we have pushed 4 entries and we expect the buffer to be
+ // full.
+ EXPECT_TRUE(buffer.IsFull());
+ // Can't push any more entries before clearing.
+ EXPECT_DEATH(({ buffer.PushEntry(6, 6); }), "Buffer already full");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h
new file mode 100644
index 0000000000..525e2a6a64
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h
@@ -0,0 +1,330 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
+
+#include <cmath>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace quantiles {
+
+// Class to compute approximate quantiles with error bound guarantees for
+// weighted data sets.
+// This implementation is an adaptation of techniques from the following papers:
+// * (2001) Space-efficient online computation of quantile summaries.
+// * (2004) Power-conserving computation of order-statistics over
+// sensor networks.
+// * (2007) A fast algorithm for approximate quantiles in high speed
+// data streams.
+// * (2016) XGBoost: A Scalable Tree Boosting System.
+//
+// The key ideas at play are the following:
+// - Maintain an in-memory multi-level quantile summary in a way to guarantee
+// a maximum approximation error of eps * W per bucket where W is the total
+// weight across all points in the input dataset.
+// - Two base operations are defined: MERGE and COMPRESS. MERGE combines two
+// summaries guaranteeing a epsNew = max(eps1, eps2). COMPRESS compresses
+// a summary to b + 1 elements guaranteeing epsNew = epsOld + 1/b.
+// - b * sizeof(summary entry) must ideally be small enough to fit in an
+// average CPU L2 cache.
+// - To distribute this algorithm with maintaining error bounds, we need
+// the worker-computed summaries to have no more than eps / h error
+// where h is the height of the distributed computation graph which
+// is 2 for an MR with no combiner.
+//
+// We mainly want to max out IO bw by ensuring we're not compute-bound and
+// using a reasonable amount of RAM.
+//
+// Complexity:
+// Compute: O(n * log(1/eps * log(eps * n))).
+// Memory: O(1/eps * log^2(eps * n)) <- for one worker streaming through the
+// entire dataset.
+// An epsilon value of zero would make the algorithm extremely inefficent and
+// therefore, is disallowed.
+template <typename ValueType, typename WeightType,
+ typename CompareFn = std::less<ValueType>>
+class WeightedQuantilesStream {
+ public:
+ using Buffer = WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>;
+ using BufferEntry = typename Buffer::BufferEntry;
+ using Summary = WeightedQuantilesSummary<ValueType, WeightType, CompareFn>;
+ using SummaryEntry = typename Summary::SummaryEntry;
+
+ explicit WeightedQuantilesStream(double eps, int64 max_elements)
+ : eps_(eps), buffer_(1LL, 2LL), finalized_(false) {
+ // See the class documentation. An epsilon value of zero could cause
+ // perfoamance issues.
+ QCHECK(eps > 0) << "An epsilon value of zero is not allowed.";
+ std::tie(max_levels_, block_size_) = GetQuantileSpecs(eps, max_elements);
+ buffer_ = Buffer(block_size_, max_elements);
+ summary_levels_.reserve(max_levels_);
+ }
+
+ // Disallow copy and assign but enable move semantics for the stream.
+ WeightedQuantilesStream(const WeightedQuantilesStream& other) = delete;
+ WeightedQuantilesStream& operator=(const WeightedQuantilesStream&) = delete;
+ WeightedQuantilesStream(WeightedQuantilesStream&& other) = default;
+ WeightedQuantilesStream& operator=(WeightedQuantilesStream&& other) = default;
+
+ // Pushes one entry while maintaining approximation error invariants.
+ void PushEntry(const ValueType& value, const WeightType& weight) {
+ // Validate state.
+ QCHECK(!finalized_) << "Finalize() already called.";
+
+ // Push element to base buffer.
+ buffer_.PushEntry(value, weight);
+
+ // When compacted buffer is full we need to compress
+ // and push weighted quantile summary up the level chain.
+ if (buffer_.IsFull()) {
+ PushBuffer(buffer_);
+ }
+ }
+
+ // Pushes full buffer while maintaining approximation error invariants.
+ void PushBuffer(Buffer& buffer) {
+ // Validate state.
+ QCHECK(!finalized_) << "Finalize() already called.";
+
+ // Create local compressed summary and propagate.
+ local_summary_.BuildFromBufferEntries(buffer.GenerateEntryList());
+ local_summary_.Compress(block_size_, eps_);
+ PropagateLocalSummary();
+ }
+
+ // Pushes full summary while maintaining approximation error invariants.
+ void PushSummary(const std::vector<SummaryEntry>& summary) {
+ // Validate state.
+ QCHECK(!finalized_) << "Finalize() already called.";
+
+ // Create local compressed summary and propagate.
+ local_summary_.BuildFromSummaryEntries(summary);
+ local_summary_.Compress(block_size_, eps_);
+ PropagateLocalSummary();
+ }
+
+ // Flushes approximator and finalizes state.
+ void Finalize() {
+ // Validate state.
+ QCHECK(!finalized_) << "Finalize() may only be called once.";
+
+ // Flush any remaining buffer elements.
+ PushBuffer(buffer_);
+
+ // Create final merged summary.
+ local_summary_.Clear();
+ for (auto& summary : summary_levels_) {
+ local_summary_.Merge(summary);
+ summary.Clear();
+ }
+ summary_levels_.clear();
+ summary_levels_.shrink_to_fit();
+ finalized_ = true;
+ }
+
+ // Generates requested number of quantiles after finalizing stream.
+ // The returned quantiles can be queried using std::lower_bound to get
+ // the bucket for a given value.
+ std::vector<ValueType> GenerateQuantiles(int64 num_quantiles) const {
+ // Validate state.
+ QCHECK(finalized_)
+ << "Finalize() must be called before generating quantiles.";
+ return local_summary_.GenerateQuantiles(num_quantiles);
+ }
+
+ // Generates requested number of boundaries after finalizing stream.
+ // The returned boundaries can be queried using std::lower_bound to get
+ // the bucket for a given value.
+ // The boundaries, while still guaranteeing approximation bounds, don't
+ // necessarily represent the actual quantiles of the distribution.
+ // Boundaries are preferable over quantiles when the caller is less
+ // interested in the actual quantiles distribution and more interested in
+ // getting a representative sample of boundary values.
+ std::vector<ValueType> GenerateBoundaries(int64 num_boundaries) const {
+ // Validate state.
+ QCHECK(finalized_)
+ << "Finalize() must be called before generating boundaries.";
+ return local_summary_.GenerateBoundaries(num_boundaries);
+ }
+
+ // Calculates approximation error for the specified level.
+ // If the passed level is negative, the approximation error for the entire
+ // summary is returned. Note that after Finalize is called, only the overall
+ // error is available.
+ WeightType ApproximationError(int64 level = -1) const {
+ if (finalized_) {
+ QCHECK(level <= 0) << "Only overall error is available after Finalize()";
+ return local_summary_.ApproximationError();
+ }
+
+ if (summary_levels_.empty()) {
+ // No error even if base buffer isn't empty.
+ return 0;
+ }
+
+ // If level is negative, we get the approximation error
+ // for the top-most level which is the max approximation error
+ // in all summaries by construction.
+ if (level < 0) {
+ level = summary_levels_.size() - 1;
+ }
+ QCHECK(level < summary_levels_.size()) << "Invalid level.";
+ return summary_levels_[level].ApproximationError();
+ }
+
+ size_t MaxDepth() const { return summary_levels_.size(); }
+
+ // Generates requested number of quantiles after finalizing stream.
+ const Summary& GetFinalSummary() const {
+ // Validate state.
+ QCHECK(finalized_)
+ << "Finalize() must be called before requesting final summary.";
+ return local_summary_;
+ }
+
+ // Helper method which, given the desired approximation error
+ // and an upper bound on the number of elements, computes the optimal
+ // number of levels and block size and returns them in the tuple.
+ static std::tuple<int64, int64> GetQuantileSpecs(double eps,
+ int64 max_elements);
+
+ // Serializes the internal state of the stream.
+ std::vector<Summary> SerializeInternalSummaries() const {
+ // The buffer should be empty for serialize to work.
+ QCHECK_EQ(buffer_.Size(), 0);
+ std::vector<Summary> result;
+ result.reserve(summary_levels_.size() + 1);
+ for (const Summary& summary : summary_levels_) {
+ result.push_back(summary);
+ }
+ result.push_back(local_summary_);
+ return result;
+ }
+
+ // Resets the state of the stream with a serialized state.
+ void DeserializeInternalSummaries(const std::vector<Summary>& summaries) {
+ // Clear the state before deserializing.
+ buffer_.Clear();
+ summary_levels_.clear();
+ local_summary_.Clear();
+ QCHECK_GT(max_levels_, summaries.size() - 1);
+ for (int i = 0; i < summaries.size() - 1; ++i) {
+ summary_levels_.push_back(summaries[i]);
+ }
+ local_summary_ = summaries[summaries.size() - 1];
+ }
+
+ private:
+ // Propagates local summary through summary levels while maintaining
+ // approximation error invariants.
+ void PropagateLocalSummary() {
+ // Validate state.
+ QCHECK(!finalized_) << "Finalize() already called.";
+
+ // No-op if there's nothing to add.
+ if (local_summary_.Size() <= 0) {
+ return;
+ }
+
+ // Propagate summary through levels.
+ size_t level = 0;
+ for (bool settled = false; !settled; ++level) {
+ // Ensure we have enough depth.
+ if (summary_levels_.size() <= level) {
+ summary_levels_.emplace_back();
+ }
+
+ // Merge summaries.
+ Summary& current_summary = summary_levels_[level];
+ local_summary_.Merge(current_summary);
+
+ // Check if we need to compress and propagate summary higher.
+ if (current_summary.Size() == 0 ||
+ local_summary_.Size() <= block_size_ + 1) {
+ current_summary = std::move(local_summary_);
+ settled = true;
+ } else {
+ // Compress, empty current level and propagate.
+ local_summary_.Compress(block_size_, eps_);
+ current_summary.Clear();
+ }
+ }
+ }
+
+ // Desired approximation precision.
+ double eps_;
+ // Maximum number of levels.
+ int64 max_levels_;
+ // Max block size per level.
+ int64 block_size_;
+ // Base buffer.
+ Buffer buffer_;
+ // Local summary used to minimize memory allocation and cache misses.
+ // After the stream is finalized, this summary holds the final quantile
+ // estimates.
+ Summary local_summary_;
+ // Summary levels;
+ std::vector<Summary> summary_levels_;
+ // Flag indicating whether the stream is finalized.
+ bool finalized_;
+};
+
+template <typename ValueType, typename WeightType, typename CompareFn>
+inline std::tuple<int64, int64>
+WeightedQuantilesStream<ValueType, WeightType, CompareFn>::GetQuantileSpecs(
+ double eps, int64 max_elements) {
+ int64 max_level = 1LL;
+ int64 block_size = 2LL;
+ QCHECK(eps >= 0 && eps < 1);
+ QCHECK_GT(max_elements, 0);
+
+ if (eps <= std::numeric_limits<double>::epsilon()) {
+ // Exact quantile computation at the expense of RAM.
+ max_level = 1;
+ block_size = std::max(max_elements, int64{2});
+ } else {
+ // The bottom-most level will become full at most
+ // (max_elements / block_size) times, the level above will become full
+ // (max_elements / 2 * block_size) times and generally level l becomes
+ // full (max_elements / 2^l * block_size) times until the last
+ // level max_level becomes full at most once meaning when the inequality
+ // (2^max_level * block_size >= max_elements) is satisfied.
+ // In what follows, we jointly solve for max_level and block_size by
+ // gradually increasing the level until the inequality above is satisfied.
+ // We could alternatively set max_level = ceil(log2(eps * max_elements));
+ // and block_size = ceil(max_level / eps) + 1 but that tends to give more
+ // pessimistic bounds and wastes RAM needlessly.
+ for (max_level = 1, block_size = 2;
+ (1LL << max_level) * block_size < max_elements; ++max_level) {
+ // Update upper bound on block size at current level, we always
+ // increase the estimate by 2 to hold the min/max elements seen so far.
+ block_size = static_cast<size_t>(ceil(max_level / eps)) + 1;
+ }
+ }
+ return std::make_tuple(max_level, std::max(block_size, int64{2}));
+}
+
+} // namespace quantiles
+} // namespace boosted_trees
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream_test.cc b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream_test.cc
new file mode 100644
index 0000000000..6c5b9fd23b
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream_test.cc
@@ -0,0 +1,276 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+using Tuple = std::tuple<int64, int64>;
+
+using Summary =
+ boosted_trees::quantiles::WeightedQuantilesSummary<double, double>;
+using SummaryEntry =
+ boosted_trees::quantiles::WeightedQuantilesSummary<double,
+ double>::SummaryEntry;
+using Stream =
+ boosted_trees::quantiles::WeightedQuantilesStream<double, double>;
+
+TEST(GetQuantileSpecs, InvalidEps) {
+ EXPECT_DEATH({ Stream::GetQuantileSpecs(-0.01, 0L); }, "eps >= 0");
+ EXPECT_DEATH({ Stream::GetQuantileSpecs(1.01, 0L); }, "eps < 1");
+}
+
+TEST(GetQuantileSpecs, ZeroEps) {
+ EXPECT_DEATH({ Stream::GetQuantileSpecs(0.0, 0L); }, "max_elements > 0");
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.0, 1LL), Tuple(1LL, 2LL));
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.0, 20LL), Tuple(1LL, 20LL));
+}
+
+TEST(GetQuantileSpecs, NonZeroEps) {
+ EXPECT_DEATH({ Stream::GetQuantileSpecs(0.01, 0L); }, "max_elements > 0");
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.1, 320LL), Tuple(4LL, 31LL));
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.01, 25600LL), Tuple(6LL, 501LL));
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.01, 104857600LL), Tuple(17LL, 1601LL));
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.1, 104857600LL), Tuple(20LL, 191LL));
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.01, 1LL << 40), Tuple(29LL, 2801LL));
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.001, 1LL << 40), Tuple(26LL, 25001LL));
+}
+
+class WeightedQuantilesStreamTest : public ::testing::Test {};
+
+// Stream generators.
+void GenerateFixedUniformSummary(int32 worker_id, int64 max_elements,
+ double *total_weight, Stream *stream) {
+ for (int64 i = 0; i < max_elements; ++i) {
+ const double x = static_cast<double>(i) / max_elements;
+ stream->PushEntry(x, 1.0);
+ ++(*total_weight);
+ }
+ stream->Finalize();
+}
+
+void GenerateFixedNonUniformSummary(int32 worker_id, int64 max_elements,
+ double *total_weight, Stream *stream) {
+ for (int64 i = 0; i < max_elements; ++i) {
+ const double x = static_cast<double>(i) / max_elements;
+ stream->PushEntry(x, x);
+ (*total_weight) += x;
+ }
+ stream->Finalize();
+}
+
+void GenerateRandUniformFixedWeightsSummary(int32 worker_id, int64 max_elements,
+ double *total_weight,
+ Stream *stream) {
+ // Simulate uniform distribution stream.
+ random::PhiloxRandom philox(13 + worker_id);
+ random::SimplePhilox rand(&philox);
+ for (int64 i = 0; i < max_elements; ++i) {
+ const double x = rand.RandDouble();
+ stream->PushEntry(x, 1);
+ ++(*total_weight);
+ }
+ stream->Finalize();
+}
+
+void GenerateRandUniformRandWeightsSummary(int32 worker_id, int64 max_elements,
+ double *total_weight,
+ Stream *stream) {
+ // Simulate uniform distribution stream.
+ random::PhiloxRandom philox(13 + worker_id);
+ random::SimplePhilox rand(&philox);
+ for (int64 i = 0; i < max_elements; ++i) {
+ const double x = rand.RandDouble();
+ const double w = rand.RandDouble();
+ stream->PushEntry(x, w);
+ (*total_weight) += w;
+ }
+ stream->Finalize();
+}
+
+// Single worker tests.
+void TestSingleWorkerStreams(
+ double eps, int64 max_elements,
+ const std::function<void(int32, int64, double *, Stream *)>
+ &worker_summary_generator,
+ std::initializer_list<double> expected_quantiles,
+ double quantiles_matcher_epsilon) {
+ // Generate single stream.
+ double total_weight = 0;
+ Stream stream(eps, max_elements);
+ worker_summary_generator(0, max_elements, &total_weight, &stream);
+
+ // Ensure we didn't lose track of any elements and are
+ // within approximation error bound.
+ EXPECT_LE(stream.ApproximationError(), eps);
+ EXPECT_NEAR(stream.GetFinalSummary().TotalWeight(), total_weight, 1e-6);
+
+ // Verify expected quantiles.
+ int i = 0;
+ auto actuals = stream.GenerateQuantiles(expected_quantiles.size() - 1);
+ for (auto expected_quantile : expected_quantiles) {
+ EXPECT_NEAR(actuals[i], expected_quantile, quantiles_matcher_epsilon);
+ ++i;
+ }
+}
+
+// Stream generators.
+void GenerateOneValue(int32 worker_id, int64 max_elements, double *total_weight,
+ Stream *stream) {
+ stream->PushEntry(10, 1);
+ ++(*total_weight);
+ stream->Finalize();
+}
+
+void GenerateOneZeroWeightedValue(int32 worker_id, int64 max_elements,
+ double *total_weight, Stream *stream) {
+ stream->PushEntry(10, 0);
+ stream->Finalize();
+}
+
+TEST(WeightedQuantilesStreamTest, OneValue) {
+ const double eps = 0.01;
+ const int64 max_elements = 1 << 16;
+ TestSingleWorkerStreams(eps, max_elements, GenerateOneValue,
+ {10.0, 10.0, 10.0, 10.0, 10.0}, 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, OneZeroWeightValue) {
+ const double eps = 0.01;
+ const int64 max_elements = 1 << 16;
+ TestSingleWorkerStreams(eps, max_elements, GenerateOneZeroWeightedValue, {},
+ 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, FixedUniform) {
+ const double eps = 0.01;
+ const int64 max_elements = 1 << 16;
+ TestSingleWorkerStreams(eps, max_elements, GenerateFixedUniformSummary,
+ {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0},
+ 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, FixedNonUniform) {
+ const double eps = 0.01;
+ const int64 max_elements = 1 << 16;
+ TestSingleWorkerStreams(eps, max_elements, GenerateFixedNonUniformSummary,
+ {0, std::sqrt(0.1), std::sqrt(0.2), std::sqrt(0.3),
+ std::sqrt(0.4), std::sqrt(0.5), std::sqrt(0.6),
+ std::sqrt(0.7), std::sqrt(0.8), std::sqrt(0.9), 1.0},
+ 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, RandUniformFixedWeights) {
+ const double eps = 0.01;
+ const int64 max_elements = 1 << 16;
+ TestSingleWorkerStreams(
+ eps, max_elements, GenerateRandUniformFixedWeightsSummary,
+ {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, RandUniformRandWeights) {
+ const double eps = 0.01;
+ const int64 max_elements = 1 << 16;
+ TestSingleWorkerStreams(
+ eps, max_elements, GenerateRandUniformRandWeightsSummary,
+ {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, 1e-2);
+}
+
+// Distributed tests.
+void TestDistributedStreams(
+ int32 num_workers, double eps, int64 max_elements,
+ const std::function<void(int32, int64, double *, Stream *)>
+ &worker_summary_generator,
+ std::initializer_list<double> expected_quantiles,
+ double quantiles_matcher_epsilon) {
+ // Simulate streams on each worker running independently
+ double total_weight = 0;
+ std::vector<std::vector<SummaryEntry>> worker_summaries;
+ for (int32 i = 0; i < num_workers; ++i) {
+ Stream stream(eps / 2, max_elements);
+ worker_summary_generator(i, max_elements / num_workers, &total_weight,
+ &stream);
+ worker_summaries.push_back(stream.GetFinalSummary().GetEntryList());
+ }
+
+ // In the accumulation phase, we aggregate the summaries from each worker
+ // and build an overall summary while maintaining error bounds by ensuring we
+ // don't increase the error by more than eps / 2.
+ Stream reducer_stream(eps, max_elements);
+ for (const auto &summary : worker_summaries) {
+ reducer_stream.PushSummary(summary);
+ }
+ reducer_stream.Finalize();
+
+ // Ensure we didn't lose track of any elements and are
+ // within approximation error bound.
+ EXPECT_LE(reducer_stream.ApproximationError(), eps);
+ EXPECT_NEAR(reducer_stream.GetFinalSummary().TotalWeight(), total_weight,
+ total_weight);
+
+ // Verify expected quantiles.
+ int i = 0;
+ auto actuals =
+ reducer_stream.GenerateQuantiles(expected_quantiles.size() - 1);
+ for (auto expected_quantile : expected_quantiles) {
+ EXPECT_NEAR(actuals[i], expected_quantile, quantiles_matcher_epsilon);
+ ++i;
+ }
+}
+
+TEST(WeightedQuantilesStreamTest, FixedUniformDistributed) {
+ const int32 num_workers = 10;
+ const double eps = 0.01;
+ const int64 max_elements = num_workers * (1 << 16);
+ TestDistributedStreams(
+ num_workers, eps, max_elements, GenerateFixedUniformSummary,
+ {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, FixedNonUniformDistributed) {
+ const int32 num_workers = 10;
+ const double eps = 0.01;
+ const int64 max_elements = num_workers * (1 << 16);
+ TestDistributedStreams(num_workers, eps, max_elements,
+ GenerateFixedNonUniformSummary,
+ {0, std::sqrt(0.1), std::sqrt(0.2), std::sqrt(0.3),
+ std::sqrt(0.4), std::sqrt(0.5), std::sqrt(0.6),
+ std::sqrt(0.7), std::sqrt(0.8), std::sqrt(0.9), 1.0},
+ 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, RandUniformFixedWeightsDistributed) {
+ const int32 num_workers = 10;
+ const double eps = 0.01;
+ const int64 max_elements = num_workers * (1 << 16);
+ TestDistributedStreams(
+ num_workers, eps, max_elements, GenerateRandUniformFixedWeightsSummary,
+ {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, RandUniformRandWeightsDistributed) {
+ const int32 num_workers = 10;
+ const double eps = 0.01;
+ const int64 max_elements = num_workers * (1 << 16);
+ TestDistributedStreams(
+ num_workers, eps, max_elements, GenerateRandUniformRandWeightsSummary,
+ {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, 1e-2);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h
new file mode 100644
index 0000000000..31d7fe25a4
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h
@@ -0,0 +1,344 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
+
+#include <cstring>
+#include <vector>
+
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace quantiles {
+
+// Summary holding a sorted block of entries with upper bound guarantees
+// over the approximation error.
+template <typename ValueType, typename WeightType,
+ typename CompareFn = std::less<ValueType>>
+class WeightedQuantilesSummary {
+ public:
+ using Buffer = WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>;
+ using BufferEntry = typename Buffer::BufferEntry;
+
+ struct SummaryEntry {
+ SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min,
+ const WeightType& max) {
+ // Explicitly initialize all of memory (including padding from memory
+ // alignment) to allow the struct to be msan-resistant "plain old data".
+ //
+ // POD = http://en.cppreference.com/w/cpp/concept/PODType
+ memset(this, 0, sizeof(*this));
+
+ value = v;
+ weight = w;
+ min_rank = min;
+ max_rank = max;
+ }
+
+ SummaryEntry() {
+ memset(this, 0, sizeof(*this));
+
+ value = ValueType();
+ weight = 0;
+ min_rank = 0;
+ max_rank = 0;
+ }
+
+ bool operator==(const SummaryEntry& other) const {
+ return value == other.value && weight == other.weight &&
+ min_rank == other.min_rank && max_rank == other.max_rank;
+ }
+ friend std::ostream& operator<<(std::ostream& strm,
+ const SummaryEntry& entry) {
+ return strm << "{" << entry.value << ", " << entry.weight << ", "
+ << entry.min_rank << ", " << entry.max_rank << "}";
+ }
+
+ // Max rank estimate for previous smaller value.
+ WeightType PrevMaxRank() const { return max_rank - weight; }
+
+ // Min rank estimate for next larger value.
+ WeightType NextMinRank() const { return min_rank + weight; }
+
+ ValueType value;
+ WeightType weight;
+ WeightType min_rank;
+ WeightType max_rank;
+ };
+
+ // Re-construct summary from the specified buffer.
+ void BuildFromBufferEntries(const std::vector<BufferEntry>& buffer_entries) {
+ entries_.clear();
+ entries_.reserve(buffer_entries.size());
+ WeightType cumulative_weight = 0;
+ for (const auto& entry : buffer_entries) {
+ WeightType current_weight = entry.weight;
+ entries_.emplace_back(entry.value, entry.weight, cumulative_weight,
+ cumulative_weight + current_weight);
+ cumulative_weight += current_weight;
+ }
+ }
+
+ // Re-construct summary from the specified summary entries.
+ void BuildFromSummaryEntries(
+ const std::vector<SummaryEntry>& summary_entries) {
+ entries_.clear();
+ entries_.reserve(summary_entries.size());
+ entries_.insert(entries_.begin(), summary_entries.begin(),
+ summary_entries.end());
+ }
+
+ // Merges two summaries through an algorithm that's derived from MergeSort
+ // for summary entries while guaranteeing that the max approximation error
+ // of the final merged summary is no greater than the approximation errors
+ // of each individual summary.
+ // For example consider summaries where each entry is of the form
+ // (element, weight, min rank, max rank):
+ // summary entries 1: (1, 3, 0, 3), (4, 2, 3, 5)
+ // summary entries 2: (3, 1, 0, 1), (4, 1, 1, 2)
+ // merged: (1, 3, 0, 3), (3, 1, 3, 4), (4, 3, 4, 7).
+ void Merge(const WeightedQuantilesSummary& other_summary) {
+ // Make sure we have something to merge.
+ const auto& other_entries = other_summary.entries_;
+ if (other_entries.empty()) {
+ return;
+ }
+ if (entries_.empty()) {
+ BuildFromSummaryEntries(other_summary.entries_);
+ return;
+ }
+
+ // Move current entries to make room for a new buffer.
+ std::vector<SummaryEntry> base_entries(std::move(entries_));
+ entries_.clear();
+ entries_.reserve(base_entries.size() + other_entries.size());
+
+ // Merge entries maintaining ranks. The idea is to stack values
+ // in order which we can do in linear time as the two summaries are
+ // already sorted. We keep track of the next lower rank from either
+ // summary and update it as we pop elements from the summaries.
+ // We handle the special case when the next two elements from either
+ // summary are equal, in which case we just merge the two elements
+ // and simultaneously update both ranks.
+ auto it1 = base_entries.cbegin();
+ auto it2 = other_entries.cbegin();
+ WeightType next_min_rank1 = 0;
+ WeightType next_min_rank2 = 0;
+ while (it1 != base_entries.cend() && it2 != other_entries.cend()) {
+ if (kCompFn(it1->value, it2->value)) { // value1 < value2
+ // Take value1 and use the last added value2 to compute
+ // the min rank and the current value2 to compute the max rank.
+ entries_.emplace_back(it1->value, it1->weight,
+ it1->min_rank + next_min_rank2,
+ it1->max_rank + it2->PrevMaxRank());
+ // Update next min rank 1.
+ next_min_rank1 = it1->NextMinRank();
+ ++it1;
+ } else if (kCompFn(it2->value, it1->value)) { // value1 > value2
+ // Take value2 and use the last added value1 to compute
+ // the min rank and the current value1 to compute the max rank.
+ entries_.emplace_back(it2->value, it2->weight,
+ it2->min_rank + next_min_rank1,
+ it2->max_rank + it1->PrevMaxRank());
+ // Update next min rank 2.
+ next_min_rank2 = it2->NextMinRank();
+ ++it2;
+ } else { // value1 == value2
+ // Straight additive merger of the two entries into one.
+ entries_.emplace_back(it1->value, it1->weight + it2->weight,
+ it1->min_rank + it2->min_rank,
+ it1->max_rank + it2->max_rank);
+ // Update next min ranks for both.
+ next_min_rank1 = it1->NextMinRank();
+ next_min_rank2 = it2->NextMinRank();
+ ++it1;
+ ++it2;
+ }
+ }
+
+ // Fill in any residual.
+ while (it1 != base_entries.cend()) {
+ entries_.emplace_back(it1->value, it1->weight,
+ it1->min_rank + next_min_rank2,
+ it1->max_rank + other_entries.back().max_rank);
+ ++it1;
+ }
+ while (it2 != other_entries.cend()) {
+ entries_.emplace_back(it2->value, it2->weight,
+ it2->min_rank + next_min_rank1,
+ it2->max_rank + base_entries.back().max_rank);
+ ++it2;
+ }
+ }
+
+ // Compresses buffer into desired size. The size specification is
+ // considered a hint as we always keep the first and last elements and
+ // maintain strict approximation error bounds.
+ // The approximation error delta is taken as the max of either the requested
+ // min error or 1 / size_hint.
+ // After compression, the approximation error is guaranteed to increase
+ // by no more than that error delta.
+ // This algorithm is linear in the original size of the summary and is
+ // designed to be cache-friendly.
+ void Compress(int64 size_hint, double min_eps = 0) {
+ // No-op if we're already within the size requirement.
+ size_hint = std::max(size_hint, int64{2});
+ if (entries_.size() <= size_hint) {
+ return;
+ }
+
+ // First compute the max error bound delta resulting from this compression.
+ double eps_delta = TotalWeight() * std::max(1.0 / size_hint, min_eps);
+
+ // Compress elements ensuring approximation bounds and elements diversity
+ // are both maintained.
+ int64 add_accumulator = 0, add_step = entries_.size();
+ auto write_it = entries_.begin() + 1, last_it = write_it;
+ for (auto read_it = entries_.begin(); read_it + 1 != entries_.end();) {
+ auto next_it = read_it + 1;
+ while (next_it != entries_.end() && add_accumulator < add_step &&
+ next_it->PrevMaxRank() - read_it->NextMinRank() <= eps_delta) {
+ add_accumulator += size_hint;
+ ++next_it;
+ }
+ if (read_it == next_it - 1) {
+ ++read_it;
+ } else {
+ read_it = next_it - 1;
+ }
+ (*write_it++) = (*read_it);
+ last_it = read_it;
+ add_accumulator -= add_step;
+ }
+ // Write last element and resize.
+ if (last_it + 1 != entries_.end()) {
+ (*write_it++) = entries_.back();
+ }
+ entries_.resize(write_it - entries_.begin());
+ }
+
+ // To construct the boundaries we first run a soft compress over a copy
+ // of the summary and retrieve the values.
+ // The resulting boundaries are guaranteed to both contain at least
+ // num_boundaries unique elements and maintain approximation bounds.
+ std::vector<ValueType> GenerateBoundaries(int64 num_boundaries) const {
+ std::vector<ValueType> output;
+ if (entries_.empty()) {
+ return output;
+ }
+
+ // Generate soft compressed summary.
+ WeightedQuantilesSummary<ValueType, WeightType, CompareFn>
+ compressed_summary;
+ compressed_summary.BuildFromSummaryEntries(entries_);
+ // Set an epsilon for compression that's at most 1.0 / num_boundaries
+ // more than epsilon of original our summary since the compression operation
+ // adds ~1.0/num_boundaries to final approximation error.
+ float compression_eps = ApproximationError() + (1.0 / num_boundaries);
+ compressed_summary.Compress(num_boundaries, compression_eps);
+
+ // Return boundaries.
+ output.reserve(compressed_summary.entries_.size());
+ for (const auto& entry : compressed_summary.entries_) {
+ output.push_back(entry.value);
+ }
+ return output;
+ }
+
+ // To construct the desired n-quantiles we repetitively query n ranks from the
+ // original summary. The following algorithm is an efficient cache-friendly
+ // O(n) implementation of that idea which avoids the cost of the repetitive
+ // full rank queries O(nlogn).
+ std::vector<ValueType> GenerateQuantiles(int64 num_quantiles) const {
+ std::vector<ValueType> output;
+ if (entries_.empty()) {
+ return output;
+ }
+ num_quantiles = std::max(num_quantiles, int64{2});
+ output.reserve(num_quantiles + 1);
+
+ // Make successive rank queries to get boundaries.
+ // We always keep the first (min) and last (max) entries.
+ for (size_t cur_idx = 0, rank = 0; rank <= num_quantiles; ++rank) {
+ // This step boils down to finding the next element sub-range defined by
+ // r = (rmax[i + 1] + rmin[i + 1]) / 2 where the desired rank d < r.
+ WeightType d_2 = 2 * (rank * entries_.back().max_rank / num_quantiles);
+ size_t next_idx = cur_idx + 1;
+ while (next_idx < entries_.size() &&
+ d_2 >= entries_[next_idx].min_rank + entries_[next_idx].max_rank) {
+ ++next_idx;
+ }
+ cur_idx = next_idx - 1;
+
+ // Determine insertion order.
+ if (next_idx == entries_.size() ||
+ d_2 < entries_[cur_idx].NextMinRank() +
+ entries_[next_idx].PrevMaxRank()) {
+ output.push_back(entries_[cur_idx].value);
+ } else {
+ output.push_back(entries_[next_idx].value);
+ }
+ }
+ return output;
+ }
+
+ // Calculates current approximation error which should always be <= eps.
+ double ApproximationError() const {
+ if (entries_.empty()) {
+ return 0;
+ }
+
+ WeightType max_gap = 0;
+ for (auto it = entries_.cbegin() + 1; it < entries_.end(); ++it) {
+ max_gap = std::max(max_gap,
+ std::max(it->max_rank - it->min_rank - it->weight,
+ it->PrevMaxRank() - (it - 1)->NextMinRank()));
+ }
+ return static_cast<double>(max_gap) / TotalWeight();
+ }
+
+ ValueType MinValue() const {
+ return !entries_.empty() ? entries_.front().value
+ : std::numeric_limits<ValueType>::max();
+ }
+ ValueType MaxValue() const {
+ return !entries_.empty() ? entries_.back().value
+ : std::numeric_limits<ValueType>::max();
+ }
+ WeightType TotalWeight() const {
+ return !entries_.empty() ? entries_.back().max_rank : 0;
+ }
+ int64 Size() const { return entries_.size(); }
+ void Clear() { entries_.clear(); }
+ const std::vector<SummaryEntry>& GetEntryList() const { return entries_; }
+
+ private:
+ // Comparison function.
+ static constexpr decltype(CompareFn()) kCompFn = CompareFn();
+
+ // Summary entries.
+ std::vector<SummaryEntry> entries_;
+};
+
+template <typename ValueType, typename WeightType, typename CompareFn>
+constexpr decltype(CompareFn())
+ WeightedQuantilesSummary<ValueType, WeightType, CompareFn>::kCompFn;
+
+} // namespace quantiles
+} // namespace boosted_trees
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary_test.cc b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary_test.cc
new file mode 100644
index 0000000000..ccd1215cf4
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary_test.cc
@@ -0,0 +1,223 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+using Buffer = boosted_trees::quantiles::WeightedQuantilesBuffer<float, float>;
+using BufferEntry =
+ boosted_trees::quantiles::WeightedQuantilesBuffer<float,
+ float>::BufferEntry;
+using Summary =
+ boosted_trees::quantiles::WeightedQuantilesSummary<float, float>;
+using SummaryEntry =
+ boosted_trees::quantiles::WeightedQuantilesSummary<float,
+ float>::SummaryEntry;
+
+class WeightedQuantilesSummaryTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ // Constructs a buffer of 10 weighted unique entries.
+ buffer1_.reset(new Buffer(10, 1000));
+ buffer1_->PushEntry(5, 9);
+ buffer1_->PushEntry(2, 3);
+ buffer1_->PushEntry(-1, 7);
+ buffer1_->PushEntry(-7, 1);
+ buffer1_->PushEntry(3, 2);
+ buffer1_->PushEntry(-2, 3);
+ buffer1_->PushEntry(21, 8);
+ buffer1_->PushEntry(-13, 4);
+ buffer1_->PushEntry(8, 2);
+ buffer1_->PushEntry(-5, 6);
+
+ // Constructs a buffer of 7 weighted unique entries.
+ buffer2_.reset(new Buffer(7, 1000));
+ buffer2_->PushEntry(9, 2);
+ buffer2_->PushEntry(-7, 3);
+ buffer2_->PushEntry(2, 1);
+ buffer2_->PushEntry(4, 13);
+ buffer2_->PushEntry(0, 5);
+ buffer2_->PushEntry(-5, 3);
+ buffer2_->PushEntry(11, 3);
+ }
+
+ void TearDown() override { buffer1_->Clear(); }
+
+ std::unique_ptr<Buffer> buffer1_;
+ std::unique_ptr<Buffer> buffer2_;
+ const double buffer1_min_value_ = -13;
+ const double buffer1_max_value_ = 21;
+ const double buffer1_total_weight_ = 45;
+ const double buffer2_min_value_ = -7;
+ const double buffer2_max_value_ = 11;
+ const double buffer2_total_weight_ = 30;
+};
+
+TEST_F(WeightedQuantilesSummaryTest, BuildFromBuffer) {
+ Summary summary;
+ summary.BuildFromBufferEntries(buffer1_->GenerateEntryList());
+
+ // We expect no approximation error because no compress operation occurred.
+ EXPECT_EQ(summary.ApproximationError(), 0);
+
+ // Check first and last elements in the summary.
+ const auto& entries = summary.GetEntryList();
+ // First element's rmin should be zero.
+ EXPECT_EQ(summary.MinValue(), buffer1_min_value_);
+ EXPECT_EQ(entries.front(), SummaryEntry(-13, 4, 0, 4));
+ // Last element's rmax should be cumulative weight.
+ EXPECT_EQ(summary.MaxValue(), buffer1_max_value_);
+ EXPECT_EQ(entries.back(), SummaryEntry(21, 8, 37, 45));
+ // Check total weight.
+ EXPECT_EQ(summary.TotalWeight(), buffer1_total_weight_);
+}
+
+TEST_F(WeightedQuantilesSummaryTest, CompressSeparately) {
+ const auto entry_list = buffer1_->GenerateEntryList();
+ for (int new_size = 9; new_size >= 2; --new_size) {
+ Summary summary;
+ summary.BuildFromBufferEntries(entry_list);
+ summary.Compress(new_size);
+
+ // Expect a max approximation error of 1 / n
+ // ie. eps0 + 1/n but eps0 = 0.
+ EXPECT_TRUE(summary.Size() >= new_size && summary.Size() <= new_size + 2);
+ EXPECT_LE(summary.ApproximationError(), 1.0 / new_size);
+
+ // Min/Max elements and total weight should not change.
+ EXPECT_EQ(summary.MinValue(), buffer1_min_value_);
+ EXPECT_EQ(summary.MaxValue(), buffer1_max_value_);
+ EXPECT_EQ(summary.TotalWeight(), buffer1_total_weight_);
+ }
+}
+
+TEST_F(WeightedQuantilesSummaryTest, CompressSequentially) {
+ Summary summary;
+ summary.BuildFromBufferEntries(buffer1_->GenerateEntryList());
+ for (int new_size = 9; new_size >= 2; new_size -= 2) {
+ double prev_eps = summary.ApproximationError();
+ summary.Compress(new_size);
+
+ // Expect a max approximation error of prev_eps + 1 / n.
+ EXPECT_TRUE(summary.Size() >= new_size && summary.Size() <= new_size + 2);
+ EXPECT_LE(summary.ApproximationError(), prev_eps + 1.0 / new_size);
+
+ // Min/Max elements and total weight should not change.
+ EXPECT_EQ(summary.MinValue(), buffer1_min_value_);
+ EXPECT_EQ(summary.MaxValue(), buffer1_max_value_);
+ EXPECT_EQ(summary.TotalWeight(), buffer1_total_weight_);
+ }
+}
+
+TEST_F(WeightedQuantilesSummaryTest, CompressRandomized) {
+ // Check multiple size compressions and ensure approximation bounds
+ // are always respected.
+ int prev_size = 1;
+ int size = 2;
+ float max_value = 1 << 20;
+ while (size < (1 << 16)) {
+ // Create buffer of size from uniform random elements.
+ Buffer buffer(size, size << 4);
+ random::PhiloxRandom philox(13);
+ random::SimplePhilox rand(&philox);
+ for (int i = 0; i < size; ++i) {
+ buffer.PushEntry(rand.RandFloat() * max_value,
+ rand.RandFloat() * max_value);
+ }
+
+ // Create summary and compress.
+ Summary summary;
+ summary.BuildFromBufferEntries(buffer.GenerateEntryList());
+ int new_size = std::max(rand.Uniform(size), 2u);
+ summary.Compress(new_size);
+
+ // Ensure approximation error is acceptable.
+ EXPECT_TRUE(summary.Size() >= new_size && summary.Size() <= new_size + 2);
+ EXPECT_LE(summary.ApproximationError(), 1.0 / new_size);
+
+ // Update size to next fib number.
+ size_t last_size = size;
+ size += prev_size;
+ prev_size = last_size;
+ }
+}
+
+TEST_F(WeightedQuantilesSummaryTest, MergeSymmetry) {
+ // Create two separate summaries and merge.
+ const auto list_1 = buffer1_->GenerateEntryList();
+ const auto list_2 = buffer2_->GenerateEntryList();
+ Summary summary1;
+ summary1.BuildFromBufferEntries(list_1);
+ Summary summary2;
+ summary2.BuildFromBufferEntries(list_2);
+
+ // Merge summary 2 into 1 and verify.
+ summary1.Merge(summary2);
+ EXPECT_EQ(summary1.ApproximationError(), 0.0);
+ EXPECT_EQ(summary1.MinValue(),
+ std::min(buffer1_min_value_, buffer2_min_value_));
+ EXPECT_EQ(summary1.MaxValue(),
+ std::max(buffer1_max_value_, buffer2_max_value_));
+ EXPECT_EQ(summary1.TotalWeight(),
+ buffer1_total_weight_ + buffer2_total_weight_);
+ EXPECT_EQ(summary1.Size(), 14); // 14 unique values.
+
+ // Merge summary 1 into 2 and verify same result.
+ summary1.BuildFromBufferEntries(list_1);
+ summary2.Merge(summary1);
+ EXPECT_EQ(summary2.ApproximationError(), 0.0);
+ EXPECT_EQ(summary2.MinValue(),
+ std::min(buffer1_min_value_, buffer2_min_value_));
+ EXPECT_EQ(summary2.MaxValue(),
+ std::max(buffer1_max_value_, buffer2_max_value_));
+ EXPECT_EQ(summary2.TotalWeight(),
+ buffer1_total_weight_ + buffer2_total_weight_);
+ EXPECT_EQ(summary2.Size(), 14); // 14 unique values.
+}
+
+TEST_F(WeightedQuantilesSummaryTest, CompressThenMerge) {
+ // Create two separate summaries and merge.
+ Summary summary1;
+ summary1.BuildFromBufferEntries(buffer1_->GenerateEntryList());
+ Summary summary2;
+ summary2.BuildFromBufferEntries(buffer2_->GenerateEntryList());
+
+ // Compress summaries.
+ summary1.Compress(5); // max error is 1/5.
+ const auto eps1 = 1.0 / 5;
+ EXPECT_LE(summary1.ApproximationError(), eps1);
+ summary2.Compress(3); // max error is 1/3.
+ const auto eps2 = 1.0 / 3;
+ EXPECT_LE(summary2.ApproximationError(), eps2);
+
+ // Merge guarantees an approximation error of max(eps1, eps2).
+ // Merge summary 2 into 1 and verify.
+ summary1.Merge(summary2);
+ EXPECT_LE(summary1.ApproximationError(), std::max(eps1, eps2));
+ EXPECT_EQ(summary1.MinValue(),
+ std::min(buffer1_min_value_, buffer2_min_value_));
+ EXPECT_EQ(summary1.MaxValue(),
+ std::max(buffer1_max_value_, buffer2_max_value_));
+ EXPECT_EQ(summary1.TotalWeight(),
+ buffer1_total_weight_ + buffer2_total_weight_);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc
index 0478c93280..3a72567655 100644
--- a/tensorflow/core/kernels/cast_op.cc
+++ b/tensorflow/core/kernels/cast_op.cc
@@ -98,7 +98,13 @@ void CastOpBase::Compute(OpKernelContext* ctx) {
ctx->set_output(0, inp);
} else {
Tensor in;
- in.UnsafeCopyFromInternal(inp, src_dtype_, inp.shape());
+ if (external_src_dtype_ != src_dtype_) {
+ // If the type is a quantized type we need to do an UnsafeCopyFromInternal
+ // since the src_dtype_ is different from external_src_type_.
+ in.UnsafeCopyFromInternal(inp, src_dtype_, inp.shape());
+ } else {
+ in = inp;
+ }
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
out->set_dtype(dst_dtype_);
diff --git a/tensorflow/core/kernels/colorspace_op.h b/tensorflow/core/kernels/colorspace_op.h
index 90bfce1419..4de14bc339 100644
--- a/tensorflow/core/kernels/colorspace_op.h
+++ b/tensorflow/core/kernels/colorspace_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_COLORSPACE_OP_H_
-#define TENSORFLOW_KERNELS_COLORSPACE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_COLORSPACE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_COLORSPACE_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -91,4 +91,4 @@ struct HSVToRGB {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_COLORSPACE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_COLORSPACE_OP_H_
diff --git a/tensorflow/core/kernels/concat_lib_cpu.h b/tensorflow/core/kernels/concat_lib_cpu.h
index 720b506537..29f3a427fe 100644
--- a/tensorflow/core/kernels/concat_lib_cpu.h
+++ b/tensorflow/core/kernels/concat_lib_cpu.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_CONCAT_LIB_CPU_H_
+#define TENSORFLOW_CORE_KERNELS_CONCAT_LIB_CPU_H_
+
#define EIGEN_USE_THREADS
#include <vector>
@@ -162,3 +165,5 @@ void ConcatSYCLImpl(
}
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_CONCAT_LIB_CPU_H_
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
index 375819a8a2..426c404f43 100644
--- a/tensorflow/core/kernels/constant_op.cc
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -259,8 +259,9 @@ class ZerosLikeOp : public OpKernel {
errors::InvalidArgument("ZerosLike non-scalar Tensor with "
"dtype=DT_VARIANT is not supported."));
const Variant& v = input.scalar<Variant>()();
- Tensor out(ctx->device()->GetAllocator(AllocatorAttributes()), DT_VARIANT,
- TensorShape({}));
+ // DT_VARIANT tensors must be allocated on CPU since they wrap C++
+ // objects which can not be efficiently represented in GPU memory.
+ Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
Variant* out_v = &(out.scalar<Variant>()());
OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(
ctx, ZEROS_LIKE_VARIANT_UNARY_OP, v, out_v));
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
index 5bf709af08..fc0a2f123f 100644
--- a/tensorflow/core/kernels/conv_grad_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -63,7 +63,7 @@ Status ConvBackpropExtractAndVerifyDimensionV2(
return errors::InvalidArgument(
label, ": Size of out_backprop doesn't match computed: ", "actual = ",
dim->output_size, ", computed = ", out_size,
- "spatial_dim: ", spatial_dim, " input: ", dim->input_size,
+ " spatial_dim: ", spatial_dim, " input: ", dim->input_size,
" filter: ", dim->filter_size, " output: ", dim->output_size,
" stride: ", dim->stride, " dilation: ", dim->dilation);
}
diff --git a/tensorflow/core/kernels/cross_op.h b/tensorflow/core/kernels/cross_op.h
index ca6beba52b..45bc46a921 100644
--- a/tensorflow/core/kernels/cross_op.h
+++ b/tensorflow/core/kernels/cross_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_COLORSPACE_OP_H_
-#define TENSORFLOW_KERNELS_COLORSPACE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CROSS_OP_H_
+#define TENSORFLOW_CORE_KERNELS_CROSS_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -51,4 +51,4 @@ struct Cross {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_COLORSPACE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_CROSS_OP_H_
diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h
index b2e8ee23a9..2c30d036df 100644
--- a/tensorflow/core/kernels/cuda_solvers.h
+++ b/tensorflow/core/kernels/cuda_solvers.h
@@ -14,6 +14,9 @@ limitations under the License.
==============================================================================
*/
+#ifndef TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
+#define TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
+
// This header declares the class CudaSolver, which contains wrappers of linear
// algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow
// kernels.
@@ -433,3 +436,5 @@ inline DeviceLapackInfo CudaSolver::GetDeviceLapackInfo(
} // namespace tensorflow
#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc
index d6a2403816..35662e278f 100644
--- a/tensorflow/core/kernels/cwise_op_div.cc
+++ b/tensorflow/core/kernels/cwise_op_div.cc
@@ -24,8 +24,7 @@ REGISTER5(BinaryOp, CPU, "TruncateDiv", functor::safe_div, uint8, uint16, int16,
int32, int64);
REGISTER6(BinaryOp, CPU, "RealDiv", functor::div, float, Eigen::half, double,
bfloat16, complex64, complex128);
-REGISTER5(BinaryOp, CPU, "UnsafeDiv", functor::unsafe_div, float, double, int16,
- int32, int64);
+REGISTER2(BinaryOp, CPU, "DivNoNan", functor::div_no_nan, float, double);
#if GOOGLE_CUDA
REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 1014519059..de164c1c09 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -154,8 +154,8 @@ struct functor_traits<safe_div_or_mod_op<T, DivOrMod>> {
};
template <typename T>
-struct unsafe_div_op {
- EIGEN_EMPTY_STRUCT_CTOR(unsafe_div_op)
+struct div_no_nan_op {
+ EIGEN_EMPTY_STRUCT_CTOR(div_no_nan_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a,
const T& b) const {
if (b != 0) {
@@ -167,7 +167,7 @@ struct unsafe_div_op {
};
template <typename T>
-struct functor_traits<unsafe_div_op<T>> {
+struct functor_traits<div_no_nan_op<T>> {
enum {
Cost = functor_traits<scalar_quotient_op<T>>::Cost + NumTraits<T>::AddCost,
PacketAccess = false,
@@ -742,7 +742,7 @@ struct safe_div : base<T, Eigen::internal::safe_div_or_mod_op<
};
template <typename T>
-struct unsafe_div : base<T, Eigen::internal::unsafe_div_op<T>> {};
+struct div_no_nan : base<T, Eigen::internal::div_no_nan_op<T>> {};
template <typename T>
struct fmod : base<T, Eigen::internal::scalar_fmod_op<T>> {};
diff --git a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
index 965e42dcce..cfae273bf4 100644
--- a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
+++ b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
@@ -17,8 +17,8 @@ limitations under the License.
#error This file must only be included when building with Cuda support
#endif
-#ifndef TENSORFLOW_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
-#define TENSORFLOW_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
+#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
#define EIGEN_USE_GPU
@@ -188,4 +188,4 @@ struct ApproximateEqual<GPUDevice, T> {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
+#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
diff --git a/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h b/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h
index e81b840a50..15e5de0f72 100644
--- a/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h
+++ b/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h
@@ -17,8 +17,8 @@ limitations under the License.
#error This file must only be included when building with Cuda support
#endif
-#ifndef TENSORFLOW_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
-#define TENSORFLOW_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
+#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
#define EIGEN_USE_GPU
@@ -68,4 +68,4 @@ struct SimpleBinaryFunctor<GPUDevice, Functor> {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
+#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index e492a8215a..cfa96d910d 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -279,7 +279,12 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (!current_worker->outputs.empty()) {
// We have an element!
next_index_ = index;
- if (i == 0) {
+ const bool element_acquired_sloppily =
+ dataset()->sloppy_ && i > 1;
+ if (!element_acquired_sloppily) {
+ // If the element was acquired in the regular (non-sloppy)
+ // order, then advance the current block and cycle pointers to
+ // the next element in the regular order.
block_count_++;
if (block_count_ == dataset()->block_length_) {
next_index_ = (index + 1) % interleave_indices_.size();
diff --git a/tensorflow/core/kernels/gemm_functors.h b/tensorflow/core/kernels/gemm_functors.h
index 4b30c1f17f..1c80844085 100644
--- a/tensorflow/core/kernels/gemm_functors.h
+++ b/tensorflow/core/kernels/gemm_functors.h
@@ -24,6 +24,9 @@ limitations under the License.
#error "EIGEN_USE_THREADS must be enabled by all .cc files including this."
#endif // EIGEN_USE_THREADS
+#ifndef TENSORFLOW_CORE_KERNELS_GEMM_FUNCTORS_H_
+#define TENSORFLOW_CORE_KERNELS_GEMM_FUNCTORS_H_
+
#include <string.h>
#include <map>
#include <vector>
@@ -116,3 +119,5 @@ class FastGemmFunctor<float, float, float> {
}
};
#endif // USE_CBLAS_GEMM
+
+#endif // TENSORFLOW_CORE_KERNELS_GEMM_FUNCTORS_H_
diff --git a/tensorflow/core/kernels/hexagon/soc_interface.h b/tensorflow/core/kernels/hexagon/soc_interface.h
index 062103ed98..d1a41d47c8 100644
--- a/tensorflow/core/kernels/hexagon/soc_interface.h
+++ b/tensorflow/core/kernels/hexagon/soc_interface.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_HEXAGON_SOC_INTERFACE_H_
-#define TENSORFLOW_PLATFORM_HEXAGON_SOC_INTERFACE_H_
+#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_SOC_INTERFACE_H_
+#define TENSORFLOW_CORE_KERNELS_HEXAGON_SOC_INTERFACE_H_
// Declaration of APIs provided by hexagon shared library. This header is shared
// with both hexagon library built with qualcomm SDK and tensorflow.
@@ -111,4 +111,4 @@ void soc_interface_SetDebugFlag(uint64_t flag);
}
#endif // __cplusplus
-#endif // TENSORFLOW_PLATFORM_HEXAGON_SOC_INTERFACE_H_
+#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_SOC_INTERFACE_H_
diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h
index 42871c6113..b3f74c060b 100644
--- a/tensorflow/core/kernels/list_kernels.h
+++ b/tensorflow/core/kernels/list_kernels.h
@@ -261,14 +261,15 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
out_tensor.flat<dtype>().constant(dtype(0)); \
break;
- TF_CALL_NUMBER_TYPES(DTYPE_CASE)
+ TF_CALL_POD_TYPES(DTYPE_CASE)
#undef DTYPE_CASE
default:
return errors::InvalidArgument(
- "Trying to compute zeros_like for unsupported dtype",
- out_tensor.dtype());
+ "Trying to compute zeros_like for unsupported dtype ",
+ DataTypeString(out_tensor.dtype()));
}
+ y->tensors.emplace_back(out_tensor);
}
return Status::OK();
}
diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h
index 3977f16299..35ca2b9ad0 100644
--- a/tensorflow/core/kernels/lookup_table_op.h
+++ b/tensorflow/core/kernels/lookup_table_op.h
@@ -102,9 +102,12 @@ class LookupTableOp : public OpKernel {
~LookupTableOp() override {
// If the table object was not shared, delete it.
if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
- TF_CHECK_OK(
- cinfo_.resource_manager()->template Delete<lookup::LookupInterface>(
- cinfo_.container(), cinfo_.name()));
+ if (!cinfo_.resource_manager()
+ ->template Delete<lookup::LookupInterface>(cinfo_.container(),
+ cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
}
}
diff --git a/tensorflow/core/kernels/matrix_band_part_op.h b/tensorflow/core/kernels/matrix_band_part_op.h
index 97cc950793..b04e36db8e 100644
--- a/tensorflow/core/kernels/matrix_band_part_op.h
+++ b/tensorflow/core/kernels/matrix_band_part_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
-#define TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_BAND_PART_OP_H_
+#define TENSORFLOW_CORE_KERNELS_MATRIX_BAND_PART_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -34,4 +34,4 @@ struct MatrixBandPartFunctor {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_MATRIX_BAND_PART_OP_H_
diff --git a/tensorflow/core/kernels/matrix_diag_op.h b/tensorflow/core/kernels/matrix_diag_op.h
index 14095845b8..108ba0f56b 100644
--- a/tensorflow/core/kernels/matrix_diag_op.h
+++ b/tensorflow/core/kernels/matrix_diag_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
-#define TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_DIAG_OP_H_
+#define TENSORFLOW_CORE_KERNELS_MATRIX_DIAG_OP_H_
// Generator definition for MatrixDiagOp, must be compilable by nvcc.
@@ -91,4 +91,4 @@ struct MatrixDiag {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_MATRIX_DIAG_OP_H_
diff --git a/tensorflow/core/kernels/matrix_solve_ls_op_impl.h b/tensorflow/core/kernels/matrix_solve_ls_op_impl.h
index 0e09078365..00a05a87a3 100644
--- a/tensorflow/core/kernels/matrix_solve_ls_op_impl.h
+++ b/tensorflow/core/kernels/matrix_solve_ls_op_impl.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_SOLVE_LS_OP_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_MATRIX_SOLVE_LS_OP_IMPL_H_
+
// See docs in ../ops/linalg_ops.cc.
#include "third_party/eigen3/Eigen/Cholesky"
@@ -159,3 +162,5 @@ class MatrixSolveLsOp : public LinearAlgebraOp<Scalar> {
};
} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_MATRIX_SOLVE_LS_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index 50c25e1da7..afbfaa83f3 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -82,11 +82,11 @@ struct MklConvBwdFilterParams {
};
template <typename T>
-class MklConv2DBwdFilterPrimitive : public MklPrimitive {
+class MklConvBwdFilterPrimitive : public MklPrimitive {
public:
- explicit MklConv2DBwdFilterPrimitive(
- const MklConvBwdFilterParams& convBwdFilterDims) :
- cpu_engine_(engine::cpu, 0) {
+ explicit MklConvBwdFilterPrimitive(
+ const MklConvBwdFilterParams& convBwdFilterDims)
+ : cpu_engine_(engine::cpu, 0) {
context_.bwd_filter_stream.reset(new stream(stream::kind::eager));
// create conv primitive
if (context_.conv_bwd_filter == nullptr) {
@@ -94,7 +94,7 @@ class MklConv2DBwdFilterPrimitive : public MklPrimitive {
}
}
- ~MklConv2DBwdFilterPrimitive() {}
+ ~MklConvBwdFilterPrimitive() {}
// Convolution backward weights with bias
// src_data: input data buffer of src
@@ -297,38 +297,36 @@ class MklConv2DBwdFilterPrimitive : public MklPrimitive {
};
template <typename T>
-class MklConv2DBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
+class MklConvBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
- static MklConv2DBwdFilterPrimitive<T>* Get(
+ static MklConvBwdFilterPrimitive<T>* Get(
const MklConvBwdFilterParams& convBwdFilterDims) {
- MklConv2DBwdFilterPrimitive<T>* conv2d_bwd_filter = nullptr;
+ MklConvBwdFilterPrimitive<T>* conv_bwd_filter = nullptr;
// look into the pool for reusable primitive
- conv2d_bwd_filter = dynamic_cast<MklConv2DBwdFilterPrimitive<T>*> (
- MklConv2DBwdFilterPrimitiveFactory<T>::GetInstance().GetConv2dBwdFilter(
- convBwdFilterDims));
-
- if (conv2d_bwd_filter == nullptr) {
- conv2d_bwd_filter = new MklConv2DBwdFilterPrimitive<T>(
- convBwdFilterDims);
- MklConv2DBwdFilterPrimitiveFactory<T>::GetInstance().SetConv2dBwdFilter(
- convBwdFilterDims, conv2d_bwd_filter);
+ conv_bwd_filter = dynamic_cast<MklConvBwdFilterPrimitive<T>*>(
+ MklConvBwdFilterPrimitiveFactory<T>::GetInstance().GetConvBwdFilter(
+ convBwdFilterDims));
+
+ if (conv_bwd_filter == nullptr) {
+ conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims);
+ MklConvBwdFilterPrimitiveFactory<T>::GetInstance().SetConvBwdFilter(
+ convBwdFilterDims, conv_bwd_filter);
}
- return conv2d_bwd_filter;
+ return conv_bwd_filter;
}
-
private:
- MklConv2DBwdFilterPrimitiveFactory() {}
- ~MklConv2DBwdFilterPrimitiveFactory() {}
+ MklConvBwdFilterPrimitiveFactory() {}
+ ~MklConvBwdFilterPrimitiveFactory() {}
- static MklConv2DBwdFilterPrimitiveFactory& GetInstance() {
- static MklConv2DBwdFilterPrimitiveFactory instance_;
+ static MklConvBwdFilterPrimitiveFactory& GetInstance() {
+ static MklConvBwdFilterPrimitiveFactory instance_;
return instance_;
}
static string CreateKey(const MklConvBwdFilterParams& convBwdFilterDims) {
- string prefix = "conv2d_bwd_filter";
+ string prefix = "conv_bwd_filter";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convBwdFilterDims.src_dims);
@@ -342,14 +340,14 @@ class MklConv2DBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
return key_creator.GetKey();
}
- MklPrimitive* GetConv2dBwdFilter(
+ MklPrimitive* GetConvBwdFilter(
const MklConvBwdFilterParams& convBwdFilterDims) {
string key = CreateKey(convBwdFilterDims);
return this->GetOp(key);
}
- void SetConv2dBwdFilter(
- const MklConvBwdFilterParams& convBwdFilterDims, MklPrimitive* op) {
+ void SetConvBwdFilter(const MklConvBwdFilterParams& convBwdFilterDims,
+ MklPrimitive* op) {
string key = CreateKey(convBwdFilterDims);
this->SetOp(key, op);
}
@@ -738,14 +736,13 @@ TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
#else
template <typename Device, class T, bool biasEnabled>
-class MklConv2DCustomBackpropFilterOp
- : public MklConv2DBackpropCommonOp<Device, T> {
+class MklConvCustomBackpropFilterOp
+ : public MklConvBackpropCommonOp<Device, T> {
public:
- explicit MklConv2DCustomBackpropFilterOp(OpKernelConstruction* context)
- : MklConv2DBackpropCommonOp<Device, T>(context) {
- }
+ explicit MklConvCustomBackpropFilterOp(OpKernelConstruction* context)
+ : MklConvBackpropCommonOp<Device, T>(context) {}
- ~MklConv2DCustomBackpropFilterOp() {}
+ ~MklConvCustomBackpropFilterOp() {}
void Compute(OpKernelContext* context) {
try {
@@ -753,6 +750,9 @@ class MklConv2DCustomBackpropFilterOp
MklDnnData<T> diff_dst(&cpu_engine_);
MklDnnData<T> diff_filter(&cpu_engine_); // output
+ // This flag indicates Conv2D or Conv3D
+ bool isConv2D = (this->strides_.size() == 4);
+
// Input tensors
const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
const Tensor& src_tensor = MklGetInput(context, kInputIdx);
@@ -813,7 +813,10 @@ class MklConv2DCustomBackpropFilterOp
&fwd_dst_dims, &padding_left, &padding_right);
if (!context->status().ok()) return;
- auto tf_fmt = TFDataFormatToMklDnnDataFormat(this->data_format_);
+ auto tf_fmt = isConv2D
+ ? TFDataFormatToMklDnnDataFormat(this->data_format_)
+ : TFDataFormatToMklDnn3DDataFormat(this->data_format_);
+
auto fwd_src_md =
src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
@@ -832,21 +835,19 @@ class MklConv2DCustomBackpropFilterOp
if (biasEnabled) {
TensorShape obp_tf_shape = GetTfShape(context, 2);
depth = (this->data_format_ == FORMAT_NCHW)
- ? obp_tf_shape.dim_size(1)
- : obp_tf_shape.dim_size(3);
+ ? obp_tf_shape.dim_size(1)
+ : obp_tf_shape.dim_size(isConv2D ? 3 : 4);
diff_bias_dims = {static_cast<int>(depth)};
}
+ for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
- dilations[kDilationH] -= 1;
- dilations[kDilationW] -= 1;
-
- MklConv2DBwdFilterPrimitive<T> *conv2d_bwd_filter = nullptr;
+ MklConvBwdFilterPrimitive<T>* conv_bwd_filter = nullptr;
MklConvBwdFilterParams convBwdFilterDims(fwd_src_dims, fwd_filter_dims,
diff_bias_dims, diff_dst_dims, strides, dilations, padding_left,
padding_right, TFPaddingToMklDnnPadding(this->padding_));
- conv2d_bwd_filter = MklConv2DBwdFilterPrimitiveFactory<T>::Get(
- convBwdFilterDims);
- auto bwd_filter_pd = conv2d_bwd_filter->GetPrimitiveDesc();
+ conv_bwd_filter =
+ MklConvBwdFilterPrimitiveFactory<T>::Get(convBwdFilterDims);
+ auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc();
// allocate output tensors: diff_fitler and diff_bias (w bias)
auto bwd_output_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims);
@@ -854,14 +855,26 @@ class MklConv2DCustomBackpropFilterOp
// diff_filter
MklDnnShape diff_filter_mkl_shape;
diff_filter_mkl_shape.SetMklTensor(false);
- // output_dims_mkl_order is in OIHW format.
- TensorShape diff_filter_tf_shape(
- {bwd_output_dims[MklDnnDims::Dim_H],
- bwd_output_dims[MklDnnDims::Dim_W],
- bwd_output_dims[MklDnnDims::Dim_I],
- bwd_output_dims[MklDnnDims::Dim_O]});
- AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
- diff_filter_tf_shape, diff_filter_mkl_shape);
+
+ if (isConv2D) {
+ // Conv2D: output_dims_mkl_order is in OIHW format.
+ TensorShape diff_filter_tf_shape({bwd_output_dims[MklDnnDims::Dim_H],
+ bwd_output_dims[MklDnnDims::Dim_W],
+ bwd_output_dims[MklDnnDims::Dim_I],
+ bwd_output_dims[MklDnnDims::Dim_O]});
+ AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
+ diff_filter_tf_shape, diff_filter_mkl_shape);
+ } else {
+ // Conv3D: output_dims_mkl_order is in OIDHW format.
+ TensorShape diff_filter_tf_shape(
+ {bwd_output_dims[MklDnnDims3D::Dim3d_D],
+ bwd_output_dims[MklDnnDims3D::Dim3d_H],
+ bwd_output_dims[MklDnnDims3D::Dim3d_W],
+ bwd_output_dims[MklDnnDims3D::Dim3d_I],
+ bwd_output_dims[MklDnnDims3D::Dim3d_O]});
+ AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
+ diff_filter_tf_shape, diff_filter_mkl_shape);
+ }
Tensor* diff_bias_tensor = nullptr;
if (biasEnabled) {
@@ -871,7 +884,7 @@ class MklConv2DCustomBackpropFilterOp
// check if src and diff_dst need reorder
T *src_data = nullptr;
- if (fwd_src_md.data.format != conv2d_bwd_filter->GetSrcMemoryFormat()) {
+ if (fwd_src_md.data.format != conv_bwd_filter->GetSrcMemoryFormat()) {
src.SetUsrMem(fwd_src_md, &src_tensor);
src.CheckReorderToOpMem(bwd_filter_pd->src_primitive_desc());
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
@@ -882,7 +895,7 @@ class MklConv2DCustomBackpropFilterOp
T *diff_dst_data = nullptr;
if (diff_dst_md.data.format !=
- conv2d_bwd_filter->GetDiffDstMemoryFormat()) {
+ conv_bwd_filter->GetDiffDstMemoryFormat()) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(bwd_filter_pd->diff_dst_primitive_desc());
diff_dst_data = static_cast<T*>(
@@ -897,7 +910,7 @@ class MklConv2DCustomBackpropFilterOp
bool diff_filter_reorder_required = false;
T *diff_filter_data = nullptr;
if (GetOutputFormat(tf_fmt) !=
- conv2d_bwd_filter->GetDiffFilterMemoryFormat()) {
+ conv_bwd_filter->GetDiffFilterMemoryFormat()) {
// Allocate diff filter tensor as Tensorflow layout
diff_filter.SetUsrMem(bwd_output_dims, GetOutputFormat(tf_fmt),
diff_filter_tensor);
@@ -915,10 +928,10 @@ class MklConv2DCustomBackpropFilterOp
if (biasEnabled) {
T* diff_bias_data = static_cast<T*>(const_cast<T*>(
diff_bias_tensor->flat<T>().data()));
- conv2d_bwd_filter->Execute(src_data, diff_filter_data,
- diff_bias_data, diff_dst_data);
+ conv_bwd_filter->Execute(src_data, diff_filter_data, diff_bias_data,
+ diff_dst_data);
} else {
- conv2d_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data);
+ conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data);
}
// Reorder diff_filter back to Tensorflow layout if necessary
@@ -947,7 +960,7 @@ class MklConv2DCustomBackpropFilterOp
const MklDnnShape& filter_mkl_shape,
const MklDnnShape& obp_mkl_shape) {
CHECK(!filter_mkl_shape.IsMklTensor())
- << "Conv2DBackpropFilter: filter should not be in MKL Layout";
+ << "ConvBackpropFilter: filter should not be in MKL Layout";
}
// Get TensorFlow shape of input tensor.
@@ -983,9 +996,11 @@ class MklConv2DCustomBackpropFilterOp
return fwd_filter_dims;
}
- // Output layout is Tensorflow's filter layout (HWIO).
+ // Output layout is Tensorflow's filter layout
+ // Conv2D: HWIO; Conv3D: DHWIO
memory::format GetOutputFormat(const memory::format data_format) {
- return memory::format::hwio;
+ return (this->strides_.size() == 4) ? memory::format::hwio
+ : memory::format::dhwio;
}
// Allocate output tensor.
@@ -1027,24 +1042,27 @@ class MklConv2DCustomBackpropFilterOp
}
};
-#define REGISTER_MKL_FILTER_KERNELS(T) \
- REGISTER_KERNEL_BUILDER( \
- Name("_MklConv2DBackpropFilter") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DCustomBackpropFilterOp<CPUDevice, T, false>); \
- REGISTER_KERNEL_BUILDER( \
- Name("_MklConv2DBackpropFilterWithBias") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DCustomBackpropFilterOp<CPUDevice, T, true>); \
- REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DBackpropFilterWithBias") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_op_registry::kMklOpLabel), \
- MklDummyOp<CPUDevice, T>);
+#define REGISTER_MKL_FILTER_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropFilterOp<CPUDevice, T, false>); \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilterWithBias") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropFilterOp<CPUDevice, T, true>); \
+ REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DBackpropFilterWithBias") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklDummyOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropFilterOp<CPUDevice, T, false>);
TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
#undef REGISTER_MKL_FILTER_KERNELS
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index 38e014d68e..b5a98301e2 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -59,7 +59,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
#ifndef INTEL_MKL_ML_ONLY
-/// utility classes enabling primitive reuse for backward conv2d ops.
+/// utility classes enabling primitive reuse for backward conv ops.
struct MklConvBwdInputParams {
memory::dims diff_src_dims;
memory::dims filter_dims;
@@ -83,11 +83,11 @@ struct MklConvBwdInputParams {
};
template <typename T>
-class MklConv2DBwdInputPrimitive : public MklPrimitive {
+class MklConvBwdInputPrimitive : public MklPrimitive {
public:
- explicit MklConv2DBwdInputPrimitive(
- const MklConvBwdInputParams& convBwdInputDims) :
- cpu_engine_(engine::cpu, 0) {
+ explicit MklConvBwdInputPrimitive(
+ const MklConvBwdInputParams& convBwdInputDims)
+ : cpu_engine_(engine::cpu, 0) {
context_.bwd_input_stream.reset(new stream(stream::kind::eager));
// create conv primitive
@@ -95,7 +95,7 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive {
Setup(convBwdInputDims);
}
}
- ~MklConv2DBwdInputPrimitive() {}
+ ~MklConvBwdInputPrimitive() {}
// Convolution backward filter (weights)
// diff_src_data: output data buffer of diff_src
@@ -134,7 +134,7 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive {
}
private:
- // Primitive reuse context for Conv2D Bwd Input op
+ // Primitive reuse context for Conv Bwd Input op
struct ConvBwdInputContext {
// expected memory format for this primitive instance
memory::format filter_fmt;
@@ -235,38 +235,37 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive {
};
template <typename T>
-class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
+class MklConvBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
private:
- MklConv2DBwdInputPrimitiveFactory() {}
- ~MklConv2DBwdInputPrimitiveFactory() {}
+ MklConvBwdInputPrimitiveFactory() {}
+ ~MklConvBwdInputPrimitiveFactory() {}
public:
- static MklConv2DBwdInputPrimitive<T>* Get(
+ static MklConvBwdInputPrimitive<T>* Get(
const MklConvBwdInputParams& convBwdInputDims) {
- MklConv2DBwdInputPrimitive<T>* conv2d_bwd_input = nullptr;
+ MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr;
// look into the pool for reusable primitive
- conv2d_bwd_input = dynamic_cast<MklConv2DBwdInputPrimitive<T>*> (
- MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().GetConv2dBwdInput(
+ conv_bwd_input = dynamic_cast<MklConvBwdInputPrimitive<T>*>(
+ MklConvBwdInputPrimitiveFactory<T>::GetInstance().GetConvBwdInput(
convBwdInputDims));
- if (conv2d_bwd_input == nullptr) {
- conv2d_bwd_input = new MklConv2DBwdInputPrimitive<T>(
- convBwdInputDims);
- MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().SetConv2dBwdInput(
- convBwdInputDims, conv2d_bwd_input);
+ if (conv_bwd_input == nullptr) {
+ conv_bwd_input = new MklConvBwdInputPrimitive<T>(convBwdInputDims);
+ MklConvBwdInputPrimitiveFactory<T>::GetInstance().SetConvBwdInput(
+ convBwdInputDims, conv_bwd_input);
}
- return conv2d_bwd_input;
+ return conv_bwd_input;
}
private:
- static MklConv2DBwdInputPrimitiveFactory& GetInstance() {
- static MklConv2DBwdInputPrimitiveFactory instance_;
+ static MklConvBwdInputPrimitiveFactory& GetInstance() {
+ static MklConvBwdInputPrimitiveFactory instance_;
return instance_;
}
static string CreateKey(const MklConvBwdInputParams& convBwdInputDims) {
- string prefix = "conv2d_bwd_input";
+ string prefix = "conv_bwd_input";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convBwdInputDims.diff_src_dims);
@@ -279,14 +278,13 @@ class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
return key_creator.GetKey();
}
- MklPrimitive* GetConv2dBwdInput(
- const MklConvBwdInputParams& convBwdInputDims) {
+ MklPrimitive* GetConvBwdInput(const MklConvBwdInputParams& convBwdInputDims) {
string key = CreateKey(convBwdInputDims);
return this->GetOp(key);
}
- void SetConv2dBwdInput(
- const MklConvBwdInputParams& convBwdInputDims, MklPrimitive *op) {
+ void SetConvBwdInput(const MklConvBwdInputParams& convBwdInputDims,
+ MklPrimitive* op) {
string key = CreateKey(convBwdInputDims);
this->SetOp(key, op);
}
@@ -594,23 +592,34 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
TensorFormat data_format;
};
+#define REGISTER_MKL_CPU_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConv2DCustomBackpropInputOp<CPUDevice, T>);
+
+TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
+#undef REGISTER_MKL_CPU_KERNELS
+
#else
template <typename Device, class T>
-class MklConv2DCustomBackpropInputOp
- : public MklConv2DBackpropCommonOp<Device, T> {
+class MklConvCustomBackpropInputOp : public MklConvBackpropCommonOp<Device, T> {
public:
- explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context)
- : MklConv2DBackpropCommonOp<Device, T>(context) {
- }
+ explicit MklConvCustomBackpropInputOp(OpKernelConstruction* context)
+ : MklConvBackpropCommonOp<Device, T>(context) {}
- ~MklConv2DCustomBackpropInputOp() {}
+ ~MklConvCustomBackpropInputOp() {}
void Compute(OpKernelContext* context) {
try {
MklDnnData<T> filter(&cpu_engine);
MklDnnData<T> diff_dst(&cpu_engine);
+ // This flag indicate Conv2D or Conv3D
+ bool isConv2D = (this->strides_.size() == 4);
+
// Input tensors
const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
const Tensor& src_tensor = MklGetInput(context, kInputIdx);
@@ -626,7 +635,7 @@ class MklConv2DCustomBackpropInputOp
diff_dst_mkl_shape);
// Allow operator-specific generation of shapes.
- // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a
+ // E.g., ConvBackpropFilter gets filter as filter_sizes. It is a
// tensor containing shape of filter. So filter.shape() is not
// a correct way to get filter shape. These operator-specific calls
// allow this class to handle this case.
@@ -655,6 +664,7 @@ class MklConv2DCustomBackpropInputOp
}
return;
}
+
// By default, all dims are in MKL order. Only dims in TF order
// are those with postfix tf_order.
memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims;
@@ -673,15 +683,18 @@ class MklConv2DCustomBackpropInputOp
// Create Convolution forward descriptor since Convolution backward
// API needs it. For that, we first need to create input, filter
// and output memory descriptors.
- auto tf_fmt = TFDataFormatToMklDnnDataFormat(this->data_format_);
+ auto tf_fmt = isConv2D
+ ? TFDataFormatToMklDnnDataFormat(this->data_format_)
+ : TFDataFormatToMklDnn3DDataFormat(this->data_format_);
// If filter is in MKL layout, then simply grab filter layout;
// otherwise, construct filter in TF layout.
// For TF layout, filter is in HWIO format.
auto fwd_filter_md = filter_mkl_shape.IsMklTensor()
- ? filter_mkl_shape.GetMklLayout()
- : memory::desc(fwd_filter_dims, MklDnnType<T>(),
- memory::format::hwio);
+ ? filter_mkl_shape.GetMklLayout()
+ : memory::desc(fwd_filter_dims, MklDnnType<T>(),
+ isConv2D ? memory::format::hwio
+ : memory::format::dhwio);
conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
if (!context->status().ok()) return;
@@ -689,18 +702,15 @@ class MklConv2DCustomBackpropInputOp
? diff_dst_mkl_shape.GetMklLayout()
: memory::desc(diff_dst_dims,
MklDnnType<T>(), tf_fmt);
+ for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
- dilations[kDilationH] -= 1;
- dilations[kDilationW] -= 1;
-
- MklConv2DBwdInputPrimitive<T> *conv2d_bwd_input = nullptr;
- conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
+ MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr;
MklConvBwdInputParams convBwdInputDims(fwd_src_dims, fwd_filter_dims,
diff_dst_dims, strides, dilations, padding_left, padding_right,
TFPaddingToMklDnnPadding(this->padding_));
- conv2d_bwd_input = MklConv2DBwdInputPrimitiveFactory<T>::Get(
- convBwdInputDims);
- auto bwd_input_pd = conv2d_bwd_input->GetPrimitiveDesc();
+ conv_bwd_input =
+ MklConvBwdInputPrimitiveFactory<T>::Get(convBwdInputDims);
+ auto bwd_input_pd = conv_bwd_input->GetPrimitiveDesc();
// allocate output tensor
auto diff_src_pd = bwd_input_pd->diff_src_primitive_desc();
@@ -723,7 +733,7 @@ class MklConv2DCustomBackpropInputOp
// check if filter and diff_dst need reorder
T* filter_data = nullptr;
if (fwd_filter_md.data.format !=
- conv2d_bwd_input->GetFilterMemoryFormat()) {
+ conv_bwd_input->GetFilterMemoryFormat()) {
filter.SetUsrMem(fwd_filter_md, &filter_tensor);
filter.CheckReorderToOpMem(bwd_input_pd->weights_primitive_desc());
filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle());
@@ -733,8 +743,7 @@ class MklConv2DCustomBackpropInputOp
}
T* diff_dst_data = nullptr;
- if (diff_dst_md.data.format !=
- conv2d_bwd_input->GetDiffDstMemoryFormat()) {
+ if (diff_dst_md.data.format != conv_bwd_input->GetDiffDstMemoryFormat()) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(bwd_input_pd->diff_dst_primitive_desc());
diff_dst_data = static_cast<T*>(
@@ -745,7 +754,7 @@ class MklConv2DCustomBackpropInputOp
}
// execute convolution input bwd
- conv2d_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
+ conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
@@ -770,7 +779,7 @@ class MklConv2DCustomBackpropInputOp
// of the Tensor and never an actual tensor. So it will never be in MKL
// layout.
CHECK(!input_mkl_shape.IsMklTensor())
- << "Conv2DBackpropInput: input should not be in MKL Layout";
+ << "ConvBackpropInput: input should not be in MKL Layout";
}
// Get TensorFlow shape of input tensor.
@@ -778,10 +787,10 @@ class MklConv2DCustomBackpropInputOp
const Tensor& input_tensor) {
TensorShape input_tf_shape;
CHECK_EQ(TensorShapeUtils::IsVector(input_tensor.shape()), true);
- CHECK_EQ(
- TensorShapeUtils::MakeShape(input_tensor.vec<int32>(), &input_tf_shape)
- .ok(),
- true);
+ // Conv[2D|3D]BackpropInputV2 supports both DT_INT32 and DT_INT64
+ // output_shape MakeShape is able to handle both DT_INT32 and DT_INT64 for
+ // input_tensor.
+ CHECK_EQ(this->MakeShape(input_tensor, &input_tf_shape).ok(), true);
return input_tf_shape;
}
@@ -792,7 +801,7 @@ class MklConv2DCustomBackpropInputOp
}
// Get the Tensorflow shape of Output (diff_src),
- // which is same as shape of Conv2D 'input'.
+ // which is same as shape of Conv 'input'.
TensorShape GetOutputTfShape(const TensorShape& input_shape,
const TensorShape& filter_shape,
const TensorShape& outbprop_shape) {
@@ -800,7 +809,7 @@ class MklConv2DCustomBackpropInputOp
}
// Get the Tensorflow shape of Output (diff_src),
- // which is same as shape of Conv2D 'input'.
+ // which is same as shape of Conv 'input'.
const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims,
const memory::dims& fwd_filter_dims) {
return fwd_input_dims;
@@ -839,17 +848,22 @@ class MklConv2DCustomBackpropInputOp
}
};
-#endif // INTEL_MKL_ML_ONLY
-
-#define REGISTER_MKL_CPU_KERNELS(T) \
- REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DCustomBackpropInputOp<CPUDevice, T>);
+#define REGISTER_MKL_CPU_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv3DBackpropInputV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropInputOp<CPUDevice, T>);
TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
#undef REGISTER_MKL_CPU_KERNELS
+#endif // INTEL_MKL_ML_ONLY
+
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index bca1aa21a8..c6295c7280 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -85,9 +85,9 @@ struct MklConvFwdParams {
};
template <typename T>
-class MklConv2DFwdPrimitive : public MklPrimitive {
+class MklConvFwdPrimitive : public MklPrimitive {
public:
- explicit MklConv2DFwdPrimitive(const MklConvFwdParams& convFwdDims)
+ explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
: cpu_engine_(engine::cpu, 0) {
context_.fwd_stream.reset(new stream(stream::kind::eager));
// create conv primitive
@@ -96,7 +96,7 @@ class MklConv2DFwdPrimitive : public MklPrimitive {
}
}
- ~MklConv2DFwdPrimitive() {}
+ ~MklConvFwdPrimitive() {}
// Convolution forward execute with bias
// src_data: input data buffer of src
@@ -269,37 +269,36 @@ class MklConv2DFwdPrimitive : public MklPrimitive {
};
template <typename T>
-class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
+class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
- static MklConv2DFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims) {
- MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
+ static MklConvFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims) {
+ MklConvFwdPrimitive<T>* conv_fwd = nullptr;
// try to find a suitable one in pool
- conv2d_fwd = dynamic_cast<MklConv2DFwdPrimitive<T>*>(
- MklConv2DFwdPrimitiveFactory<T>::GetInstance().GetConv2DFwd(
- convFwdDims));
-
- if (conv2d_fwd == nullptr) {
- conv2d_fwd = new MklConv2DFwdPrimitive<T>(convFwdDims);
- MklConv2DFwdPrimitiveFactory<T>::GetInstance().SetConv2DFwd(convFwdDims,
- conv2d_fwd);
+ conv_fwd = dynamic_cast<MklConvFwdPrimitive<T>*>(
+ MklConvFwdPrimitiveFactory<T>::GetInstance().GetConvFwd(convFwdDims));
+
+ if (conv_fwd == nullptr) {
+ conv_fwd = new MklConvFwdPrimitive<T>(convFwdDims);
+ MklConvFwdPrimitiveFactory<T>::GetInstance().SetConvFwd(convFwdDims,
+ conv_fwd);
}
- return conv2d_fwd;
+ return conv_fwd;
}
private:
- MklConv2DFwdPrimitiveFactory() {}
- ~MklConv2DFwdPrimitiveFactory() {}
+ MklConvFwdPrimitiveFactory() {}
+ ~MklConvFwdPrimitiveFactory() {}
static const int kDilationH = 0, kDilationW = 1;
- static MklConv2DFwdPrimitiveFactory& GetInstance() {
- static MklConv2DFwdPrimitiveFactory instance_;
+ static MklConvFwdPrimitiveFactory& GetInstance() {
+ static MklConvFwdPrimitiveFactory instance_;
return instance_;
}
static string CreateKey(const MklConvFwdParams& convFwdDims) {
- string prefix = "conv2d_fwd_";
+ string prefix = "conv_fwd_";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convFwdDims.src_dims);
@@ -313,12 +312,12 @@ class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
return key_creator.GetKey();
}
- MklPrimitive* GetConv2DFwd(const MklConvFwdParams& convFwdDims) {
+ MklPrimitive* GetConvFwd(const MklConvFwdParams& convFwdDims) {
string key = CreateKey(convFwdDims);
return this->GetOp(key);
}
- void SetConv2DFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) {
+ void SetConvFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) {
string key = CreateKey(convFwdDims);
this->SetOp(key, op);
}
@@ -331,11 +330,11 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
// For now, MKL-ML is default. So making MKL-DNN not a default choice.
#ifdef INTEL_MKL_ML_ONLY
template <typename Device, typename T, bool biasEnabled>
-class MklConv2DOp : public OpKernel {
+class MklConvOp : public OpKernel {
public:
- ~MklConv2DOp() {}
+ ~MklConvOp() {}
- explicit MklConv2DOp(OpKernelConstruction* context) : OpKernel(context) {
+ explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
string data_format;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
@@ -755,21 +754,22 @@ class MklConv2DOp : public OpKernel {
#else
+// Base class for convolution forward operations
template <typename Device, typename T, bool biasEnabled>
-class MklConv2DOp : public OpKernel {
+class MklConvOp : public OpKernel {
public:
- ~MklConv2DOp() {}
+ ~MklConvOp() {}
- explicit MklConv2DOp(OpKernelConstruction* context) : OpKernel(context) {
+ explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
string data_format;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
errors::InvalidArgument("Invalid data format"));
- OP_REQUIRES(context, strides_.size() == 4,
+ OP_REQUIRES(context, (strides_.size() == 4 || strides_.size() == 5),
errors::InvalidArgument("Sliding window strides field must "
- "specify 4 dimensions"));
+ "specify 4 or 5 dimensions"));
const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
@@ -778,20 +778,39 @@ class MklConv2DOp : public OpKernel {
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
- OP_REQUIRES(context, dilations_.size() == 4,
- errors::InvalidArgument("Sliding window dilations field must "
- "specify 4 dimensions"));
- const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
- const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
- const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
- const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
- OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
- errors::InvalidArgument(
- "Current implementation does not yet support "
- "dilations in the batch and depth dimensions."));
- OP_REQUIRES(
- context, dilation_h > 0 && dilation_w > 0,
- errors::InvalidArgument("Dilated rates should be larger than 0."));
+
+ if (strides_.size() == 4) {
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
+ const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
+ const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
+ const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
+ OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context, dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
+ } else if (strides_.size() == 5) {
+ OP_REQUIRES(context, dilations_.size() == 5,
+ errors::InvalidArgument("Dilation rates field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(context,
+ (GetTensorDim(dilations_, data_format_, 'N') == 1 &&
+ GetTensorDim(dilations_, data_format_, 'C') == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations rates in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context,
+ (GetTensorDim(dilations_, data_format_, '0') > 0 &&
+ GetTensorDim(dilations_, data_format_, '1') > 0 &&
+ GetTensorDim(dilations_, data_format_, '2') > 0),
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
+ }
}
void Compute(OpKernelContext* context) override {
@@ -837,7 +856,8 @@ class MklConv2DOp : public OpKernel {
AllocateOutputSetMklShape(context, kOutputIndex_Dst,
&dst_tensor, src_tf_shape, dst_mkl_shape);
- // MklConv2D also outputs converted filter as 2nd output of Conv2D.
+ // MklConv2D/3D also outputs converted filter
+ // as 2nd output of Conv2D/3D.
filter_mkl_shape.SetMklTensor(false);
Tensor* output_filter_tensor = nullptr;
AllocateOutputSetMklShape(context, kOutputIndex_Filter,
@@ -846,15 +866,20 @@ class MklConv2DOp : public OpKernel {
return;
}
+ bool isConv2D = (strides_.size() == 4);
+
// Create memory for user data.
// Describe how the inputs and outputs of Convolution look like. Also
// specify buffers containing actual input and output data.
- auto tf_fmt = TFDataFormatToMklDnnDataFormat(data_format_);
+ auto tf_fmt = isConv2D ? TFDataFormatToMklDnnDataFormat(data_format_)
+ : TFDataFormatToMklDnn3DDataFormat(data_format_);
// If input is in MKL layout, then simply grab input layout; otherwise,
// construct input Tf layout. For TF layout, although input shape
// (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
- // layout (NHWC or NCHW depending on data format).
+ // layout depending on data format:
+ // Conv2D: NHWC or NCHW
+ // Conv3D: NDHWC or NCDHW
auto src_md = src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
: memory::desc(src_dims, MklDnnType<T>(), tf_fmt);
@@ -864,31 +889,30 @@ class MklConv2DOp : public OpKernel {
auto filter_md = filter_mkl_shape.IsMklTensor() // Should NEVER be true
? filter_mkl_shape.GetMklLayout()
: memory::desc(filter_dims, MklDnnType<T>(),
- memory::format::hwio);
-
+ isConv2D ? memory::format::hwio
+ : memory::format::dhwio);
// MKLDNN dilation starts from 0.
- dilations[kDilationH] -= 1;
- dilations[kDilationW] -= 1;
+ for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
// get a conv2d fwd from primitive pool
- MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
+ MklConvFwdPrimitive<T>* conv_fwd = nullptr;
if (biasEnabled) {
memory::dims bias_dims = {};
conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims);
MklConvFwdParams convFwdDims(src_dims, filter_dims, bias_dims,
dst_dims_mkl_order, strides, dilations,
padding_left, padding_right);
- conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims);
+ conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(convFwdDims);
} else {
MklConvFwdParams convFwdDims(src_dims, filter_dims, NONE_DIMS,
dst_dims_mkl_order, strides, dilations,
padding_left, padding_right);
- conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims);
+ conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(convFwdDims);
}
// allocate output tensors output_tensor and filter_out_tensor
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_fwd_pd =
- conv2d_fwd->GetPrimitiveDesc();
+ conv_fwd->GetPrimitiveDesc();
AllocateOutputTensor(context, *conv_fwd_pd,
dst_dims_mkl_order, tf_fmt, &dst_tensor);
Tensor* filter_out_tensor = nullptr;
@@ -900,7 +924,7 @@ class MklConv2DOp : public OpKernel {
// check whether src/filter need reorder
T *src_data = nullptr;
- if (src_md.data.format != conv2d_fwd->GetSrcMemoryFormat()) {
+ if (src_md.data.format != conv_fwd->GetSrcMemoryFormat()) {
src.SetUsrMem(src_md, &src_tensor);
src.CheckReorderToOpMem(conv_fwd_pd.get()->src_primitive_desc());
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
@@ -908,7 +932,7 @@ class MklConv2DOp : public OpKernel {
src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
}
T* filter_data = nullptr;
- if (filter_md.data.format != conv2d_fwd->GetFilterMemoryFormat()) {
+ if (filter_md.data.format != conv_fwd->GetFilterMemoryFormat()) {
filter.SetUsrMem(filter_md, &filter_tensor);
filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_primitive_desc(),
filter.GetTensorBuffer(filter_out_tensor));
@@ -918,16 +942,15 @@ class MklConv2DOp : public OpKernel {
static_cast<T*>(const_cast<T*>(filter_tensor.flat<T>().data()));
}
-
// execute convolution
if (biasEnabled) {
const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias);
T* bias_data = static_cast<T*>(const_cast<T*>(
bias_tensor.flat<T>().data()));
- conv2d_fwd->Execute(src_data, filter_data, bias_data, dst_data);
+ conv_fwd->Execute(src_data, filter_data, bias_data, dst_data);
} else {
- conv2d_fwd->Execute(src_data, filter_data, dst_data);
+ conv_fwd->Execute(src_data, filter_data, dst_data);
}
} catch (mkldnn::error &e) {
string error_msg = tensorflow::strings::StrCat(
@@ -1038,17 +1061,18 @@ class MklConv2DOp : public OpKernel {
#endif
+// Register 2D operations
#define REGISTER_MKL_CPU(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DOp<CPUDevice, T, false>); \
+ MklConvOp<CPUDevice, T, false>); \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DOp<CPUDevice, T, true>); \
+ MklConvOp<CPUDevice, T, true>); \
REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DWithBias") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
@@ -1057,5 +1081,14 @@ class MklConv2DOp : public OpKernel {
TF_CALL_float(REGISTER_MKL_CPU);
+// Register 3D operations
+#define REGISTER_MKL_CPU(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv3D") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvOp<CPUDevice, T, false>);
+TF_CALL_float(REGISTER_MKL_CPU);
+
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h
index 838c06f49d..01cc606f41 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.h
+++ b/tensorflow/core/kernels/mkl_conv_ops.h
@@ -79,9 +79,16 @@ class MklDnnConvUtil {
// For now we take the stride from the second and third dimensions only
// (we do not support striding on the batch or depth dimension).
CHECK_NOTNULL(strides);
- int stride_rows = GetTensorDim(strides_, data_format_, 'H');
- int stride_cols = GetTensorDim(strides_, data_format_, 'W');
- *strides = {stride_rows, stride_cols};
+ if (strides_.size() == 4) {
+ int stride_rows = GetTensorDim(strides_, data_format_, 'H');
+ int stride_cols = GetTensorDim(strides_, data_format_, 'W');
+ *strides = {stride_rows, stride_cols};
+ } else if (strides_.size() == 5) {
+ int stride_planes = GetTensorDim(strides_, data_format_, '0');
+ int stride_rows = GetTensorDim(strides_, data_format_, '1');
+ int stride_cols = GetTensorDim(strides_, data_format_, '2');
+ *strides = {stride_planes, stride_rows, stride_cols};
+ }
}
// Calculate Convolution dilations
@@ -89,13 +96,20 @@ class MklDnnConvUtil {
// For now we take the dilation from the second and third dimensions only
// (we do not support dilation on the batch or depth dimension).
CHECK_NOTNULL(dilations);
- int dilations_rows = GetTensorDim(dilations_, data_format_, 'H');
- int dilations_cols = GetTensorDim(dilations_, data_format_, 'W');
- *dilations = {dilations_rows, dilations_cols};
+ if (dilations_.size() == 4) {
+ int dilations_rows = GetTensorDim(dilations_, data_format_, 'H');
+ int dilations_cols = GetTensorDim(dilations_, data_format_, 'W');
+ *dilations = {dilations_rows, dilations_cols};
+ } else if (dilations_.size() == 5) {
+ int dilations_planes = GetTensorDim(dilations_, data_format_, '0');
+ int dilations_rows = GetTensorDim(dilations_, data_format_, '1');
+ int dilations_cols = GetTensorDim(dilations_, data_format_, '2');
+ *dilations = {dilations_planes, dilations_rows, dilations_cols};
+ }
}
// Calculate Convolution input size in MKL-DNN order. MKL-DNN
- // requires input in NCHW format. Function does not return anything.
+ // requires input in NCHW/NCDHW format. Function does not return anything.
// But errors arising from sanity checks are returned in context's
// status.
virtual inline void GetInputSizeInMklOrder(const TensorShape& input_shape,
@@ -113,40 +127,62 @@ class MklDnnConvUtil {
int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C');
int input_depth = static_cast<int>(input_depth_raw);
- // Input rows/height
- int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H');
- CHECK_BOUNDS(input_rows_raw, "Input rows too large");
- int input_rows = static_cast<int>(input_rows_raw);
-
- // Input columns/width
- int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W');
- CHECK_BOUNDS(input_cols_raw, "Input cols too large");
- int input_cols = static_cast<int>(input_cols_raw);
-
// Input batch
int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N');
CHECK_BOUNDS(input_batch_raw, "Input batch too large");
int input_batch = static_cast<int>(input_batch_raw);
+ if (strides_.size() == 4) { // NCHW format for Conv2D
+ // Input rows/height
+ int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H');
+ CHECK_BOUNDS(input_rows_raw, "Input rows too large");
+ int input_rows = static_cast<int>(input_rows_raw);
+
+ // Input columns/width
+ int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W');
+ CHECK_BOUNDS(input_cols_raw, "Input cols too large");
+ int input_cols = static_cast<int>(input_cols_raw);
+
+ // MKL-DNN always requires input in NCHW format Conv2D.
+ std::vector<int> mkldnn_sizes(4, -1);
+ mkldnn_sizes[MklDnnDims::Dim_N] = input_batch;
+ mkldnn_sizes[MklDnnDims::Dim_C] = input_depth;
+ mkldnn_sizes[MklDnnDims::Dim_H] = input_rows;
+ mkldnn_sizes[MklDnnDims::Dim_W] = input_cols;
+
+ *input_dims = mkldnn_sizes;
+ } else if (strides_.size() == 5) { // NCDHW format for Conv3D
+ // Input planes/third-dimension
+ int64 input_planes_raw = GetTensorDim(input_shape, data_format_, '0');
+ CHECK_BOUNDS(input_planes_raw, "Input depth too large");
+ int input_planes = static_cast<int>(input_planes_raw);
+
+ // Input rows/height
+ int64 input_rows_raw = GetTensorDim(input_shape, data_format_, '1');
+ CHECK_BOUNDS(input_rows_raw, "Input rows too large");
+ int input_rows = static_cast<int>(input_rows_raw);
+
+ // Input columns/width
+ int64 input_cols_raw = GetTensorDim(input_shape, data_format_, '2');
+ CHECK_BOUNDS(input_cols_raw, "Input cols too large");
+ int input_cols = static_cast<int>(input_cols_raw);
+
+ // MKL-DNN always requires input in NCDHW format for Conv3D.
+ std::vector<int> mkldnn_sizes(5, -1);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_batch;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_planes;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_H] = input_rows;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_W] = input_cols;
+
+ *input_dims = mkldnn_sizes;
+ }
#undef CHECK_BOUNDS
-
- // MKL-DNN always requires input in NCHW format.
- std::vector<int> mkldnn_sizes(4, -1);
- mkldnn_sizes[MklDnnDims::Dim_N] = input_batch;
- mkldnn_sizes[MklDnnDims::Dim_C] = input_depth;
- mkldnn_sizes[MklDnnDims::Dim_H] = input_rows;
- mkldnn_sizes[MklDnnDims::Dim_W] = input_cols;
-
- *input_dims = mkldnn_sizes;
}
- // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
- // requires filter in OIHW format. Function does not return anything.
- // But errors arising from sanity checks are returned in context's
- // status.
- //
- // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
- // requires filter in OIHW format. Function does not return anything.
+ // Calculate Convolution filter size in MKL-DNN order.
+ // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW (Conv3D) format.
+ // Function does not return anything.
// But errors arising from sanity checks are returned in context's
// status. This function differs from GetConvFilterSizeInMklOrder in
// parameter for input - it accepts src_shape since Convolution Backward
@@ -159,11 +195,13 @@ class MklDnnConvUtil {
memory::dims* filter_dims) {
CHECK_NOTNULL(filter_dims);
- OP_REQUIRES(context_, filter_shape.dims() == 4,
- errors::InvalidArgument("filter must be 4-dimensional: ",
+ OP_REQUIRES(context_, filter_shape.dims() == strides_.size(),
+ errors::InvalidArgument((strides_.size() == 4)
+ ? "filter must be 4-dimensional: "
+ : "filter must be 5-dimensional: ",
filter_shape.DebugString()));
- for (int i = 0; i < 3; i++) {
+ for (int i = 0; i < ((strides_.size() == 4) ? 3 : 5); i++) {
OP_REQUIRES(context_,
FastBoundsCheck(filter_shape.dim_size(i),
std::numeric_limits<int>::max()),
@@ -172,32 +210,57 @@ class MklDnnConvUtil {
int input_depth = GetTensorDim(input_shape, data_format_, 'C');
- OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2),
- errors::InvalidArgument(
- "input and filter must have the same depth: ", input_depth,
- " vs ", filter_shape.dim_size(2)));
-
- // TF filter is always in (rows, cols, in_depth, out_depth) order.
- int filter_rows = static_cast<int>(filter_shape.dim_size(0));
- int filter_cols = static_cast<int>(filter_shape.dim_size(1));
- int in_depth = static_cast<int>(filter_shape.dim_size(2));
- int out_depth = static_cast<int>(filter_shape.dim_size(3));
-
- // MKL-DNN always needs filter in OIHW format.
- // OIHW = (out_depth, in_depth, rows, cols)
- std::vector<int> mkldnn_sizes(4, -1);
- mkldnn_sizes[MklDnnDims::Dim_O] = out_depth;
- mkldnn_sizes[MklDnnDims::Dim_I] = in_depth;
- mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows;
- mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols;
-
- *filter_dims = mkldnn_sizes;
+ if (strides_.size() == 4) { // Conv2D
+ OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2),
+ errors::InvalidArgument(
+ "input and filter must have the same depth: ",
+ input_depth, " vs ", filter_shape.dim_size(2)));
+
+ // TF filter is always in (rows, cols, in_depth, out_depth) order.
+ int filter_rows = static_cast<int>(filter_shape.dim_size(0));
+ int filter_cols = static_cast<int>(filter_shape.dim_size(1));
+ int in_depth = static_cast<int>(filter_shape.dim_size(2));
+ int out_depth = static_cast<int>(filter_shape.dim_size(3));
+
+ // MKL-DNN always needs filter in OIHW format.
+ // OIHW = (out_depth, in_depth, rows, cols)
+ std::vector<int> mkldnn_sizes(4, -1);
+ mkldnn_sizes[MklDnnDims::Dim_O] = out_depth;
+ mkldnn_sizes[MklDnnDims::Dim_I] = in_depth;
+ mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows;
+ mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols;
+
+ *filter_dims = mkldnn_sizes;
+ } else { // Conv3D
+ OP_REQUIRES(context_, input_depth == filter_shape.dim_size(3),
+ errors::InvalidArgument(
+ "input and filter must have the same depth: ",
+ input_depth, " vs ", filter_shape.dim_size(3)));
+
+ // TF filter is always in (planes, rows, cols, in_depth, out_depth) order.
+ int filter_planes = static_cast<int>(filter_shape.dim_size(0));
+ int filter_rows = static_cast<int>(filter_shape.dim_size(1));
+ int filter_cols = static_cast<int>(filter_shape.dim_size(2));
+ int in_depth = static_cast<int>(filter_shape.dim_size(3));
+ int out_depth = static_cast<int>(filter_shape.dim_size(4));
+
+ // MKL-DNN always needs filter in OIDHW format.
+ // OIDHW = (out_depth, in_depth, planes, rows, cols)
+ std::vector<int> mkldnn_sizes(5, -1);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_O] = out_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_I] = in_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_D] = filter_planes;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_H] = filter_rows;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_W] = filter_cols;
+
+ *filter_dims = mkldnn_sizes;
+ }
}
- // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
- // requires filter in OIHW format. Function does not return anything.
- // But errors arising from sanity checks are returned in context's
- // status.
+ // Calculate Convolution filter size in MKL-DNN order.
+ // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW(Conv3D format.
+ // Function does not return anything. But errors arising from sanity
+ // checks are returned in context's status.
virtual inline void GetFilterSizeInMklOrder(size_t src_index,
size_t filter_index,
memory::dims* filter_dims) {
@@ -206,8 +269,8 @@ class MklDnnConvUtil {
GetTfShape(context_, filter_index), filter_dims);
}
- // Calculate Bias size for 2D Convolution. Function does not return
- // anything, but sets error in context status.
+ // Calculate Bias size for 2D or 3D Convolution. Function does not
+ // return anything, but may set an error in context status.
virtual inline void GetBiasSizeInMklOrder(size_t bias_index,
memory::dims* bias_dims) {
const Tensor& bias = MklGetInput(context_, bias_index);
@@ -218,73 +281,142 @@ class MklDnnConvUtil {
*bias_dims = {static_cast<int>(bias.dim_size(0))};
}
- // Function to calculate output and padding size for 2D convolution.
+ // Function to calculate output and padding size for 2D/3D convolution.
//
// Calculate output shape of Convolution in MKL-DNN and TensorFlow order.
- // MKL-DNN uses NCHW for output order. But TensorFlow output will be in
- // NHWC or NCHW format depending on data format. Function also calculates
- // left, right, top and bottom pads. Function does not return any status -
- // status is returned via context status.
+ // MKL-DNN uses NCHW(Conv2D) or NCDHW(Conv3D) for output order.
+ // But TensorFlow output will be in NHWC||NCHW(Conv2D) or
+ // NDHWC||NCDHW(Conv3D) format depending on data format.
+ // Function also calculates left, right, top and bottom pads.
+ // Function does not return any status which is set with context status.
//
// TODO(nhasabni): Add similar function for input and filter in MklShape.
virtual inline void GetOutputAndPadSizeInMklOrder(
const TensorShape& input_shape, const TensorShape& filter_shape,
const memory::dims& strides, const memory::dims& dilations,
- memory::dims* output_dims_tf_order,
- memory::dims* output_dims_mkl_order, memory::dims* pad_l,
- memory::dims* pad_r) {
+ memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order,
+ memory::dims* pad_l, memory::dims* pad_r) {
CHECK_NOTNULL(output_dims_tf_order);
CHECK_NOTNULL(output_dims_mkl_order);
CHECK_NOTNULL(pad_l);
CHECK_NOTNULL(pad_r);
- int input_rows = GetTensorDim(input_shape, data_format_, 'H');
- int input_cols = GetTensorDim(input_shape, data_format_, 'W');
+ bool isConv2D = (strides_.size() == 4);
+ int input_planes, input_rows, input_cols;
+ if (isConv2D) {
+ input_rows = GetTensorDim(input_shape, data_format_, 'H');
+ input_cols = GetTensorDim(input_shape, data_format_, 'W');
+ } else {
+ input_planes = GetTensorDim(input_shape, data_format_, '0');
+ input_rows = GetTensorDim(input_shape, data_format_, '1');
+ input_cols = GetTensorDim(input_shape, data_format_, '2');
+ }
- // The first dimension for filter is rows/height.
- int filter_rows = filter_shape.dim_size(0);
- // The second dimension for filter is cols/width.
- int filter_cols = filter_shape.dim_size(1);
+ // Filter dimension
+ // Conv2D:
+ // First dimension: rows/height.
+ // Second dimension: cols/width.
+ // Conv3D:
+ // First dimension: planes/depth.
+ // Second dimension: rows/height.
+ // Third dimension: cols/width.
+
+ int filter_planes, filter_rows, filter_cols;
+ if (isConv2D) {
+ filter_rows = filter_shape.dim_size(0);
+ filter_cols = filter_shape.dim_size(1);
+ } else {
+ filter_planes = filter_shape.dim_size(0);
+ filter_rows = filter_shape.dim_size(1);
+ filter_cols = filter_shape.dim_size(2);
+ }
- // Stride is vector of 2 elements: {s_r, s_c}
- int stride_rows = strides[0];
- int stride_cols = strides[1];
- int dilation_rows = dilations[0];
- int dilation_cols = dilations[1];
+ int stride_planes, stride_rows, stride_cols;
+ int dilation_planes, dilation_rows, dilation_cols;
+ if (isConv2D) {
+ // Conv2D stride is a vector of 2 elements: {s_r, s_c}
+ stride_rows = strides[0];
+ stride_cols = strides[1];
+ dilation_rows = dilations[0];
+ dilation_cols = dilations[1];
+ } else {
+ // Conv3D stride is a vector of 3 elements: {s_d, s_r, s_c}
+ stride_planes = strides[0];
+ stride_rows = strides[1];
+ stride_cols = strides[2];
+ dilation_planes = dilations[0];
+ dilation_rows = dilations[1];
+ dilation_cols = dilations[2];
+ }
// Output batch is same as input batch.
int out_batch = GetTensorDim(input_shape, data_format_, 'N');
+
// Output depth is same as last dimension for filter.
- int out_depth = filter_shape.dim_size(3);
+ int out_depth = filter_shape.dim_size(isConv2D ? 3 : 4);
- int64 out_rows = 0, out_cols = 0;
+ int64 out_rows = 0, out_cols = 0, out_planes = 0;
int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right;
+ int64 pad_D1, pad_D2;
+
+ if (isConv2D) {
+ OP_REQUIRES_OK(context_,
+ GetWindowedOutputSizeVerboseV2(
+ input_rows, filter_rows, dilation_rows, stride_rows,
+ padding_, &out_rows, &pad_top, &pad_bottom));
+ OP_REQUIRES_OK(context_,
+ GetWindowedOutputSizeVerboseV2(
+ input_cols, filter_cols, dilation_cols, stride_cols,
+ padding_, &out_cols, &pad_left, &pad_right));
+ } else {
+ OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
+ input_planes, filter_planes, stride_planes,
+ padding_, &out_planes, &pad_D1, &pad_D2));
+ OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
+ input_rows, filter_rows, stride_rows,
+ padding_, &out_rows, &pad_top, &pad_bottom));
+ OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
+ input_cols, filter_cols, stride_cols,
+ padding_, &out_cols, &pad_left, &pad_right));
+ }
- OP_REQUIRES_OK(context_,
- GetWindowedOutputSizeVerboseV2(input_rows, filter_rows,
- dilation_rows, stride_rows, padding_,
- &out_rows, &pad_top, &pad_bottom));
- OP_REQUIRES_OK(context_,
- GetWindowedOutputSizeVerboseV2(input_cols, filter_cols,
- dilation_cols, stride_cols, padding_,
- &out_cols, &pad_left, &pad_right));
-
- // Tensorflow output is in data_format order. (NHWC or NCHW)
+ // Tensorflow output is in data_format order.
+ // Conv2D: NHWC or NCHW
+ // Conv3D: NDHWC or NCDHW
+ // MKL-DNN uses asymetric padding.
TensorShape out_shape =
- ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, out_depth);
+ isConv2D
+ ? ShapeFromFormat(data_format_, out_batch, out_rows, out_cols,
+ out_depth)
+ : ShapeFromFormat(data_format_, out_batch,
+ {{out_planes, out_rows, out_cols}}, out_depth);
*output_dims_tf_order = TFShapeToMklDnnDims(out_shape);
- // MKL-DNN always needs output in NCHW format.
- std::vector<int> mkldnn_sizes(4, -1);
- mkldnn_sizes[MklDnnDims::Dim_N] = out_batch;
- mkldnn_sizes[MklDnnDims::Dim_C] = out_depth;
- mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows);
- mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols);
- *output_dims_mkl_order = mkldnn_sizes;
-
- // Now handle padding. MKL-DNN uses asymetric padding.
- *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
- *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
+ if (isConv2D) {
+ // For Conv2D, MKL-DNN always needs output in NCHW format.
+ std::vector<int> mkldnn_sizes(4, -1);
+ mkldnn_sizes[MklDnnDims::Dim_N] = out_batch;
+ mkldnn_sizes[MklDnnDims::Dim_C] = out_depth;
+ mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows);
+ mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols);
+ *output_dims_mkl_order = mkldnn_sizes;
+
+ *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
+ *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
+ } else {
+ std::vector<int> mkldnn_sizes(5, -1);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_N] = out_batch;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_C] = out_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_D] = static_cast<int>(out_planes);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_H] = static_cast<int>(out_rows);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_W] = static_cast<int>(out_cols);
+ *output_dims_mkl_order = mkldnn_sizes;
+
+ *pad_l = {static_cast<int>(pad_D1), static_cast<int>(pad_top),
+ static_cast<int>(pad_left)};
+ *pad_r = {static_cast<int>(pad_D2), static_cast<int>(pad_bottom),
+ static_cast<int>(pad_right)};
+ }
}
// Calculate output and pad size of forward Convolution operator.
@@ -292,10 +424,10 @@ class MklDnnConvUtil {
//
// Function does not return anything, but sets error in context status.
inline void GetOutputAndPadSizeInMklOrder(
- size_t src_index, size_t filter_index,
- const memory::dims& strides, const memory::dims& dilations,
- memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order,
- memory::dims* pad_l, memory::dims* pad_r) {
+ size_t src_index, size_t filter_index, const memory::dims& strides,
+ const memory::dims& dilations, memory::dims* output_dims_tf_order,
+ memory::dims* output_dims_mkl_order, memory::dims* pad_l,
+ memory::dims* pad_r) {
CHECK_NOTNULL(output_dims_tf_order);
CHECK_NOTNULL(output_dims_mkl_order);
CHECK_NOTNULL(pad_l);
@@ -304,9 +436,17 @@ class MklDnnConvUtil {
auto input_tf_shape = GetTfShape(context_, src_index);
auto filter_tf_shape = GetTfShape(context_, filter_index);
- OP_REQUIRES(context_, input_tf_shape.dims() == 4,
- errors::InvalidArgument("input must be 4-dimensional",
- input_tf_shape.DebugString()));
+ if (strides_.size() == 4) {
+ // Conv2D
+ OP_REQUIRES(context_, input_tf_shape.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input_tf_shape.DebugString()));
+ } else {
+ // Conv3D
+ OP_REQUIRES(context_, input_tf_shape.dims() == 5,
+ errors::InvalidArgument("input must be 5-dimensional",
+ input_tf_shape.DebugString()));
+ }
GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape,
strides, dilations, output_dims_tf_order,
@@ -314,9 +454,11 @@ class MklDnnConvUtil {
}
// Wrapper function to calculate input, filter, and output sizes of
- // 2D Convolution in MKL order (NCHW for input and output; OIHW for filter.)
- // Function also calculates output shape in Tensorflow order. Additionally, it
- // also calculates strides and paddings for 2D Convolution.
+ // Conv2D/Conv3D in MKL order:
+ // Conv2D: NCHW for input and output; OIHW for filter.
+ // Conv3D: NCDHW for input and output; OIDHW for filter.
+ // Function also calculates output shape in Tensorflow order.
+ // Additionally, it also calculates strides and paddings.
//
// Function does not return anything, but sets error in context status.
inline void GetConvFwdSizesInMklOrder(
@@ -349,16 +491,15 @@ class MklDnnConvUtil {
}
};
-
/////////////////////////////////////////////////////////////////////
-/// Common class that implements Conv2DBackpropFilter and Input
+/// Common class that implements ConvBackpropFilter and Input
/////////////////////////////////////////////////////////////////////
template <typename Device, class T>
-class MklConv2DBackpropCommonOp : public OpKernel {
+class MklConvBackpropCommonOp : public OpKernel {
public:
- ~MklConv2DBackpropCommonOp() {}
- explicit MklConv2DBackpropCommonOp(OpKernelConstruction* context)
+ ~MklConvBackpropCommonOp() {}
+ explicit MklConvBackpropCommonOp(OpKernelConstruction* context)
: OpKernel(context) {
string data_format_str;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
@@ -372,20 +513,25 @@ class MklConv2DBackpropCommonOp : public OpKernel {
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
- OP_REQUIRES(context, dilations_.size() == 4,
- errors::InvalidArgument("Sliding window dilations field must "
- "specify 4 dimensions"));
- int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
- int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
- int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
- int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
- OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1),
- errors::InvalidArgument(
- "Current implementation does not yet support "
- "dilations in the batch and depth dimensions."));
- OP_REQUIRES(
- context, dilation_h > 0 && dilation_w > 0,
- errors::InvalidArgument("Dilated rates should be larger than 0."));
+
+ if (strides_.size() == 4) {
+ // Check Conv2D dilations
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
+ int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
+ int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
+ int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
+ OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context, dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
+ }
+
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc
index 0ab9ff9f65..9ec83b867f 100644
--- a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc
+++ b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc
@@ -47,7 +47,7 @@ using random::PhiloxRandom;
template <typename T>
struct TruncatedNormalFunctor<CPUDevice, T> {
- static const int kMaxIterations = 100;
+ static const int kMaxIterations = 1000;
void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches,
int64 samples_per_batch, int64 num_elements,
@@ -124,6 +124,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
(normMin * (normMin - sqrtFactor)) / T(4)) /
(normMin + sqrtFactor);
const T diff = normMax - normMin;
+
if (diff < cutoff) {
// Sample from a uniform distribution on [normMin, normMax].
@@ -143,15 +144,20 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
const auto u = dist(&gen_copy);
for (int i = 0; i < size; i++) {
- if (u[i] <= Eigen::numext::exp(g[i]) ||
- numIterations + 1 >= kMaxIterations) {
+ auto accept = u[i] <= Eigen::numext::exp(g[i]);
+ if (accept || numIterations + 1 >= kMaxIterations) {
// Accept the sample z.
// If we run out of iterations, just use the current uniform
- // sample. Emperically, the probability of accepting each sample
- // is at least 50% for typical inputs, so we will always accept
- // by 100 iterations.
- // This introduces a slight inaccuracy when at least one bound
- // is large, minval is negative and maxval is positive.
+ // sample, but emit a warning.
+ // TODO(jjhunt) For small entropies (relative to the bounds),
+ // this sampler is poor and may take many iterations since
+ // the proposal distribution is the uniform distribution
+ // U(lower_bound, upper_bound).
+ if (!accept) {
+ LOG(WARNING) << "TruncatedNormal uniform rejection sampler "
+ << "exceeded max iterations. Sample may contain "
+ << "outliers.";
+ }
output(sample) = z[i] * stddev + mean;
sample++;
if (sample >= limit_sample) {
@@ -181,13 +187,15 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
const T g = Eigen::numext::exp(-x * x / T(2.0));
const T u = rand[i];
i++;
- if ((u <= g && z < normMax) ||
- numIterations + 1 >= kMaxIterations) {
+ auto accept = (u <= g && z < normMax);
+ if (accept || numIterations + 1 >= kMaxIterations) {
+ if (!accept) {
+ LOG(WARNING) << "TruncatedNormal exponential distribution "
+ << "rejection sampler exceeds max iterations. "
+ << "Sample may contain outliers.";
+ }
output(sample) = z * stddev + mean;
sample++;
- if (sample >= limit_sample) {
- break;
- }
numIterations = 0;
} else {
numIterations++;
diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc
index 661d47d925..5b80a962bc 100644
--- a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc
@@ -190,7 +190,7 @@ __global__ void __launch_bounds__(1024)
// Partial specialization for GPU
template <typename T>
struct TruncatedNormalFunctor<GPUDevice, T> {
- static const int kMaxIterations = 100;
+ static const int kMaxIterations = 1000;
void operator()(OpKernelContext* ctx, const GPUDevice& d, int64 num_batches,
int64 samples_per_batch, int64 num_elements,
diff --git a/tensorflow/core/kernels/pooling_ops_3d_gpu.h b/tensorflow/core/kernels/pooling_ops_3d_gpu.h
index 350b1b6732..2c3681455e 100644
--- a/tensorflow/core/kernels/pooling_ops_3d_gpu.h
+++ b/tensorflow/core/kernels/pooling_ops_3d_gpu.h
@@ -17,8 +17,8 @@ limitations under the License.
#error This file must only be included when building with Cuda support
#endif
-#ifndef TENSORFLOW_CORE_KERNELS_POOLING_OP_3D_GPU_H_
-#define TENSORFLOW_CORE_KERNELS_POOLING_OP_3D_GPU_H_
+#ifndef TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_GPU_H_
+#define TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_GPU_H_
#define EIGEN_USE_GPU
@@ -45,4 +45,4 @@ struct MaxPool3dGradBackward {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_KERNELS_POOLING_OP_3D_H_
+#endif // TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_GPU_H_
diff --git a/tensorflow/core/kernels/qr_op_impl.h b/tensorflow/core/kernels/qr_op_impl.h
index 0552c034d2..535df9d160 100644
--- a/tensorflow/core/kernels/qr_op_impl.h
+++ b/tensorflow/core/kernels/qr_op_impl.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_QR_OP_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_QR_OP_IMPL_H_
+
// See docs in ../ops/linalg_ops.cc.
//
// This header file is used by the individual qr_*op*.cc files for registering
@@ -292,6 +295,8 @@ class QrOpGpu : public AsyncOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(QrOpGpu);
};
-#endif
+#endif // GOOGLE_CUDA
} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_QR_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
index 9af4cc23b6..88b3c2ac76 100644
--- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
+++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_
+#define TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_
+
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
@@ -1058,4 +1061,6 @@ struct ReduceFunctor<GPUDevice, Eigen::internal::OrReducer> {
} // namespace functor
} // namespace tensorflow
-#endif
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_
diff --git a/tensorflow/core/kernels/regex_replace_op.cc b/tensorflow/core/kernels/regex_replace_op.cc
index 59ec854a79..a1b948891d 100644
--- a/tensorflow/core/kernels/regex_replace_op.cc
+++ b/tensorflow/core/kernels/regex_replace_op.cc
@@ -20,8 +20,43 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
+namespace {
+
+// Execute the specified regex using the given context.
+// Context requirements:
+// - "input" string Tensor at input_index=0
+// - "output" string Tensor at output_index=0
+Status InternalCompute(const RE2& match, const string& rewrite,
+ const bool replace_global, OpKernelContext* ctx) {
+ const Tensor* input_tensor;
+ TF_RETURN_IF_ERROR(ctx->input("input", &input_tensor));
+ Tensor* output_tensor;
+ std::unique_ptr<Tensor> maybe_forwarded =
+ ctx->forward_input(0 /*input_index*/, 0 /*output_index*/,
+ tensorflow::DT_STRING, input_tensor->shape(),
+ ctx->input_memory_type(0), ctx->input_alloc_attr(0));
+ if (maybe_forwarded) {
+ output_tensor = maybe_forwarded.get();
+ TF_RETURN_IF_ERROR(ctx->set_output("output", *output_tensor));
+ } else {
+ TF_RETURN_IF_ERROR(
+ ctx->allocate_output("output", input_tensor->shape(), &output_tensor));
+ output_tensor->flat<string>() = input_tensor->flat<string>();
+ }
+ auto output_flat = output_tensor->flat<string>();
+ for (size_t i = 0; i < output_flat.size(); ++i) {
+ if (replace_global) {
+ RE2::GlobalReplace(&output_flat(i), match, rewrite);
+ } else {
+ RE2::Replace(&output_flat(i), match, rewrite);
+ }
+ }
+ return Status::OK();
+}
+} // namespace
class RegexReplaceOp : public OpKernel {
public:
@@ -30,10 +65,6 @@ class RegexReplaceOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
- const Tensor* input_tensor;
- OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
- const auto& input_flat = input_tensor->flat<string>();
-
const Tensor* pattern_tensor;
OP_REQUIRES_OK(ctx, ctx->input("pattern", &pattern_tensor));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(pattern_tensor->shape()),
@@ -51,19 +82,7 @@ class RegexReplaceOp : public OpKernel {
errors::InvalidArgument("Rewrite must be scalar, but received ",
rewrite_tensor->shape().DebugString()));
const string rewrite = rewrite_tensor->flat<string>()(0);
-
- Tensor* output_tensor = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
- &output_tensor));
- auto output_flat = output_tensor->flat<string>();
- for (size_t i = 0; i < input_flat.size(); ++i) {
- output_flat(i) = input_flat(i);
- if (replace_global_) {
- RE2::GlobalReplace(&output_flat(i), match, rewrite);
- } else {
- RE2::Replace(&output_flat(i), match, rewrite);
- }
- }
+ OP_REQUIRES_OK(ctx, InternalCompute(match, rewrite, replace_global_, ctx));
}
private:
@@ -73,4 +92,31 @@ class RegexReplaceOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("RegexReplace").Device(DEVICE_CPU),
RegexReplaceOp);
+class StaticRegexReplaceOp : public OpKernel {
+ public:
+ explicit StaticRegexReplaceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string pattern;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("rewrite", &rewrite_str_));
+ re_ = MakeUnique<RE2>(pattern);
+ OP_REQUIRES(ctx, re_->ok(),
+ errors::InvalidArgument("Invalid pattern: ", pattern,
+ ", error: ", re_->error()));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("replace_global", &replace_global_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ OP_REQUIRES_OK(ctx,
+ InternalCompute(*re_, rewrite_str_, replace_global_, ctx));
+ }
+
+ private:
+ string rewrite_str_;
+ std::unique_ptr<RE2> re_;
+ bool replace_global_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("StaticRegexReplace").Device(DEVICE_CPU),
+ StaticRegexReplaceOp);
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/regex_replace_op_test.cc b/tensorflow/core/kernels/regex_replace_op_test.cc
new file mode 100644
index 0000000000..9691d4a89f
--- /dev/null
+++ b/tensorflow/core/kernels/regex_replace_op_test.cc
@@ -0,0 +1,137 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+
+// Test data from the TensorFlow README.md.
+const char* lines[] = {
+ "**TensorFlow** is an open source software library for numerical "
+ "computation using data flow graphs.",
+ "The graph nodes represent mathematical operations, while the graph edges "
+ "represent the multidimensional data arrays (tensors) that flow between "
+ "them.",
+ "This flexible architecture enables you to deploy computation to one or "
+ "more CPUs or GPUs in a desktop, server, or mobile device without "
+ "rewriting code.",
+ "TensorFlow also includes "
+ "[TensorBoard](https://www.tensorflow.org/guide/"
+ "summaries_and_tensorboard), a data visualization toolkit.",
+ "TensorFlow was originally developed by researchers and engineers working "
+ "on the Google Brain team within Google's Machine Intelligence Research "
+ "organization for the purposes of conducting machine learning and deep "
+ "neural networks research.",
+ "The system is general enough to be applicable in a wide variety of other "
+ "domains, as well.",
+ "TensorFlow provides stable Python API and C APIs as well as without API "
+ "backwards compatibility guarantee like C++, Go, Java, JavaScript and "
+ "Swift."};
+
+const char kRegExPattern[] = "\\p{P}";
+const char kRewrite[] = " ";
+
+Tensor GetTestTensor(int batch) {
+ const int sz = TF_ARRAYSIZE(lines);
+ Tensor t(DT_STRING, {batch});
+ auto s = t.flat<string>();
+ for (int i = 0; i < batch; ++i) {
+ s(i) = lines[i % sz];
+ }
+ return t;
+}
+
+Graph* SetupRegexReplaceGraph(const Tensor& input, const string& input_pattern,
+ const string& input_rewrite) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor pattern(DT_STRING, TensorShape({}));
+ pattern.flat<string>().setConstant(input_pattern);
+ Tensor rewrite(DT_STRING, TensorShape({}));
+ rewrite.flat<string>().setConstant(input_rewrite);
+
+ TF_CHECK_OK(NodeBuilder("regex_replace_op", "RegexReplace")
+ .Input(test::graph::Constant(g, input))
+ .Input(test::graph::Constant(g, pattern))
+ .Input(test::graph::Constant(g, rewrite))
+ .Attr("replace_global", true)
+ .Finalize(g, nullptr /* node */));
+ return g;
+}
+
+void BM_RegexReplace(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestTensor(batch_size);
+ Graph* g = SetupRegexReplaceGraph(input, kRegExPattern, kRewrite);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+BENCHMARK(BM_RegexReplace)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
+
+Graph* SetupStaticGraph(const Tensor& input, const string& input_pattern,
+ const string& rewrite) {
+ Graph* g = new Graph(OpRegistry::Global());
+
+ TF_CHECK_OK(NodeBuilder("static_regex_replace_op", "StaticRegexReplace")
+ .Attr("pattern", input_pattern)
+ .Attr("rewrite", rewrite)
+ .Input(test::graph::Constant(g, input))
+ .Attr("replace_global", true)
+ .Finalize(g, nullptr /* node */));
+ return g;
+}
+void BM_StaticRegexReplace(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestTensor(batch_size);
+ Graph* g = SetupStaticGraph(input, kRegExPattern, kRewrite);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+BENCHMARK(BM_StaticRegexReplace)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h b/tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h
index 271dd2c485..b5274f8788 100644
--- a/tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h
+++ b/tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
+
// See docs in ../ops/linalg_ops.cc.
#include "third_party/eigen3/Eigen/Core"
@@ -85,3 +88,5 @@ class SelfAdjointEigV2Op : public LinearAlgebraOp<Scalar> {
};
} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/softplus_op.cc b/tensorflow/core/kernels/softplus_op.cc
index 494a83ed14..d3fc0e1461 100644
--- a/tensorflow/core/kernels/softplus_op.cc
+++ b/tensorflow/core/kernels/softplus_op.cc
@@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/kernels/warn_about_ints.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -35,9 +34,7 @@ template <typename Device, typename T>
class SoftplusOp : public UnaryElementWiseOp<T, SoftplusOp<Device, T>> {
public:
explicit SoftplusOp(OpKernelConstruction* context)
- : UnaryElementWiseOp<T, SoftplusOp<Device, T>>(context) {
- WarnAboutInts(context);
- }
+ : UnaryElementWiseOp<T, SoftplusOp<Device, T>>(context) {}
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
functor::Softplus<Device, T> functor;
@@ -51,9 +48,7 @@ class SoftplusGradOp
: public BinaryElementWiseOp<T, SoftplusGradOp<Device, T>> {
public:
explicit SoftplusGradOp(OpKernelConstruction* context)
- : BinaryElementWiseOp<T, SoftplusGradOp<Device, T>>(context) {
- WarnAboutInts(context);
- }
+ : BinaryElementWiseOp<T, SoftplusGradOp<Device, T>>(context) {}
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
const Tensor& a, Tensor* output);
@@ -89,7 +84,7 @@ void SoftplusGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
Name("SoftplusGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
SoftplusGradOp<CPUDevice, type>);
-TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+TF_CALL_FLOAT_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/softsign_op.cc b/tensorflow/core/kernels/softsign_op.cc
index 00ee649b17..d691f15651 100644
--- a/tensorflow/core/kernels/softsign_op.cc
+++ b/tensorflow/core/kernels/softsign_op.cc
@@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/kernels/warn_about_ints.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -35,9 +34,7 @@ template <typename Device, typename T>
class SoftsignOp : public UnaryElementWiseOp<T, SoftsignOp<Device, T>> {
public:
explicit SoftsignOp(OpKernelConstruction* context)
- : UnaryElementWiseOp<T, SoftsignOp<Device, T>>(context) {
- WarnAboutInts(context);
- }
+ : UnaryElementWiseOp<T, SoftsignOp<Device, T>>(context) {}
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
functor::Softsign<Device, T> functor;
@@ -51,9 +48,7 @@ class SoftsignGradOp
: public BinaryElementWiseOp<T, SoftsignGradOp<Device, T>> {
public:
explicit SoftsignGradOp(OpKernelConstruction* context)
- : BinaryElementWiseOp<T, SoftsignGradOp<Device, T>>(context) {
- WarnAboutInts(context);
- }
+ : BinaryElementWiseOp<T, SoftsignGradOp<Device, T>>(context) {}
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
const Tensor& a, Tensor* output);
@@ -90,7 +85,7 @@ void SoftsignGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
Name("SoftsignGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
SoftsignGradOp<CPUDevice, type>);
-TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+TF_CALL_FLOAT_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/sparse_xent_op.h b/tensorflow/core/kernels/sparse_xent_op.h
index b5587aa9d7..6ba7931ab5 100644
--- a/tensorflow/core/kernels/sparse_xent_op.h
+++ b/tensorflow/core/kernels/sparse_xent_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_XENT_OP_H_
-#define TENSORFLOW_KERNELS_XENT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_
// Functor definition for SparseXentOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -224,4 +224,4 @@ struct SparseXentEigenImpl {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_XENT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index 59fdc2262a..7b537fef5b 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -300,7 +300,8 @@ class StridedSliceAssignOp : public OpKernel {
gtl::InlinedVector<int64, 4> end;
gtl::InlinedVector<int64, 4> strides;
- Tensor old_lhs;
+ Tensor* old_lhs = nullptr;
+ Tensor tmp;
if (context->input_dtype(0) == DT_RESOURCE) {
Var* v;
OP_REQUIRES_OK(context,
@@ -308,29 +309,30 @@ class StridedSliceAssignOp : public OpKernel {
mutex_lock ml(*v->mu());
OP_REQUIRES_OK(context,
PrepareToUpdateVariable<Device, T>(context, v->tensor()));
- old_lhs = *v->tensor();
- OP_REQUIRES(context, old_lhs.dtype() == DataTypeToEnum<T>::value,
+ old_lhs = v->tensor();
+ OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum<T>::value,
errors::InvalidArgument(
- "l-value dtype ", DataTypeString(old_lhs.dtype()),
+ "l-value dtype ", DataTypeString(old_lhs->dtype()),
" does not match r-value dtype ",
DataTypeString(DataTypeToEnum<T>::value)));
} else {
context->forward_ref_input_to_ref_output(0, 0);
- old_lhs = context->mutable_input(0, true);
+ tmp = context->mutable_input(0, true);
+ old_lhs = &tmp;
}
OP_REQUIRES_OK(
- context,
- ValidateStridedSliceOp(
- &context->input(1), &context->input(2), context->input(3),
- old_lhs.shape(), begin_mask, end_mask, ellipsis_mask, new_axis_mask,
- shrink_axis_mask, &processing_shape, &final_shape, &is_identity,
- &is_simple_slice, &slice_dim0, &begin, &end, &strides));
+ context, ValidateStridedSliceOp(
+ &context->input(1), &context->input(2), context->input(3),
+ old_lhs->shape(), begin_mask, end_mask, ellipsis_mask,
+ new_axis_mask, shrink_axis_mask, &processing_shape,
+ &final_shape, &is_identity, &is_simple_slice, &slice_dim0,
+ &begin, &end, &strides));
if (processing_shape.num_elements()) {
const Tensor& input = context->input(4);
TensorShape input_shape = input.shape();
- TensorShape original_shape = old_lhs.shape();
+ TensorShape original_shape = old_lhs->shape();
// TODO(aselle): This check is too strong, we only should need
// input_shape to be broadcastable to final_shape
OP_REQUIRES(
@@ -345,12 +347,12 @@ class StridedSliceAssignOp : public OpKernel {
// scalar shape
// Handle general dimensions
-#define HANDLE_DIM(NDIM) \
- if (processing_dims == NDIM) { \
- HandleStridedSliceAssignCase<Device, T, NDIM>()( \
- context, begin, end, strides, processing_shape, is_simple_slice, \
- &old_lhs); \
- return; \
+#define HANDLE_DIM(NDIM) \
+ if (processing_dims == NDIM) { \
+ HandleStridedSliceAssignCase<Device, T, NDIM>()(context, begin, end, \
+ strides, processing_shape, \
+ is_simple_slice, old_lhs); \
+ return; \
}
HANDLE_DIM(0);
HANDLE_DIM(1);
diff --git a/tensorflow/core/kernels/string_length_op.cc b/tensorflow/core/kernels/string_length_op.cc
new file mode 100644
index 0000000000..a6829b29d9
--- /dev/null
+++ b/tensorflow/core/kernels/string_length_op.cc
@@ -0,0 +1,45 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+
+namespace tensorflow {
+namespace {
+
+class StringLengthOp : public OpKernel {
+ public:
+ using OpKernel::OpKernel;
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+
+ Tensor* output;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+
+ auto src = input.flat<string>();
+ auto dst = output->flat<int32>();
+
+ for (int n = 0; n < src.size(); ++n) {
+ dst(n) = src(n).size();
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("StringLength").Device(DEVICE_CPU),
+ StringLengthOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/string_split_op.cc b/tensorflow/core/kernels/string_split_op.cc
index 26ab72f12e..3884370a6c 100644
--- a/tensorflow/core/kernels/string_split_op.cc
+++ b/tensorflow/core/kernels/string_split_op.cc
@@ -26,25 +26,81 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
-
namespace {
+// Split input string `str` based on a character delimiter.
+// Returns a vector of StringPieces which are valid as long as input `str`
+// is valid.
+// Note: The single character delimiter is a common case and is implemented as
+// a series of finds in the input string, making it much more effcient than
+// SplitOnCharSet.
+template <typename Predicate>
+std::vector<StringPiece> SplitOnChar(const string& str, const char delim,
+ Predicate p) {
+ std::vector<StringPiece> result;
+ StringPiece text(str);
+ auto f = text.find(delim);
+ while (f != StringPiece::npos) {
+ StringPiece token = text.substr(0, f);
+ if (p(token)) {
+ result.emplace_back(token);
+ }
+ text.remove_prefix(f + 1);
+ f = text.find(delim);
+ }
+ if (p(text)) {
+ result.push_back(text);
+ }
+ return result;
+}
-std::vector<string> Split(const string& str, const string& delimiter,
- const bool skipEmpty) {
- if (!delimiter.empty()) {
- if (skipEmpty) {
- return str_util::Split(str, delimiter, str_util::SkipEmpty());
+// Split input string `str` based on a set of character delimiters.
+// Returns a vector of StringPieces which are valid as long as input `str`
+// is valid.
+// Based on str_util::Split.
+template <typename Predicate>
+std::vector<StringPiece> SplitOnCharSet(const string& str,
+ const string& delim_set, Predicate p) {
+ std::vector<StringPiece> result;
+ StringPiece text(str);
+ StringPiece delims(delim_set);
+ size_t token_start = 0;
+ for (size_t i = 0; i < text.size() + 1; i++) {
+ if ((i == text.size()) || (delims.find(text[i]) != StringPiece::npos)) {
+ StringPiece token(text.data() + token_start, i - token_start);
+ if (p(token)) {
+ result.emplace_back(token);
+ }
+ token_start = i + 1;
}
- return str_util::Split(str, delimiter);
}
- std::vector<string> char_vector(str.size());
- for (size_t i = 0; i < str.size(); ++i) {
- char_vector[i] = str[i];
+ return result;
+}
+
+// Split input string `str` based on given delimiter.
+// Returns a vector of StringPieces which are valid as long as input `str`
+// is valid.
+template <typename Predicate>
+std::vector<StringPiece> Split(const string& str, const string& delimiter,
+ Predicate predicate) {
+ if (str.empty()) {
+ return std::vector<StringPiece>();
+ }
+ if (delimiter.empty()) {
+ std::vector<StringPiece> result;
+ result.resize(str.size());
+ for (size_t i = 0; i < str.size(); ++i) {
+ result[i] = StringPiece(str.data() + i, 1);
+ }
+ return result;
}
- return char_vector;
+ if (delimiter.size() == 1) {
+ return SplitOnChar(str, delimiter[0], predicate);
+ }
+ return SplitOnCharSet(str, delimiter, predicate);
}
-std::vector<string> SplitV2(const string& str, StringPiece sep, int maxsplit) {
+std::vector<StringPiece> SplitV2(const string& str, StringPiece sep,
+ int maxsplit) {
// This SplitV2 method matches the behavior of python's str.split:
// If sep is given, consecutive delimiters are not grouped together
// and are deemed to delimit empty strings (for example, '1,,2'.split(',')
@@ -59,11 +115,11 @@ std::vector<string> SplitV2(const string& str, StringPiece sep, int maxsplit) {
// splitting an empty string or a string consisting of just whitespace
// with a None separator returns [].
- std::vector<string> result;
+ std::vector<StringPiece> result;
StringPiece text(str);
if (maxsplit == 0) {
- result.emplace_back(std::string(text));
+ result.emplace_back(text);
return result;
}
@@ -73,11 +129,11 @@ std::vector<string> SplitV2(const string& str, StringPiece sep, int maxsplit) {
str_util::RemoveLeadingWhitespace(&text);
int split = 0;
while (str_util::ConsumeNonWhitespace(&text, &token)) {
- result.emplace_back(std::string(token));
+ result.push_back(token);
str_util::RemoveLeadingWhitespace(&text);
++split;
if (maxsplit > 0 && split == maxsplit) {
- result.emplace_back(std::string(text));
+ result.push_back(text);
return result;
}
}
@@ -87,17 +143,17 @@ std::vector<string> SplitV2(const string& str, StringPiece sep, int maxsplit) {
int split = 0;
while (p != text.end()) {
StringPiece token = text.substr(0, p - text.begin());
- result.emplace_back(std::string(token));
+ result.push_back(token);
text.remove_prefix(token.size());
text.remove_prefix(sep.size());
++split;
if (maxsplit > 0 && split == maxsplit) {
- result.emplace_back(std::string(text));
+ result.push_back(StringPiece(text));
return result;
}
p = std::search(text.begin(), text.end(), sep.begin(), sep.end());
}
- result.emplace_back(std::string(text));
+ result.push_back(text);
return result;
}
@@ -134,7 +190,7 @@ class StringSplitOp : public OpKernel {
const auto delimiter_vec = delimiter_tensor->flat<string>();
const string& delimiter = delimiter_vec(0);
// Empty delimiter means split the input character by character.
- std::vector<string> tokens;
+ std::vector<StringPiece> tokens;
// Guess that we'll be unpacking a handful of tokens per example.
static constexpr int kReserveSize = 4;
tokens.reserve(batch_size * kReserveSize);
@@ -143,12 +199,15 @@ class StringSplitOp : public OpKernel {
int64 max_num_entries = 0;
std::vector<int64> num_indices(batch_size);
for (int64 i = 0; i < batch_size; ++i) {
- std::vector<string> parts = Split(input_vec(i), delimiter, skip_empty_);
+ std::vector<StringPiece> parts =
+ skip_empty_ ? Split(input_vec(i), delimiter, str_util::SkipEmpty())
+ : Split(input_vec(i), delimiter, str_util::AllowEmpty());
int64 n_entries = parts.size();
num_indices[i] = n_entries;
output_size += n_entries;
max_num_entries = std::max(max_num_entries, n_entries);
- tokens.insert(tokens.end(), parts.begin(), parts.end());
+ tokens.insert(tokens.end(), std::make_move_iterator(parts.begin()),
+ std::make_move_iterator(parts.end()));
}
Tensor* sp_indices_t;
@@ -170,7 +229,7 @@ class StringSplitOp : public OpKernel {
for (size_t j = 0; j < num_indices[i]; ++j) {
sp_indices(c, 0) = i;
sp_indices(c, 1) = j;
- sp_tokens(c) = tokens[c];
+ sp_tokens(c).assign(tokens[c].data(), tokens[c].size());
++c;
}
}
@@ -204,7 +263,7 @@ class StringSplitV2Op : public OpKernel {
sep_tensor->shape().DebugString()));
const auto sep_vec = sep_tensor->flat<string>();
StringPiece sep(sep_vec(0));
- std::vector<string> tokens;
+ std::vector<StringPiece> tokens;
// Guess that we'll be unpacking a handful of tokens per example.
static constexpr int kReserveSize = 4;
tokens.reserve(batch_size * kReserveSize);
@@ -213,7 +272,7 @@ class StringSplitV2Op : public OpKernel {
int64 max_num_entries = 0;
std::vector<int64> num_indices(batch_size);
for (int64 i = 0; i < batch_size; ++i) {
- std::vector<string> parts = SplitV2(input_vec(i), sep, maxsplit_);
+ std::vector<StringPiece> parts = SplitV2(input_vec(i), sep, maxsplit_);
int64 n_entries = parts.size();
num_indices[i] = n_entries;
output_size += n_entries;
@@ -240,7 +299,7 @@ class StringSplitV2Op : public OpKernel {
for (size_t j = 0; j < num_indices[i]; ++j) {
sp_indices(c, 0) = i;
sp_indices(c, 1) = j;
- sp_tokens(c) = tokens[c];
+ sp_tokens(c).assign(tokens[c].data(), tokens[c].size());
++c;
}
}
diff --git a/tensorflow/core/kernels/string_split_op_test.cc b/tensorflow/core/kernels/string_split_op_test.cc
new file mode 100644
index 0000000000..58ad61adc8
--- /dev/null
+++ b/tensorflow/core/kernels/string_split_op_test.cc
@@ -0,0 +1,129 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+
+// Test data from the TensorFlow README.md.
+const char* lines[] = {
+ "**TensorFlow** is an open source software library for numerical "
+ "computation using data flow graphs.",
+ "The graph nodes represent mathematical operations, while the graph edges "
+ "represent the multidimensional data arrays (tensors) that flow between "
+ "them.",
+ "This flexible architecture enables you to deploy computation to one or "
+ "more CPUs or GPUs in a desktop, server, or mobile device without "
+ "rewriting code.",
+ "TensorFlow also includes "
+ "[TensorBoard](https://www.tensorflow.org/guide/"
+ "summaries_and_tensorboard), a data visualization toolkit.",
+ "TensorFlow was originally developed by researchers and engineers working "
+ "on the Google Brain team within Google's Machine Intelligence Research "
+ "organization for the purposes of conducting machine learning and deep "
+ "neural networks research.",
+ "The system is general enough to be applicable in a wide variety of other "
+ "domains, as well.",
+ "TensorFlow provides stable Python API and C APIs as well as without API "
+ "backwards compatibility guarantee like C++, Go, Java, JavaScript and "
+ "Swift."};
+
+Tensor GetTestTensor(int batch) {
+ const int sz = TF_ARRAYSIZE(lines);
+ Tensor t(DT_STRING, {batch});
+ auto s = t.flat<string>();
+ for (int i = 0; i < batch; ++i) {
+ s(i) = lines[i % sz];
+ }
+ return t;
+}
+
+Graph* SetupStringSplitGraph(const Tensor& input) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor delim(DT_STRING, TensorShape({}));
+ delim.flat<string>().setConstant(" ");
+
+ TF_CHECK_OK(NodeBuilder("string_split_op", "StringSplit")
+ .Input(test::graph::Constant(g, input))
+ .Input(test::graph::Constant(g, delim))
+ .Finalize(g, nullptr /* node */));
+ return g;
+}
+
+void BM_StringSplit(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestTensor(batch_size);
+ Graph* g = SetupStringSplitGraph(input);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+BENCHMARK(BM_StringSplit)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
+
+Graph* SetupStringSplitV2Graph(const Tensor& input) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor sep(DT_STRING, TensorShape({}));
+ sep.flat<string>().setConstant(" ");
+
+ TF_CHECK_OK(NodeBuilder("string_split_op", "StringSplitV2")
+ .Input(test::graph::Constant(g, input))
+ .Input(test::graph::Constant(g, sep))
+ .Finalize(g, nullptr /* node */));
+ return g;
+}
+
+void BM_StringSplitV2(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestTensor(batch_size);
+ Graph* g = SetupStringSplitV2Graph(input);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+BENCHMARK(BM_StringSplitV2)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/svd_op_impl.h b/tensorflow/core/kernels/svd_op_impl.h
index a996b67c62..2a67700c12 100644
--- a/tensorflow/core/kernels/svd_op_impl.h
+++ b/tensorflow/core/kernels/svd_op_impl.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_SVD_OP_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_SVD_OP_IMPL_H_
+
// See docs in ../ops/linalg_ops.cc.
//
// This header file is used by the individual svd_*op*.cc files for registering
@@ -101,3 +104,5 @@ class SvdOp : public LinearAlgebraOp<Scalar> {
};
} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_SVD_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index b368ffc875..632b65e9b6 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -1119,8 +1119,8 @@ class TensorArrayUnpackOrScatterOp : public OpKernel {
{1, num_values, element_shape.num_elements()});
Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, 0, 0};
- Eigen::DSizes<Eigen::DenseIndex, 3> sizes{1, 1,
- element_shape.num_elements()};
+ Eigen::DSizes<Eigen::DenseIndex, 3> sizes{
+ 1, 1, static_cast<Eigen::DenseIndex>(element_shape.num_elements())};
std::vector<PersistentTensor> write_values;
write_values.reserve(num_values);
@@ -1315,9 +1315,11 @@ class TensorArraySplitOp : public OpKernel {
PersistentTensor persistent_tensor;
int64 previous_length = (i == 0) ? 0 : cumulative_lengths[i - 1];
- Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, previous_length, 0};
- Eigen::DSizes<Eigen::DenseIndex, 3> sizes{1, tensor_lengths_t(i),
- elements_per_row};
+ Eigen::DSizes<Eigen::DenseIndex, 3> indices{
+ 0, static_cast<Eigen::DenseIndex>(previous_length), 0};
+ Eigen::DSizes<Eigen::DenseIndex, 3> sizes{
+ 1, static_cast<Eigen::DenseIndex>(tensor_lengths_t(i)),
+ static_cast<Eigen::DenseIndex>(elements_per_row)};
OP_REQUIRES_OK(ctx, ctx->allocate_persistent(
tensor_array->ElemType(), element_shapes[i],
diff --git a/tensorflow/core/kernels/warn_about_ints.cc b/tensorflow/core/kernels/warn_about_ints.cc
deleted file mode 100644
index 75ecdf2ae4..0000000000
--- a/tensorflow/core/kernels/warn_about_ints.cc
+++ /dev/null
@@ -1,33 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/kernels/warn_about_ints.h"
-#include "tensorflow/core/framework/node_def.pb.h"
-
-namespace tensorflow {
-
-void WarnAboutInts(OpKernelConstruction* context) {
- DataType dtype;
- OP_REQUIRES_OK(context, context->GetAttr("T", &dtype));
- if (DataTypeIsInteger(dtype)) {
- LOG(WARNING) << "Op " << context->def().name() << " of type "
- << context->def().op() << " used with integer dtype "
- << DataTypeString(dtype)
- << ". This op was registered with integer support "
- << "accidentally, and you won't like the result.";
- }
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/core/kernels/where_op_gpu.cu.h b/tensorflow/core/kernels/where_op_gpu.cu.h
index 57f51889de..8879d9dd4c 100644
--- a/tensorflow/core/kernels/where_op_gpu.cu.h
+++ b/tensorflow/core/kernels/where_op_gpu.cu.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_WHERE_OP_GPU_CU_H_
+#define TENSORFLOW_CORE_KERNELS_WHERE_OP_GPU_CU_H_
+
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
@@ -346,3 +349,5 @@ TF_CALL_WHERE_GPU_TYPES(DECLARE_GPU_SPEC);
} // namespace tensorflow
#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CORE_KERNELS_WHERE_OP_GPU_CU_H_
diff --git a/tensorflow/core/kernels/xent_op.h b/tensorflow/core/kernels/xent_op.h
index 87be17fca9..23d3ad39a8 100644
--- a/tensorflow/core/kernels/xent_op.h
+++ b/tensorflow/core/kernels/xent_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_XENT_OP_H_
-#define TENSORFLOW_KERNELS_XENT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_XENT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_XENT_OP_H_
// Functor definition for XentOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -125,4 +125,4 @@ struct XentEigenImpl {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_XENT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_XENT_OP_H_
diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h
index d7ecc44e50..329f115608 100644
--- a/tensorflow/core/lib/core/stringpiece.h
+++ b/tensorflow/core/lib/core/stringpiece.h
@@ -31,6 +31,7 @@ limitations under the License.
#include <string.h>
#include <iosfwd>
#include <string>
+#include <type_traits>
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -101,11 +102,18 @@ class StringPiece {
// > 0 iff "*this" > "b"
int compare(StringPiece b) const;
- // Converts to `std::basic_string`.
- template <typename A>
- explicit operator std::basic_string<char, std::char_traits<char>, A>() const {
+ // Converts to various kinds of strings, including `std::basic_string`.
+ template <typename S>
+ explicit operator S() const {
+ static_assert(
+ std::is_same<char, typename S::value_type>::value,
+ "Type mismatch: S must be a string with character type char.");
+ static_assert(
+ std::is_same<std::char_traits<char>, typename S::traits_type>::value,
+ "Type mismatch: S must be a string with traits type "
+ "std::char_traits<char>.");
if (!data()) return {};
- return std::basic_string<char, std::char_traits<char>, A>(data(), size());
+ return S(data(), size());
}
private:
diff --git a/tensorflow/core/lib/core/stringpiece_test.cc b/tensorflow/core/lib/core/stringpiece_test.cc
index 952b9eaaaa..e4b489fe17 100644
--- a/tensorflow/core/lib/core/stringpiece_test.cc
+++ b/tensorflow/core/lib/core/stringpiece_test.cc
@@ -56,8 +56,8 @@ TEST(StringPiece, Ctor) {
}
TEST(StringPiece, ConversionToString) {
- EXPECT_EQ("", std::string(StringPiece("")));
- EXPECT_EQ("foo", std::string(StringPiece("foo")));
+ EXPECT_EQ("", string(StringPiece("")));
+ EXPECT_EQ("foo", string(StringPiece("foo")));
}
} // namespace tensorflow
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index c15409a246..03dab390a7 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -1620,6 +1620,24 @@ TEST(ArrayOpsTest, Slice_ShapeFn) {
INFER_ERROR("cannot be < -1", op, "[2,3,4,5];[4];[4]");
}
+TEST(ArrayOpsTest, StridedSlice_ShapeFn) {
+ ShapeInferenceTestOp op("StridedSlice");
+ TF_ASSERT_OK(NodeDefBuilder("test", "StridedSlice")
+ .Input("input", 0, DT_FLOAT)
+ .Input("begin", 1, DT_INT32)
+ .Input("end", 2, DT_INT32)
+ .Input("strides", 3, DT_INT32)
+ .Attr("shrink_axis_mask", 1)
+ .Finalize(&op.node_def));
+ op.input_tensors.resize(4);
+ Tensor strides = test::AsTensor<int32>({1});
+ op.input_tensors[3] = &strides;
+ // Slicing on the 0-th dimension.
+ INFER_OK(op, "[2,3,4,5];[1];[1];[1]", "[3,4,5]");
+ // Slicing on the 0-th dimension. This time some of the result dimension is 0.
+ INFER_OK(op, "[2,0,3,4];[1];[1];[1]", "[0,3,4]");
+}
+
TEST(ArrayOpsTest, StridedSliceGrad_ShapeFn) {
ShapeInferenceTestOp op("StridedSliceGrad");
op.input_tensors.resize(5);
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 92ccbd979d..d708b5a5e3 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -20317,6 +20317,31 @@ op {
}
}
op {
+ name: "DivNoNan"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "DrawBoundingBoxes"
input_arg {
name: "images"
@@ -68834,6 +68859,32 @@ op {
}
}
op {
+ name: "StaticRegexReplace"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "pattern"
+ type: "string"
+ }
+ attr {
+ name: "rewrite"
+ type: "string"
+ }
+ attr {
+ name: "replace_global"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "StatsAggregatorHandle"
output_arg {
name: "handle"
@@ -69134,6 +69185,17 @@ op {
}
}
op {
+ name: "StringLength"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_INT32
+ }
+}
+op {
name: "StringSplit"
input_arg {
name: "input"
@@ -73406,41 +73468,6 @@ op {
}
}
op {
- name: "UnsafeDiv"
- input_arg {
- name: "x"
- type_attr: "T"
- }
- input_arg {
- name: "y"
- type_attr: "T"
- }
- output_arg {
- name: "z"
- type_attr: "T"
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_BFLOAT16
- type: DT_HALF
- type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_UINT8
- type: DT_INT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT32
- type: DT_INT64
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- }
- }
- }
-}
-op {
name: "UnsortedSegmentMax"
input_arg {
name: "data"
diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc
index 57499a6f1d..07f876cb90 100644
--- a/tensorflow/core/ops/math_grad.cc
+++ b/tensorflow/core/ops/math_grad.cc
@@ -495,18 +495,18 @@ Status RealDivGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("RealDiv", RealDivGrad);
-Status UnsafeDivGrad(const AttrSlice& attrs, FunctionDef* g) {
+Status DivNoNanGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
return GradForBinaryCwise(g, {
- {{"gx"}, "UnsafeDiv", {"dz", "y"}},
+ {{"gx"}, "DivNoNan", {"dz", "y"}},
{{"nx"}, "Neg", {"x"}, {}, {"dz"}},
{{"y2"}, "Square", {"y"}, {}, {"dz"}},
- {{"nx_y2"}, "UnsafeDiv", {"nx", "y2"}},
+ {{"nx_y2"}, "DivNoNan", {"nx", "y2"}},
{{"gy"}, "Mul", {"dz", "nx_y2"}}, // dz * (- x / y^2)
});
// clang-format on
}
-REGISTER_OP_GRADIENT("UnsafeDiv", UnsafeDivGrad);
+REGISTER_OP_GRADIENT("DivNoNan", DivNoNanGrad);
Status PowGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc
index b0d1595c31..5ee79809ac 100644
--- a/tensorflow/core/ops/math_grad_test.cc
+++ b/tensorflow/core/ops/math_grad_test.cc
@@ -753,14 +753,14 @@ TEST_F(MathGradTest, Div) {
}
}
-TEST_F(MathGradTest, UnsafeDiv) {
+TEST_F(MathGradTest, DivNoNan) {
auto x = test::AsTensor<float>(
{0.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 0.f}, TensorShape({3, 3}));
auto y = test::AsTensor<float>({-10.f, 0.f, 10.f}, TensorShape({3, 1}));
Tensor dx;
Tensor dy;
{
- SymGrad("UnsafeDiv", x, y, &dx, &dy);
+ SymGrad("DivNoNan", x, y, &dx, &dy);
{
auto g = [](float x, float y) {
if (y == 0.f) {
@@ -792,7 +792,7 @@ TEST_F(MathGradTest, UnsafeDiv) {
}
}
{ // Swap x and y.
- SymGrad("UnsafeDiv", y, x, &dy, &dx);
+ SymGrad("DivNoNan", y, x, &dy, &dx);
{
auto g = [](float x, float y) {
if (y == 0.f) {
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 49646f1f3a..717263a9b0 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -392,8 +392,11 @@ Returns x * y element-wise.
REGISTER_OP("Div").BINARY_MORE().SetShapeFn(
shape_inference::BroadcastBinaryOpShapeFn);
-REGISTER_OP("UnsafeDiv")
- .BINARY_MORE()
+REGISTER_OP("DivNoNan")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {float, double}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
REGISTER_OP("FloorDiv")
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index ebeb048157..be4c3ed2b6 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -121,7 +121,7 @@ TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
"Mod", "Mul",
"NotEqual", "Pow",
"Sub", "SquaredDifference",
- "UnsafeDiv"}) {
+ "DivNoNan"}) {
ShapeInferenceTestOp op(op_name);
INFER_OK(op, "?;?", "?");
INFER_OK(op, "[1,2];?", "?");
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index e0f25fb4ef..94476acd4b 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -1009,6 +1009,7 @@ REGISTER_OP("SeluGrad")
.Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
+// TODO(b/111515541): change T to {half, bfloat16, float, double}
REGISTER_OP("Softplus")
.Input("features: T")
.Output("activations: T")
@@ -1022,6 +1023,7 @@ REGISTER_OP("SoftplusGrad")
.Attr("T: realnumbertype")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
+// TODO(b/111515541): change T to {half, bfloat16, float, double}
REGISTER_OP("Softsign")
.Input("features: T")
.Output("activations: T")
@@ -1736,6 +1738,87 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
+REGISTER_OP("_MklConv3D")
+ .Input("input: T")
+ .Input("filter: T")
+ .Input("mkl_input: uint8")
+ .Input("mkl_filter: uint8")
+ .Output("output: T")
+ .Output("filter_output: T")
+ .Output("mkl_output: uint8")
+ .Output("mkl_filter_output: uint8")
+ .Attr("T: {half, float, double}")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
+ .SetShapeFn(shape_inference::Conv3DShape)
+ .Doc(R"doc(
+MKL version of Conv3D operator. Uses MKL DNN APIs to perform 3D convolution.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklConv3DBackpropInputV2")
+ .Input("input_sizes: Tshape")
+ .Input("filter: T")
+ .Input("out_backprop: T")
+ .Input("mkl_input_sizes: uint8")
+ .Input("mkl_filter: uint8")
+ .Input("mkl_out_backprop: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: {half, float, double}")
+ .Attr("strides: list(int) >= 5")
+ .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
+ .Attr("Tshape: {int32, int64} = DT_INT32")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
+ c->set_output(0, s);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+MKL version of Convolution3D backward input. Uses MKL DNN APIs to compute the
+gradients of convolution with respect to the input.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklConv3DBackpropFilterV2")
+ .Input("input: T")
+ .Input("filter_sizes: int32")
+ .Input("out_backprop: T")
+ .Input("mkl_input: uint8")
+ .Input("mkl_filter_size: uint8")
+ .Input("mkl_out_backprop: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: {half, float, double}")
+ .Attr("strides: list(int)")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
+ c->set_output(0, s);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+MKL version of Conv3DBackpropFilter. Uses MKL DNN APIs to compute the
+gradients of convolution with respect to the filter.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
REGISTER_OP("_MklRelu")
.Input("features: T")
.Input("mkl_features: uint8")
@@ -2161,7 +2244,7 @@ REGISTER_OP("_MklToTf")
.Input("mkl_input: uint8")
.Output("output: T")
.Attr("T: {half, float, double}")
- .Attr(GetConvnetDataFormatAttrString())
+ .Attr(GetConvnetDataFormat2D3DAttrString())
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
MKL operator to convert a tensor from MKL layout to TensorFlow layout.
@@ -2183,7 +2266,7 @@ REGISTER_OP("_MklInputConversion")
.Attr(
"T: {half, float, double, uint8, int8, uint16, int16, int32, int64, "
"complex64, complex128}")
- .Attr(GetConvnetDataFormatAttrString())
+ .Attr(GetConvnetDataFormat2D3DAttrString())
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
MKL operator to process the inputs to an elementwise MKL op. Both inputs
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index eda82f9c18..560e706931 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -9190,6 +9190,31 @@ op {
}
}
op {
+ name: "DivNoNan"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "DrawBoundingBoxes"
input_arg {
name: "images"
@@ -31820,6 +31845,32 @@ op {
}
}
op {
+ name: "StaticRegexReplace"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "pattern"
+ type: "string"
+ }
+ attr {
+ name: "rewrite"
+ type: "string"
+ }
+ attr {
+ name: "replace_global"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "StatsAggregatorHandle"
output_arg {
name: "handle"
@@ -32120,6 +32171,17 @@ op {
}
}
op {
+ name: "StringLength"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_INT32
+ }
+}
+op {
name: "StringSplit"
input_arg {
name: "input"
@@ -34949,41 +35011,6 @@ op {
}
}
op {
- name: "UnsafeDiv"
- input_arg {
- name: "x"
- type_attr: "T"
- }
- input_arg {
- name: "y"
- type_attr: "T"
- }
- output_arg {
- name: "z"
- type_attr: "T"
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_BFLOAT16
- type: DT_HALF
- type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_UINT8
- type: DT_INT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT32
- type: DT_INT64
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- }
- }
- }
-}
-op {
name: "UnsortedSegmentMax"
input_arg {
name: "data"
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index 8c39d69157..7aa1e71809 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -37,6 +37,14 @@ REGISTER_OP("RegexReplace")
return Status::OK();
});
+REGISTER_OP("StaticRegexReplace")
+ .Input("input: string")
+ .Attr("pattern: string")
+ .Attr("rewrite: string")
+ .Output("output: string")
+ .Attr("replace_global: bool = true")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
REGISTER_OP("RegexFullMatch")
.Input("input: string")
.Input("pattern: string")
@@ -159,6 +167,11 @@ REGISTER_OP("StringStrip")
.Output("output: string")
.SetShapeFn(shape_inference::UnchangedShape);
+REGISTER_OP("StringLength")
+ .Input("input: string")
+ .Output("output: int32")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
REGISTER_OP("EncodeBase64")
.Input("input: string")
.Output("output: string")
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 6383180e94..5ec7a82ae9 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -13,219 +13,224 @@ load(
# Appends a suffix to a list of deps.
def tf_deps(deps, suffix):
- tf_deps = []
+ tf_deps = []
- # If the package name is in shorthand form (ie: does not contain a ':'),
- # expand it to the full name.
- for dep in deps:
- tf_dep = dep
+ # If the package name is in shorthand form (ie: does not contain a ':'),
+ # expand it to the full name.
+ for dep in deps:
+ tf_dep = dep
- if not ":" in dep:
- dep_pieces = dep.split("/")
- tf_dep += ":" + dep_pieces[len(dep_pieces) - 1]
+ if not ":" in dep:
+ dep_pieces = dep.split("/")
+ tf_dep += ":" + dep_pieces[len(dep_pieces) - 1]
- tf_deps += [tf_dep + suffix]
+ tf_deps += [tf_dep + suffix]
- return tf_deps
+ return tf_deps
# Modified from @cython//:Tools/rules.bzl
def pyx_library(
- name,
- deps=[],
- py_deps=[],
- srcs=[],
- **kwargs):
- """Compiles a group of .pyx / .pxd / .py files.
-
- First runs Cython to create .cpp files for each input .pyx or .py + .pxd
- pair. Then builds a shared object for each, passing "deps" to each cc_binary
- rule (includes Python headers by default). Finally, creates a py_library rule
- with the shared objects and any pure Python "srcs", with py_deps as its
- dependencies; the shared objects can be imported like normal Python files.
-
- Args:
- name: Name for the rule.
- deps: C/C++ dependencies of the Cython (e.g. Numpy headers).
- py_deps: Pure Python dependencies of the final library.
- srcs: .py, .pyx, or .pxd files to either compile or pass through.
- **kwargs: Extra keyword arguments passed to the py_library.
- """
- # First filter out files that should be run compiled vs. passed through.
- py_srcs = []
- pyx_srcs = []
- pxd_srcs = []
- for src in srcs:
- if src.endswith(".pyx") or (src.endswith(".py")
- and src[:-3] + ".pxd" in srcs):
- pyx_srcs.append(src)
- elif src.endswith(".py"):
- py_srcs.append(src)
- else:
- pxd_srcs.append(src)
- if src.endswith("__init__.py"):
- pxd_srcs.append(src)
-
- # Invoke cython to produce the shared object libraries.
- for filename in pyx_srcs:
- native.genrule(
- name = filename + "_cython_translation",
- srcs = [filename],
- outs = [filename.split(".")[0] + ".cpp"],
- # Optionally use PYTHON_BIN_PATH on Linux platforms so that python 3
- # works. Windows has issues with cython_binary so skip PYTHON_BIN_PATH.
- cmd = "PYTHONHASHSEED=0 $(location @cython//:cython_binary) --cplus $(SRCS) --output-file $(OUTS)",
- tools = ["@cython//:cython_binary"] + pxd_srcs,
+ name,
+ deps = [],
+ py_deps = [],
+ srcs = [],
+ **kwargs):
+ """Compiles a group of .pyx / .pxd / .py files.
+
+ First runs Cython to create .cpp files for each input .pyx or .py + .pxd
+ pair. Then builds a shared object for each, passing "deps" to each cc_binary
+ rule (includes Python headers by default). Finally, creates a py_library rule
+ with the shared objects and any pure Python "srcs", with py_deps as its
+ dependencies; the shared objects can be imported like normal Python files.
+
+ Args:
+ name: Name for the rule.
+ deps: C/C++ dependencies of the Cython (e.g. Numpy headers).
+ py_deps: Pure Python dependencies of the final library.
+ srcs: .py, .pyx, or .pxd files to either compile or pass through.
+ **kwargs: Extra keyword arguments passed to the py_library.
+ """
+
+ # First filter out files that should be run compiled vs. passed through.
+ py_srcs = []
+ pyx_srcs = []
+ pxd_srcs = []
+ for src in srcs:
+ if src.endswith(".pyx") or (src.endswith(".py") and
+ src[:-3] + ".pxd" in srcs):
+ pyx_srcs.append(src)
+ elif src.endswith(".py"):
+ py_srcs.append(src)
+ else:
+ pxd_srcs.append(src)
+ if src.endswith("__init__.py"):
+ pxd_srcs.append(src)
+
+ # Invoke cython to produce the shared object libraries.
+ for filename in pyx_srcs:
+ native.genrule(
+ name = filename + "_cython_translation",
+ srcs = [filename],
+ outs = [filename.split(".")[0] + ".cpp"],
+ # Optionally use PYTHON_BIN_PATH on Linux platforms so that python 3
+ # works. Windows has issues with cython_binary so skip PYTHON_BIN_PATH.
+ cmd = "PYTHONHASHSEED=0 $(location @cython//:cython_binary) --cplus $(SRCS) --output-file $(OUTS)",
+ tools = ["@cython//:cython_binary"] + pxd_srcs,
+ )
+
+ shared_objects = []
+ for src in pyx_srcs:
+ stem = src.split(".")[0]
+ shared_object_name = stem + ".so"
+ native.cc_binary(
+ name = shared_object_name,
+ srcs = [stem + ".cpp"],
+ deps = deps + ["//third_party/python_runtime:headers"],
+ linkshared = 1,
+ )
+ shared_objects.append(shared_object_name)
+
+ # Now create a py_library with these shared objects as data.
+ native.py_library(
+ name = name,
+ srcs = py_srcs,
+ deps = py_deps,
+ srcs_version = "PY2AND3",
+ data = shared_objects,
+ **kwargs
)
- shared_objects = []
- for src in pyx_srcs:
- stem = src.split(".")[0]
- shared_object_name = stem + ".so"
- native.cc_binary(
- name=shared_object_name,
- srcs=[stem + ".cpp"],
- deps=deps + ["//third_party/python_runtime:headers"],
- linkshared = 1,
- )
- shared_objects.append(shared_object_name)
-
- # Now create a py_library with these shared objects as data.
- native.py_library(
- name=name,
- srcs=py_srcs,
- deps=py_deps,
- srcs_version = "PY2AND3",
- data=shared_objects,
- **kwargs
- )
-
-def _proto_cc_hdrs(srcs, use_grpc_plugin=False):
- ret = [s[:-len(".proto")] + ".pb.h" for s in srcs]
- if use_grpc_plugin:
- ret += [s[:-len(".proto")] + ".grpc.pb.h" for s in srcs]
- return ret
-
-def _proto_cc_srcs(srcs, use_grpc_plugin=False):
- ret = [s[:-len(".proto")] + ".pb.cc" for s in srcs]
- if use_grpc_plugin:
- ret += [s[:-len(".proto")] + ".grpc.pb.cc" for s in srcs]
- return ret
-
-def _proto_py_outs(srcs, use_grpc_plugin=False):
- ret = [s[:-len(".proto")] + "_pb2.py" for s in srcs]
- if use_grpc_plugin:
- ret += [s[:-len(".proto")] + "_pb2_grpc.py" for s in srcs]
- return ret
+def _proto_cc_hdrs(srcs, use_grpc_plugin = False):
+ ret = [s[:-len(".proto")] + ".pb.h" for s in srcs]
+ if use_grpc_plugin:
+ ret += [s[:-len(".proto")] + ".grpc.pb.h" for s in srcs]
+ return ret
+
+def _proto_cc_srcs(srcs, use_grpc_plugin = False):
+ ret = [s[:-len(".proto")] + ".pb.cc" for s in srcs]
+ if use_grpc_plugin:
+ ret += [s[:-len(".proto")] + ".grpc.pb.cc" for s in srcs]
+ return ret
+
+def _proto_py_outs(srcs, use_grpc_plugin = False):
+ ret = [s[:-len(".proto")] + "_pb2.py" for s in srcs]
+ if use_grpc_plugin:
+ ret += [s[:-len(".proto")] + "_pb2_grpc.py" for s in srcs]
+ return ret
# Re-defined protocol buffer rule to allow building "header only" protocol
# buffers, to avoid duplicate registrations. Also allows non-iterable cc_libs
# containing select() statements.
def cc_proto_library(
- name,
- srcs=[],
- deps=[],
- cc_libs=[],
- include=None,
- protoc="@protobuf_archive//:protoc",
- internal_bootstrap_hack=False,
- use_grpc_plugin=False,
- use_grpc_namespace=False,
- default_header=False,
- **kargs):
- """Bazel rule to create a C++ protobuf library from proto source files.
-
- Args:
- name: the name of the cc_proto_library.
- srcs: the .proto files of the cc_proto_library.
- deps: a list of dependency labels; must be cc_proto_library.
- cc_libs: a list of other cc_library targets depended by the generated
- cc_library.
- include: a string indicating the include path of the .proto files.
- protoc: the label of the protocol compiler to generate the sources.
- internal_bootstrap_hack: a flag indicate the cc_proto_library is used only
- for bootstraping. When it is set to True, no files will be generated.
- The rule will simply be a provider for .proto files, so that other
- cc_proto_library can depend on it.
- use_grpc_plugin: a flag to indicate whether to call the grpc C++ plugin
- when processing the proto files.
- default_header: Controls the naming of generated rules. If True, the `name`
- rule will be header-only, and an _impl rule will contain the
- implementation. Otherwise the header-only rule (name + "_headers_only")
- must be referred to explicitly.
- **kargs: other keyword arguments that are passed to cc_library.
- """
-
- includes = []
- if include != None:
- includes = [include]
-
- if internal_bootstrap_hack:
- # For pre-checked-in generated files, we add the internal_bootstrap_hack
- # which will skip the codegen action.
+ name,
+ srcs = [],
+ deps = [],
+ cc_libs = [],
+ include = None,
+ protoc = "@protobuf_archive//:protoc",
+ internal_bootstrap_hack = False,
+ use_grpc_plugin = False,
+ use_grpc_namespace = False,
+ default_header = False,
+ **kargs):
+ """Bazel rule to create a C++ protobuf library from proto source files.
+
+ Args:
+ name: the name of the cc_proto_library.
+ srcs: the .proto files of the cc_proto_library.
+ deps: a list of dependency labels; must be cc_proto_library.
+ cc_libs: a list of other cc_library targets depended by the generated
+ cc_library.
+ include: a string indicating the include path of the .proto files.
+ protoc: the label of the protocol compiler to generate the sources.
+ internal_bootstrap_hack: a flag indicate the cc_proto_library is used only
+ for bootstraping. When it is set to True, no files will be generated.
+ The rule will simply be a provider for .proto files, so that other
+ cc_proto_library can depend on it.
+ use_grpc_plugin: a flag to indicate whether to call the grpc C++ plugin
+ when processing the proto files.
+ default_header: Controls the naming of generated rules. If True, the `name`
+ rule will be header-only, and an _impl rule will contain the
+ implementation. Otherwise the header-only rule (name + "_headers_only")
+ must be referred to explicitly.
+ **kargs: other keyword arguments that are passed to cc_library.
+ """
+
+ includes = []
+ if include != None:
+ includes = [include]
+
+ if internal_bootstrap_hack:
+ # For pre-checked-in generated files, we add the internal_bootstrap_hack
+ # which will skip the codegen action.
+ proto_gen(
+ name = name + "_genproto",
+ srcs = srcs,
+ deps = [s + "_genproto" for s in deps],
+ includes = includes,
+ protoc = protoc,
+ visibility = ["//visibility:public"],
+ )
+
+ # An empty cc_library to make rule dependency consistent.
+ native.cc_library(
+ name = name,
+ **kargs
+ )
+ return
+
+ grpc_cpp_plugin = None
+ plugin_options = []
+ if use_grpc_plugin:
+ grpc_cpp_plugin = "//external:grpc_cpp_plugin"
+ if use_grpc_namespace:
+ plugin_options = ["services_namespace=grpc"]
+
+ gen_srcs = _proto_cc_srcs(srcs, use_grpc_plugin)
+ gen_hdrs = _proto_cc_hdrs(srcs, use_grpc_plugin)
+ outs = gen_srcs + gen_hdrs
+
proto_gen(
- name=name + "_genproto",
- srcs=srcs,
- deps=[s + "_genproto" for s in deps],
- includes=includes,
- protoc=protoc,
- visibility=["//visibility:public"],
+ name = name + "_genproto",
+ srcs = srcs,
+ deps = [s + "_genproto" for s in deps],
+ includes = includes,
+ protoc = protoc,
+ plugin = grpc_cpp_plugin,
+ plugin_language = "grpc",
+ plugin_options = plugin_options,
+ gen_cc = 1,
+ outs = outs,
+ visibility = ["//visibility:public"],
)
- # An empty cc_library to make rule dependency consistent.
- native.cc_library(
- name=name,
- **kargs)
- return
-
- grpc_cpp_plugin = None
- plugin_options = []
- if use_grpc_plugin:
- grpc_cpp_plugin = "//external:grpc_cpp_plugin"
- if use_grpc_namespace:
- plugin_options = ["services_namespace=grpc"]
-
- gen_srcs = _proto_cc_srcs(srcs, use_grpc_plugin)
- gen_hdrs = _proto_cc_hdrs(srcs, use_grpc_plugin)
- outs = gen_srcs + gen_hdrs
-
- proto_gen(
- name=name + "_genproto",
- srcs=srcs,
- deps=[s + "_genproto" for s in deps],
- includes=includes,
- protoc=protoc,
- plugin=grpc_cpp_plugin,
- plugin_language="grpc",
- plugin_options=plugin_options,
- gen_cc=1,
- outs=outs,
- visibility=["//visibility:public"],
- )
-
- if use_grpc_plugin:
- cc_libs += select({
- "//tensorflow:linux_s390x": ["//external:grpc_lib_unsecure"],
- "//conditions:default": ["//external:grpc_lib"],
- })
- if default_header:
- header_only_name = name
- impl_name = name + "_impl"
- else:
- header_only_name = name + "_headers_only"
- impl_name = name
-
- native.cc_library(
- name=impl_name,
- srcs=gen_srcs,
- hdrs=gen_hdrs,
- deps=cc_libs + deps,
- includes=includes,
- **kargs)
- native.cc_library(
- name=header_only_name,
- deps=["@protobuf_archive//:protobuf_headers"] + if_static([impl_name]),
- hdrs=gen_hdrs,
- **kargs)
+ if use_grpc_plugin:
+ cc_libs += select({
+ "//tensorflow:linux_s390x": ["//external:grpc_lib_unsecure"],
+ "//conditions:default": ["//external:grpc_lib"],
+ })
+
+ if default_header:
+ header_only_name = name
+ impl_name = name + "_impl"
+ else:
+ header_only_name = name + "_headers_only"
+ impl_name = name
+
+ native.cc_library(
+ name = impl_name,
+ srcs = gen_srcs,
+ hdrs = gen_hdrs,
+ deps = cc_libs + deps,
+ includes = includes,
+ **kargs
+ )
+ native.cc_library(
+ name = header_only_name,
+ deps = ["@protobuf_archive//:protobuf_headers"] + if_static([impl_name]),
+ hdrs = gen_hdrs,
+ **kargs
+ )
# Re-defined protocol buffer rule to bring in the change introduced in commit
# https://github.com/google/protobuf/commit/294b5758c373cbab4b72f35f4cb62dc1d8332b68
@@ -234,474 +239,512 @@ def cc_proto_library(
# to include the above commit.
def py_proto_library(
name,
- srcs=[],
- deps=[],
- py_libs=[],
- py_extra_srcs=[],
- include=None,
- default_runtime="@protobuf_archive//:protobuf_python",
- protoc="@protobuf_archive//:protoc",
- use_grpc_plugin=False,
+ srcs = [],
+ deps = [],
+ py_libs = [],
+ py_extra_srcs = [],
+ include = None,
+ default_runtime = "@protobuf_archive//:protobuf_python",
+ protoc = "@protobuf_archive//:protoc",
+ use_grpc_plugin = False,
**kargs):
- """Bazel rule to create a Python protobuf library from proto source files
-
- NOTE: the rule is only an internal workaround to generate protos. The
- interface may change and the rule may be removed when bazel has introduced
- the native rule.
-
- Args:
- name: the name of the py_proto_library.
- srcs: the .proto files of the py_proto_library.
- deps: a list of dependency labels; must be py_proto_library.
- py_libs: a list of other py_library targets depended by the generated
- py_library.
- py_extra_srcs: extra source files that will be added to the output
- py_library. This attribute is used for internal bootstrapping.
- include: a string indicating the include path of the .proto files.
- default_runtime: the implicitly default runtime which will be depended on by
- the generated py_library target.
- protoc: the label of the protocol compiler to generate the sources.
- use_grpc_plugin: a flag to indicate whether to call the Python C++ plugin
- when processing the proto files.
- **kargs: other keyword arguments that are passed to cc_library.
- """
- outs = _proto_py_outs(srcs, use_grpc_plugin)
-
- includes = []
- if include != None:
- includes = [include]
-
- grpc_python_plugin = None
- if use_grpc_plugin:
- grpc_python_plugin = "//external:grpc_python_plugin"
- # Note: Generated grpc code depends on Python grpc module. This dependency
- # is not explicitly listed in py_libs. Instead, host system is assumed to
- # have grpc installed.
-
- proto_gen(
- name=name + "_genproto",
- srcs=srcs,
- deps=[s + "_genproto" for s in deps],
- includes=includes,
- protoc=protoc,
- gen_py=1,
- outs=outs,
- visibility=["//visibility:public"],
- plugin=grpc_python_plugin,
- plugin_language="grpc"
- )
-
- if default_runtime and not default_runtime in py_libs + deps:
- py_libs = py_libs + [default_runtime]
-
- native.py_library(
- name=name,
- srcs=outs+py_extra_srcs,
- deps=py_libs+deps,
- imports=includes,
- **kargs)
-
-def tf_proto_library_cc(name, srcs = [], has_services = None,
- protodeps = [],
- visibility = [], testonly = 0,
- cc_libs = [],
- cc_stubby_versions = None,
- cc_grpc_version = None,
- j2objc_api_version = 1,
- cc_api_version = 2,
- dart_api_version = 2,
- java_api_version = 2, py_api_version = 2,
- js_api_version = 2, js_codegen = "jspb",
- default_header = False):
- js_codegen = js_codegen # unused argument
- js_api_version = js_api_version # unused argument
- native.filegroup(
- name = name + "_proto_srcs",
- srcs = srcs + tf_deps(protodeps, "_proto_srcs"),
- testonly = testonly,
- visibility = visibility,
- )
-
- use_grpc_plugin = None
- if cc_grpc_version:
- use_grpc_plugin = True
-
- cc_deps = tf_deps(protodeps, "_cc")
- cc_name = name + "_cc"
- if not srcs:
- # This is a collection of sub-libraries. Build header-only and impl
- # libraries containing all the sources.
+ """Bazel rule to create a Python protobuf library from proto source files
+
+ NOTE: the rule is only an internal workaround to generate protos. The
+ interface may change and the rule may be removed when bazel has introduced
+ the native rule.
+
+ Args:
+ name: the name of the py_proto_library.
+ srcs: the .proto files of the py_proto_library.
+ deps: a list of dependency labels; must be py_proto_library.
+ py_libs: a list of other py_library targets depended by the generated
+ py_library.
+ py_extra_srcs: extra source files that will be added to the output
+ py_library. This attribute is used for internal bootstrapping.
+ include: a string indicating the include path of the .proto files.
+ default_runtime: the implicitly default runtime which will be depended on by
+ the generated py_library target.
+ protoc: the label of the protocol compiler to generate the sources.
+ use_grpc_plugin: a flag to indicate whether to call the Python C++ plugin
+ when processing the proto files.
+ **kargs: other keyword arguments that are passed to cc_library.
+ """
+ outs = _proto_py_outs(srcs, use_grpc_plugin)
+
+ includes = []
+ if include != None:
+ includes = [include]
+
+ grpc_python_plugin = None
+ if use_grpc_plugin:
+ grpc_python_plugin = "//external:grpc_python_plugin"
+ # Note: Generated grpc code depends on Python grpc module. This dependency
+ # is not explicitly listed in py_libs. Instead, host system is assumed to
+ # have grpc installed.
+
proto_gen(
- name = cc_name + "_genproto",
- deps = [s + "_genproto" for s in cc_deps],
- protoc = "@protobuf_archive//:protoc",
- visibility=["//visibility:public"],
+ name = name + "_genproto",
+ srcs = srcs,
+ deps = [s + "_genproto" for s in deps],
+ includes = includes,
+ protoc = protoc,
+ gen_py = 1,
+ outs = outs,
+ visibility = ["//visibility:public"],
+ plugin = grpc_python_plugin,
+ plugin_language = "grpc",
)
- native.cc_library(
- name = cc_name,
- deps = cc_deps + ["@protobuf_archive//:protobuf_headers"] +
- if_static([name + "_cc_impl"]),
+
+ if default_runtime and not default_runtime in py_libs + deps:
+ py_libs = py_libs + [default_runtime]
+
+ native.py_library(
+ name = name,
+ srcs = outs + py_extra_srcs,
+ deps = py_libs + deps,
+ imports = includes,
+ **kargs
+ )
+
+def tf_proto_library_cc(
+ name,
+ srcs = [],
+ has_services = None,
+ protodeps = [],
+ visibility = [],
+ testonly = 0,
+ cc_libs = [],
+ cc_stubby_versions = None,
+ cc_grpc_version = None,
+ j2objc_api_version = 1,
+ cc_api_version = 2,
+ dart_api_version = 2,
+ java_api_version = 2,
+ py_api_version = 2,
+ js_api_version = 2,
+ js_codegen = "jspb",
+ default_header = False):
+ js_codegen = js_codegen # unused argument
+ js_api_version = js_api_version # unused argument
+ native.filegroup(
+ name = name + "_proto_srcs",
+ srcs = srcs + tf_deps(protodeps, "_proto_srcs"),
testonly = testonly,
visibility = visibility,
)
- native.cc_library(
- name = cc_name + "_impl",
- deps = [s + "_impl" for s in cc_deps] + ["@protobuf_archive//:cc_wkt_protos"],
- )
- return
-
- cc_proto_library(
- name = cc_name,
- srcs = srcs,
- deps = cc_deps + ["@protobuf_archive//:cc_wkt_protos"],
- cc_libs = cc_libs + if_static(
- ["@protobuf_archive//:protobuf"],
- ["@protobuf_archive//:protobuf_headers"]
- ),
- copts = if_not_windows([
- "-Wno-unknown-warning-option",
- "-Wno-unused-but-set-variable",
- "-Wno-sign-compare",
- ]),
- protoc = "@protobuf_archive//:protoc",
- use_grpc_plugin = use_grpc_plugin,
- testonly = testonly,
- visibility = visibility,
- default_header = default_header,
- )
-
-def tf_proto_library_py(name, srcs=[], protodeps=[], deps=[], visibility=[],
- testonly=0, srcs_version="PY2AND3", use_grpc_plugin=False):
- py_deps = tf_deps(protodeps, "_py")
- py_name = name + "_py"
- if not srcs:
- # This is a collection of sub-libraries. Build header-only and impl
- # libraries containing all the sources.
- proto_gen(
- name = py_name + "_genproto",
- deps = [s + "_genproto" for s in py_deps],
+ use_grpc_plugin = None
+ if cc_grpc_version:
+ use_grpc_plugin = True
+
+ cc_deps = tf_deps(protodeps, "_cc")
+ cc_name = name + "_cc"
+ if not srcs:
+ # This is a collection of sub-libraries. Build header-only and impl
+ # libraries containing all the sources.
+ proto_gen(
+ name = cc_name + "_genproto",
+ deps = [s + "_genproto" for s in cc_deps],
+ protoc = "@protobuf_archive//:protoc",
+ visibility = ["//visibility:public"],
+ )
+ native.cc_library(
+ name = cc_name,
+ deps = cc_deps + ["@protobuf_archive//:protobuf_headers"] +
+ if_static([name + "_cc_impl"]),
+ testonly = testonly,
+ visibility = visibility,
+ )
+ native.cc_library(
+ name = cc_name + "_impl",
+ deps = [s + "_impl" for s in cc_deps] + ["@protobuf_archive//:cc_wkt_protos"],
+ )
+
+ return
+
+ cc_proto_library(
+ name = cc_name,
+ srcs = srcs,
+ deps = cc_deps + ["@protobuf_archive//:cc_wkt_protos"],
+ cc_libs = cc_libs + if_static(
+ ["@protobuf_archive//:protobuf"],
+ ["@protobuf_archive//:protobuf_headers"],
+ ),
+ copts = if_not_windows([
+ "-Wno-unknown-warning-option",
+ "-Wno-unused-but-set-variable",
+ "-Wno-sign-compare",
+ ]),
protoc = "@protobuf_archive//:protoc",
- visibility=["//visibility:public"],
+ use_grpc_plugin = use_grpc_plugin,
+ testonly = testonly,
+ visibility = visibility,
+ default_header = default_header,
)
- native.py_library(
+
+def tf_proto_library_py(
+ name,
+ srcs = [],
+ protodeps = [],
+ deps = [],
+ visibility = [],
+ testonly = 0,
+ srcs_version = "PY2AND3",
+ use_grpc_plugin = False):
+ py_deps = tf_deps(protodeps, "_py")
+ py_name = name + "_py"
+ if not srcs:
+ # This is a collection of sub-libraries. Build header-only and impl
+ # libraries containing all the sources.
+ proto_gen(
+ name = py_name + "_genproto",
+ deps = [s + "_genproto" for s in py_deps],
+ protoc = "@protobuf_archive//:protoc",
+ visibility = ["//visibility:public"],
+ )
+ native.py_library(
+ name = py_name,
+ deps = py_deps + ["@protobuf_archive//:protobuf_python"],
+ testonly = testonly,
+ visibility = visibility,
+ )
+ return
+
+ py_proto_library(
name = py_name,
- deps = py_deps + ["@protobuf_archive//:protobuf_python"],
- testonly = testonly,
+ srcs = srcs,
+ srcs_version = srcs_version,
+ deps = deps + py_deps + ["@protobuf_archive//:protobuf_python"],
+ protoc = "@protobuf_archive//:protoc",
+ default_runtime = "@protobuf_archive//:protobuf_python",
visibility = visibility,
+ testonly = testonly,
+ use_grpc_plugin = use_grpc_plugin,
)
- return
-
- py_proto_library(
- name = py_name,
- srcs = srcs,
- srcs_version = srcs_version,
- deps = deps + py_deps + ["@protobuf_archive//:protobuf_python"],
- protoc = "@protobuf_archive//:protoc",
- default_runtime = "@protobuf_archive//:protobuf_python",
- visibility = visibility,
- testonly = testonly,
- use_grpc_plugin = use_grpc_plugin,
- )
def tf_jspb_proto_library(**kwargs):
- pass
+ pass
def tf_nano_proto_library(**kwargs):
- pass
-
-def tf_proto_library(name, srcs = [], has_services = None,
- protodeps = [],
- visibility = [], testonly = 0,
- cc_libs = [],
- cc_api_version = 2, cc_grpc_version = None,
- dart_api_version = 2, j2objc_api_version = 1,
- java_api_version = 2, py_api_version = 2,
- js_api_version = 2, js_codegen = "jspb",
- provide_cc_alias = False,
- default_header = False):
- """Make a proto library, possibly depending on other proto libraries."""
- _ignore = (js_api_version, js_codegen, provide_cc_alias)
-
- tf_proto_library_cc(
- name = name,
- srcs = srcs,
- protodeps = protodeps,
- cc_grpc_version = cc_grpc_version,
- cc_libs = cc_libs,
- testonly = testonly,
- visibility = visibility,
- default_header = default_header,
- )
-
- tf_proto_library_py(
- name = name,
- srcs = srcs,
- protodeps = protodeps,
- srcs_version = "PY2AND3",
- testonly = testonly,
- visibility = visibility,
- use_grpc_plugin = has_services,
- )
+ pass
+
+def tf_proto_library(
+ name,
+ srcs = [],
+ has_services = None,
+ protodeps = [],
+ visibility = [],
+ testonly = 0,
+ cc_libs = [],
+ cc_api_version = 2,
+ cc_grpc_version = None,
+ dart_api_version = 2,
+ j2objc_api_version = 1,
+ java_api_version = 2,
+ py_api_version = 2,
+ js_api_version = 2,
+ js_codegen = "jspb",
+ provide_cc_alias = False,
+ default_header = False):
+ """Make a proto library, possibly depending on other proto libraries."""
+ _ignore = (js_api_version, js_codegen, provide_cc_alias)
+
+ tf_proto_library_cc(
+ name = name,
+ srcs = srcs,
+ protodeps = protodeps,
+ cc_grpc_version = cc_grpc_version,
+ cc_libs = cc_libs,
+ testonly = testonly,
+ visibility = visibility,
+ default_header = default_header,
+ )
+
+ tf_proto_library_py(
+ name = name,
+ srcs = srcs,
+ protodeps = protodeps,
+ srcs_version = "PY2AND3",
+ testonly = testonly,
+ visibility = visibility,
+ use_grpc_plugin = has_services,
+ )
# A list of all files under platform matching the pattern in 'files'. In
# contrast with 'tf_platform_srcs' below, which seletive collects files that
# must be compiled in the 'default' platform, this is a list of all headers
# mentioned in the platform/* files.
def tf_platform_hdrs(files):
- return native.glob(["platform/*/" + f for f in files])
+ return native.glob(["platform/*/" + f for f in files])
def tf_platform_srcs(files):
- base_set = ["platform/default/" + f for f in files]
- windows_set = base_set + ["platform/windows/" + f for f in files]
- posix_set = base_set + ["platform/posix/" + f for f in files]
-
- # Handle cases where we must also bring the posix file in. Usually, the list
- # of files to build on windows builds is just all the stuff in the
- # windows_set. However, in some cases the implementations in 'posix/' are
- # just what is necessary and historically we choose to simply use the posix
- # file instead of making a copy in 'windows'.
- for f in files:
- if f == "error.cc":
- windows_set.append("platform/posix/" + f)
-
- return select({
- "//tensorflow:windows" : native.glob(windows_set),
- "//conditions:default" : native.glob(posix_set),
- })
+ base_set = ["platform/default/" + f for f in files]
+ windows_set = base_set + ["platform/windows/" + f for f in files]
+ posix_set = base_set + ["platform/posix/" + f for f in files]
+
+ # Handle cases where we must also bring the posix file in. Usually, the list
+ # of files to build on windows builds is just all the stuff in the
+ # windows_set. However, in some cases the implementations in 'posix/' are
+ # just what is necessary and historically we choose to simply use the posix
+ # file instead of making a copy in 'windows'.
+ for f in files:
+ if f == "error.cc":
+ windows_set.append("platform/posix/" + f)
+
+ return select({
+ "//tensorflow:windows": native.glob(windows_set),
+ "//conditions:default": native.glob(posix_set),
+ })
def tf_additional_lib_hdrs(exclude = []):
- windows_hdrs = native.glob([
- "platform/default/*.h",
- "platform/windows/*.h",
- "platform/posix/error.h",
- ], exclude = exclude)
- return select({
- "//tensorflow:windows" : windows_hdrs,
- "//conditions:default" : native.glob([
+ windows_hdrs = native.glob([
"platform/default/*.h",
- "platform/posix/*.h",
- ], exclude = exclude),
- })
+ "platform/windows/*.h",
+ "platform/posix/error.h",
+ ], exclude = exclude)
+ return select({
+ "//tensorflow:windows": windows_hdrs,
+ "//conditions:default": native.glob([
+ "platform/default/*.h",
+ "platform/posix/*.h",
+ ], exclude = exclude),
+ })
def tf_additional_lib_srcs(exclude = []):
- windows_srcs = native.glob([
- "platform/default/*.cc",
- "platform/windows/*.cc",
- "platform/posix/error.cc",
- ], exclude = exclude)
- return select({
- "//tensorflow:windows" : windows_srcs,
- "//conditions:default" : native.glob([
+ windows_srcs = native.glob([
"platform/default/*.cc",
- "platform/posix/*.cc",
- ], exclude = exclude),
- })
+ "platform/windows/*.cc",
+ "platform/posix/error.cc",
+ ], exclude = exclude)
+ return select({
+ "//tensorflow:windows": windows_srcs,
+ "//conditions:default": native.glob([
+ "platform/default/*.cc",
+ "platform/posix/*.cc",
+ ], exclude = exclude),
+ })
def tf_additional_minimal_lib_srcs():
- return [
- "platform/default/integral_types.h",
- "platform/default/mutex.h",
- ]
+ return [
+ "platform/default/integral_types.h",
+ "platform/default/mutex.h",
+ ]
def tf_additional_proto_hdrs():
- return [
- "platform/default/integral_types.h",
- "platform/default/logging.h",
- "platform/default/protobuf.h"
- ] + if_windows([
- "platform/windows/integral_types.h",
- ])
+ return [
+ "platform/default/integral_types.h",
+ "platform/default/logging.h",
+ "platform/default/protobuf.h",
+ ] + if_windows([
+ "platform/windows/integral_types.h",
+ ])
+
+def tf_additional_proto_compiler_hdrs():
+ return [
+ "platform/default/protobuf_compiler.h",
+ ]
def tf_additional_proto_srcs():
- return [
- "platform/default/protobuf.cc",
- ]
+ return [
+ "platform/default/protobuf.cc",
+ ]
def tf_additional_human_readable_json_deps():
- return []
+ return []
def tf_additional_all_protos():
- return ["//tensorflow/core:protos_all"]
+ return ["//tensorflow/core:protos_all"]
def tf_protos_all_impl():
- return ["//tensorflow/core:protos_all_cc_impl"]
+ return ["//tensorflow/core:protos_all_cc_impl"]
def tf_protos_all():
- return if_static(
- extra_deps=tf_protos_all_impl(),
- otherwise=["//tensorflow/core:protos_all_cc"])
+ return if_static(
+ extra_deps = tf_protos_all_impl(),
+ otherwise = ["//tensorflow/core:protos_all_cc"],
+ )
def tf_protos_grappler_impl():
- return ["//tensorflow/core/grappler/costs:op_performance_data_cc_impl"]
+ return ["//tensorflow/core/grappler/costs:op_performance_data_cc_impl"]
def tf_protos_grappler():
- return if_static(
- extra_deps=tf_protos_grappler_impl(),
- otherwise=["//tensorflow/core/grappler/costs:op_performance_data_cc"])
+ return if_static(
+ extra_deps = tf_protos_grappler_impl(),
+ otherwise = ["//tensorflow/core/grappler/costs:op_performance_data_cc"],
+ )
def tf_additional_cupti_wrapper_deps():
- return ["//tensorflow/core/platform/default/gpu:cupti_wrapper"]
+ return ["//tensorflow/core/platform/default/gpu:cupti_wrapper"]
def tf_additional_device_tracer_srcs():
- return ["platform/default/device_tracer.cc"]
+ return ["platform/default/device_tracer.cc"]
def tf_additional_device_tracer_cuda_deps():
- return []
+ return []
def tf_additional_device_tracer_deps():
- return []
+ return []
def tf_additional_libdevice_data():
- return []
+ return []
def tf_additional_libdevice_deps():
- return ["@local_config_cuda//cuda:cuda_headers"]
+ return ["@local_config_cuda//cuda:cuda_headers"]
def tf_additional_libdevice_srcs():
- return ["platform/default/cuda_libdevice_path.cc"]
+ return ["platform/default/cuda_libdevice_path.cc"]
def tf_additional_test_deps():
- return []
+ return []
def tf_additional_test_srcs():
- return [
- "platform/default/test_benchmark.cc",
- ] + select({
- "//tensorflow:windows" : [
- "platform/windows/test.cc"
+ return [
+ "platform/default/test_benchmark.cc",
+ ] + select({
+ "//tensorflow:windows": [
+ "platform/windows/test.cc",
],
- "//conditions:default" : [
- "platform/posix/test.cc",
+ "//conditions:default": [
+ "platform/posix/test.cc",
],
})
def tf_kernel_tests_linkstatic():
- return 0
+ return 0
def tf_additional_lib_defines():
- """Additional defines needed to build TF libraries."""
- return select({
- "//tensorflow:with_jemalloc_linux_x86_64": ["TENSORFLOW_USE_JEMALLOC"],
- "//tensorflow:with_jemalloc_linux_ppc64le":["TENSORFLOW_USE_JEMALLOC"],
- "//conditions:default": [],
- }) + if_not_mobile(["TENSORFLOW_USE_ABSL"])
+ """Additional defines needed to build TF libraries."""
+ return select({
+ "//tensorflow:with_jemalloc_linux_x86_64": ["TENSORFLOW_USE_JEMALLOC"],
+ "//tensorflow:with_jemalloc_linux_ppc64le": ["TENSORFLOW_USE_JEMALLOC"],
+ "//conditions:default": [],
+ }) + if_not_mobile(["TENSORFLOW_USE_ABSL"])
def tf_additional_lib_deps():
- """Additional dependencies needed to build TF libraries."""
- return if_not_mobile(["@com_google_absl//absl/base:base"]) + if_static(
- ["@nsync//:nsync_cpp"],
- ["@nsync//:nsync_headers"]
- ) + select({
- "//tensorflow:with_jemalloc_linux_x86_64_dynamic": ["@jemalloc//:jemalloc_headers"],
- "//tensorflow:with_jemalloc_linux_ppc64le_dynamic": ["@jemalloc//:jemalloc_headers"],
- "//tensorflow:with_jemalloc_linux_x86_64": ["@jemalloc//:jemalloc_impl"],
- "//tensorflow:with_jemalloc_linux_ppc64le": ["@jemalloc//:jemalloc_impl"],
- "//conditions:default": [],
- })
+ """Additional dependencies needed to build TF libraries."""
+ return if_not_mobile(["@com_google_absl//absl/base:base"]) + if_static(
+ ["@nsync//:nsync_cpp"],
+ ["@nsync//:nsync_headers"],
+ ) + select({
+ "//tensorflow:with_jemalloc_linux_x86_64_dynamic": ["@jemalloc//:jemalloc_headers"],
+ "//tensorflow:with_jemalloc_linux_ppc64le_dynamic": ["@jemalloc//:jemalloc_headers"],
+ "//tensorflow:with_jemalloc_linux_x86_64": ["@jemalloc//:jemalloc_impl"],
+ "//tensorflow:with_jemalloc_linux_ppc64le": ["@jemalloc//:jemalloc_impl"],
+ "//conditions:default": [],
+ })
def tf_additional_core_deps():
- return select({
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
- "//tensorflow/core/platform/cloud:gcs_file_system",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_hdfs_support_windows_override": [],
- "//tensorflow:with_hdfs_support_android_override": [],
- "//tensorflow:with_hdfs_support_ios_override": [],
- "//tensorflow:with_hdfs_support": [
- "//tensorflow/core/platform/hadoop:hadoop_file_system",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_aws_support_windows_override": [],
- "//tensorflow:with_aws_support_android_override": [],
- "//tensorflow:with_aws_support_ios_override": [],
- "//tensorflow:with_aws_support": [
- "//tensorflow/core/platform/s3:s3_file_system",
- ],
- "//conditions:default": [],
- })
+ return select({
+ "//tensorflow:with_gcp_support_android_override": [],
+ "//tensorflow:with_gcp_support_ios_override": [],
+ "//tensorflow:with_gcp_support": [
+ "//tensorflow/core/platform/cloud:gcs_file_system",
+ ],
+ "//conditions:default": [],
+ }) + select({
+ "//tensorflow:with_hdfs_support_windows_override": [],
+ "//tensorflow:with_hdfs_support_android_override": [],
+ "//tensorflow:with_hdfs_support_ios_override": [],
+ "//tensorflow:with_hdfs_support": [
+ "//tensorflow/core/platform/hadoop:hadoop_file_system",
+ ],
+ "//conditions:default": [],
+ }) + select({
+ "//tensorflow:with_aws_support_windows_override": [],
+ "//tensorflow:with_aws_support_android_override": [],
+ "//tensorflow:with_aws_support_ios_override": [],
+ "//tensorflow:with_aws_support": [
+ "//tensorflow/core/platform/s3:s3_file_system",
+ ],
+ "//conditions:default": [],
+ })
# TODO(jart, jhseu): Delete when GCP is default on.
def tf_additional_cloud_op_deps():
- return select({
- "//tensorflow:with_gcp_support_windows_override": [],
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
- "//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib",
- "//tensorflow/contrib/cloud:gcs_config_ops_op_lib",
- ],
- "//conditions:default": [],
- })
+ return select({
+ "//tensorflow:with_gcp_support_windows_override": [],
+ "//tensorflow:with_gcp_support_android_override": [],
+ "//tensorflow:with_gcp_support_ios_override": [],
+ "//tensorflow:with_gcp_support": [
+ "//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib",
+ "//tensorflow/contrib/cloud:gcs_config_ops_op_lib",
+ ],
+ "//conditions:default": [],
+ })
# TODO(jart, jhseu): Delete when GCP is default on.
def tf_additional_cloud_kernel_deps():
- return select({
- "//tensorflow:with_gcp_support_windows_override": [],
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
- "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops",
- "//tensorflow/contrib/cloud/kernels:gcs_config_ops",
- ],
- "//conditions:default": [],
- })
+ return select({
+ "//tensorflow:with_gcp_support_windows_override": [],
+ "//tensorflow:with_gcp_support_android_override": [],
+ "//tensorflow:with_gcp_support_ios_override": [],
+ "//tensorflow:with_gcp_support": [
+ "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops",
+ "//tensorflow/contrib/cloud/kernels:gcs_config_ops",
+ ],
+ "//conditions:default": [],
+ })
def tf_lib_proto_parsing_deps():
- return [
- ":protos_all_cc",
- "//third_party/eigen3",
- "//tensorflow/core/platform/default/build_config:proto_parsing",
- ]
+ return [
+ ":protos_all_cc",
+ "//third_party/eigen3",
+ "//tensorflow/core/platform/default/build_config:proto_parsing",
+ ]
+
+def tf_lib_proto_compiler_deps():
+ return [
+ "@protobuf_archive//:protoc_lib",
+ ]
def tf_additional_verbs_lib_defines():
- return select({
- "//tensorflow:with_verbs_support": ["TENSORFLOW_USE_VERBS"],
- "//conditions:default": [],
- })
+ return select({
+ "//tensorflow:with_verbs_support": ["TENSORFLOW_USE_VERBS"],
+ "//conditions:default": [],
+ })
def tf_additional_mpi_lib_defines():
- return select({
- "//tensorflow:with_mpi_support": ["TENSORFLOW_USE_MPI"],
- "//conditions:default": [],
- })
+ return select({
+ "//tensorflow:with_mpi_support": ["TENSORFLOW_USE_MPI"],
+ "//conditions:default": [],
+ })
def tf_additional_gdr_lib_defines():
- return select({
- "//tensorflow:with_gdr_support": ["TENSORFLOW_USE_GDR"],
- "//conditions:default": [],
- })
+ return select({
+ "//tensorflow:with_gdr_support": ["TENSORFLOW_USE_GDR"],
+ "//conditions:default": [],
+ })
-def tf_py_clif_cc(name, visibility=None, **kwargs):
- pass
+def tf_py_clif_cc(name, visibility = None, **kwargs):
+ pass
-def tf_pyclif_proto_library(name, proto_lib, proto_srcfile="", visibility=None,
- **kwargs):
- pass
+def tf_pyclif_proto_library(
+ name,
+ proto_lib,
+ proto_srcfile = "",
+ visibility = None,
+ **kwargs):
+ pass
def tf_additional_binary_deps():
- return ["@nsync//:nsync_cpp"] + if_cuda(
- [
- "//tensorflow/stream_executor:cuda_platform",
- "//tensorflow/core/platform/default/build_config:cuda",
- ],
- ) + select({
- "//tensorflow:with_jemalloc_linux_x86_64": ["@jemalloc//:jemalloc_impl"],
- "//tensorflow:with_jemalloc_linux_ppc64le": ["@jemalloc//:jemalloc_impl"],
- "//conditions:default": [],
- }) + [
- # TODO(allenl): Split these out into their own shared objects (they are
- # here because they are shared between contrib/ op shared objects and
- # core).
- "//tensorflow/core/kernels:lookup_util",
- "//tensorflow/core/util/tensor_bundle",
- ] + if_mkl_ml(
- [
- "//third_party/intel_mkl_ml",
- ],
- )
+ return ["@nsync//:nsync_cpp"] + if_cuda(
+ [
+ "//tensorflow/stream_executor:cuda_platform",
+ "//tensorflow/core/platform/default/build_config:cuda",
+ ],
+ ) + select({
+ "//tensorflow:with_jemalloc_linux_x86_64": ["@jemalloc//:jemalloc_impl"],
+ "//tensorflow:with_jemalloc_linux_ppc64le": ["@jemalloc//:jemalloc_impl"],
+ "//conditions:default": [],
+ }) + [
+ # TODO(allenl): Split these out into their own shared objects (they are
+ # here because they are shared between contrib/ op shared objects and
+ # core).
+ "//tensorflow/core/kernels:lookup_util",
+ "//tensorflow/core/util/tensor_bundle",
+ ] + if_mkl_ml(
+ [
+ "//third_party/mkl:intel_binary_blob",
+ ],
+ )
diff --git a/tensorflow/core/platform/default/protobuf.h b/tensorflow/core/platform/default/protobuf.h
index c732c76ff7..bd9d41c62b 100644
--- a/tensorflow/core/platform/default/protobuf.h
+++ b/tensorflow/core/platform/default/protobuf.h
@@ -20,8 +20,8 @@ limitations under the License.
// IWYU pragma: friend third_party/tensorflow/core/platform/protobuf.h
#include "google/protobuf/arena.h"
-#include "google/protobuf/compiler/importer.h"
#include "google/protobuf/descriptor.h"
+#include "google/protobuf/descriptor.pb.h"
#include "google/protobuf/dynamic_message.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
diff --git a/tensorflow/compiler/xla/ptr_util.h b/tensorflow/core/platform/default/protobuf_compiler.h
index bfcdfc62f9..a93d7a184b 100644
--- a/tensorflow/compiler/xla/ptr_util.h
+++ b/tensorflow/core/platform/default/protobuf_compiler.h
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -13,23 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_
-#define TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_PROTOBUF_COMPILER_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_PROTOBUF_COMPILER_H_
-// As this was moved to tensorflow/core/util, provide indirections here to
-// maintain current functionality of the library.
+// IWYU pragma: private, include "third_party/tensorflow/core/platform/protobuf_compiler.h"
+// IWYU pragma: friend third_party/tensorflow/core/platform/protobuf_compiler.h
-#include <stddef.h>
+#include "google/protobuf/compiler/importer.h"
+#include "tensorflow/core/platform/default/protobuf.h"
-#include <memory>
-#include <type_traits>
-#include <utility>
-
-#include "tensorflow/core/util/ptr_util.h"
-
-namespace xla {
-using tensorflow::MakeUnique;
-using tensorflow::WrapUnique;
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_
+#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_PROTOBUF_H_
diff --git a/tensorflow/core/kernels/warn_about_ints.h b/tensorflow/core/platform/protobuf_compiler.h
index 20666b230e..29679e0089 100644
--- a/tensorflow/core/kernels/warn_about_ints.h
+++ b/tensorflow/core/platform/protobuf_compiler.h
@@ -13,17 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_WARN_ABOUT_INTS_H_
-#define TENSORFLOW_KERNELS_WARN_ABOUT_INTS_H_
+#ifndef TENSORFLOW_PLATFORM_PROTOBUF_COMPILER_H_
+#define TENSORFLOW_PLATFORM_PROTOBUF_COMPILER_H_
-#include "tensorflow/core/framework/op_kernel.h"
+#if defined(PLATFORM_GOOGLE) && !defined(USE_DEFAULT_PROTOBUF)
+#include "tensorflow/core/platform/google/protobuf_compiler.h"
+#else
+#include "tensorflow/core/platform/default/protobuf_compiler.h"
+#endif
-namespace tensorflow {
-
-// Warn if a kernel is being created using ints
-// TODO(irving): Remove in TF 2.0 along with the bad op registrations.
-void WarnAboutInts(OpKernelConstruction* context);
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_KERNELS_WARN_ABOUT_INTS_H_
+#endif // TENSORFLOW_PLATFORM_PROTOBUF_COMPILER_H_
diff --git a/tensorflow/core/util/env_var.h b/tensorflow/core/util/env_var.h
index 47f9ff3a3b..724ca35729 100644
--- a/tensorflow/core/util/env_var.h
+++ b/tensorflow/core/util/env_var.h
@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_ENV_VAR_H_
+#ifndef TENSORFLOW_CORE_UTIL_ENV_VAR_H_
+#define TENSORFLOW_CORE_UTIL_ENV_VAR_H_
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -42,4 +43,4 @@ Status ReadStringFromEnvVar(StringPiece env_var_name, StringPiece default_val,
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_ENV_VAR_H_
+#endif // TENSORFLOW_CORE_UTIL_ENV_VAR_H_
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 159a787d05..422be9356d 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -87,6 +87,16 @@ typedef enum {
Dim_I = 1
} MklDnnDims;
+typedef enum {
+ Dim3d_N = 0,
+ Dim3d_C = 1,
+ Dim3d_D = 2,
+ Dim3d_H = 3,
+ Dim3d_W = 4,
+ Dim3d_O = 0,
+ Dim3d_I = 1
+} MklDnnDims3D;
+
#ifdef INTEL_MKL_ML_ONLY
class MklShape {
public:
@@ -351,6 +361,7 @@ class MklShape {
#else
// Forward decl
+TensorFormat MklDnn3DDataFormatToTFDataFormat(memory::format format);
TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format);
memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
@@ -453,6 +464,13 @@ class MklDnnShape {
return this->DimSize(index);
}
+ inline size_t GetDimension3D(char dimension) const {
+ int index = GetMklDnnTensor3DDimIndex(dimension);
+ CHECK(index >= 0 && index < this->GetDimension())
+ << "Invalid index from the dimension: " << index << ", " << dimension;
+ return this->DimSize(index);
+ }
+
inline int32 GetMklDnnTensorDimIndex(char dimension) const {
switch (dimension) {
case 'N':
@@ -469,6 +487,24 @@ class MklDnnShape {
}
}
+ inline int32 GetMklDnnTensor3DDimIndex(char dimension) const {
+ switch (dimension) {
+ case 'N':
+ return MklDnnDims3D::Dim3d_N;
+ case 'C':
+ return MklDnnDims3D::Dim3d_C;
+ case 'D':
+ return MklDnnDims3D::Dim3d_D;
+ case 'H':
+ return MklDnnDims3D::Dim3d_H;
+ case 'W':
+ return MklDnnDims3D::Dim3d_W;
+ default:
+ LOG(FATAL) << "Invalid dimension: " << dimension;
+ return -1; // Avoid compiler warning about missing return value
+ }
+ }
+
inline size_t GetDimension() const { return data_.dimension_; }
inline const int* GetSizes() const {
return reinterpret_cast<const int*>(&data_.sizes_[0]);
@@ -587,13 +623,26 @@ class MklDnnShape {
}
inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
- // TODO(nhasabni): Why do we restrict this to 4D?
- CHECK_EQ(dimension, 4);
- CHECK(dimension == data_.dimension_);
- data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
- data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
- data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
- data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
+ if (dimension == 5) {
+ CHECK(dimension == data_.dimension_);
+ data_.map_[GetTensorDimIndex<3>(data_format, '0')] =
+ MklDnnDims3D::Dim3d_D;
+ data_.map_[GetTensorDimIndex<3>(data_format, '1')] =
+ MklDnnDims3D::Dim3d_H;
+ data_.map_[GetTensorDimIndex<3>(data_format, '2')] =
+ MklDnnDims3D::Dim3d_W;
+ data_.map_[GetTensorDimIndex<3>(data_format, 'C')] =
+ MklDnnDims3D::Dim3d_C;
+ data_.map_[GetTensorDimIndex<3>(data_format, 'N')] =
+ MklDnnDims3D::Dim3d_N;
+ } else {
+ CHECK_EQ(dimension, 4);
+ CHECK(dimension == data_.dimension_);
+ data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
+ data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
+ data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
+ data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
+ }
}
inline void SetTfDimOrder(const size_t dimension, memory::format format) {
@@ -1329,6 +1378,19 @@ memory::data_type MklDnnType<float>() {
return memory::data_type::f32;
}
+/// Map TensorFlow's data format into MKL-DNN 3D data format
+/// @input: TensorFlow data format
+/// @return: memory::format corresponding to TensorFlow data format;
+/// Fails with an error if invalid data format.
+inline memory::format TFDataFormatToMklDnn3DDataFormat(TensorFormat format) {
+ if (format == FORMAT_NHWC)
+ return memory::format::ndhwc;
+ else if (format == FORMAT_NCHW)
+ return memory::format::ncdhw;
+ TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
+ return memory::format::format_undef;
+}
+
/// Map TensorFlow's data format into MKL-DNN data format
///
/// @input: TensorFlow data format
@@ -1340,7 +1402,6 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
else if (format == FORMAT_NCHW)
return memory::format::nchw;
TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
- // Return to get rid of compiler warning
return memory::format::format_undef;
}
@@ -1350,9 +1411,9 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
/// @return: Tensorflow data format corresponding to memory::format
/// Fails with an error if invalid data format.
inline TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format) {
- if (format == memory::format::nhwc)
+ if (format == memory::format::nhwc || format == memory::format::ndhwc)
return FORMAT_NHWC;
- else if (format == memory::format::nchw)
+ else if (format == memory::format::nchw || format == memory::format::ncdhw)
return FORMAT_NCHW;
TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
@@ -1402,6 +1463,22 @@ inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
return memory::dims({n, c, h, w});
}
+inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape,
+ TensorFormat format) {
+ // Check validity of format.
+ CHECK_NE(TFDataFormatToMklDnn3DDataFormat(format),
+ memory::format::format_undef);
+
+ int n = shape.dim_size(GetTensorDimIndex<3>(format, 'N'));
+ int c = shape.dim_size(GetTensorDimIndex<3>(format, 'C'));
+ int d = shape.dim_size(GetTensorDimIndex<3>(format, '0'));
+ int h = shape.dim_size(GetTensorDimIndex<3>(format, '1'));
+ int w = shape.dim_size(GetTensorDimIndex<3>(format, '2'));
+
+ // MKL-DNN requires dimensions in NCDHW format.
+ return memory::dims({n, c, d, h, w});
+}
+
/// Overloaded version of function above. Input parameters are
/// self-explanatory.
inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims,
@@ -1514,6 +1591,8 @@ class MklDnnData {
/// Operations memory descriptor
memory::desc* op_md_;
+ // flat to indicate if data is 3D or not.
+ bool bIs3D;
/// Operations temp buffer
void* allocated_buffer_;
/// CPU engine on which operation will be executed
@@ -1540,6 +1619,10 @@ class MklDnnData {
static_cast<const void*>(tensor->flat<T>().data()));
}
+ void SetIs3DData(bool bIs3D_) { bIs3D = bIs3D_; }
+
+ bool GetIs3D() { return bIs3D; }
+
/// Set user memory primitive using specified dimensions, memory format and
/// data_buffer. Function automatically uses element data type by using
/// input type T used for creating call object.
diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc
index aca60b942d..ad8a44a518 100644
--- a/tensorflow/core/util/strided_slice_op.cc
+++ b/tensorflow/core/util/strided_slice_op.cc
@@ -326,7 +326,7 @@ Status ValidateStridedSliceOp(
// Even if we don't have values for begin or end, we do know that this
// dimension covers the whole interval. If we have shape information for
// this dimension, that tells us the interval length.
- if (dim_i > 0) {
+ if (dim_i >= 0) {
if (stride_i < 0) {
interval_length = -dim_i;
} else {
diff --git a/tensorflow/core/util/tensor_format.cc b/tensorflow/core/util/tensor_format.cc
index a5f7ecf0d1..f331973f5c 100644
--- a/tensorflow/core/util/tensor_format.cc
+++ b/tensorflow/core/util/tensor_format.cc
@@ -25,6 +25,10 @@ string GetConvnet3dDataFormatAttrString() {
return "data_format: { 'NDHWC', 'NCDHW' } = 'NDHWC' ";
}
+string GetConvnetDataFormat2D3DAttrString() {
+ return "data_format: { 'NHWC', 'NCHW', 'NDHWC', 'NCDHW' } = 'NHWC' ";
+}
+
string GetConvnetFilterFormatAttrString() {
return "filter_format: { 'HWIO', 'OIHW' } = 'HWIO' ";
}
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h
index 918835e1fb..b0c349dd90 100644
--- a/tensorflow/core/util/tensor_format.h
+++ b/tensorflow/core/util/tensor_format.h
@@ -483,6 +483,7 @@ string GetConvnet3dDataFormatAttrString();
// Return the string that specifies the filter format for convnet operations.
string GetConvnetFilterFormatAttrString();
string GetConvnet3dFilterFormatAttrString();
+string GetConvnetDataFormat2D3DAttrString();
// Returns a tensor shape for the specified format and dimension sizes.
// Works for both 2D and 3D operations. The output shapes are as follows:
diff --git a/tensorflow/docs_src/about/index.md b/tensorflow/docs_src/about/index.md
index dc1e9af876..c3c13ff329 100644
--- a/tensorflow/docs_src/about/index.md
+++ b/tensorflow/docs_src/about/index.md
@@ -3,9 +3,9 @@
This section provides a few documents about TensorFlow itself,
including the following:
- * @{$uses$TensorFlow in Use}, which provides a link to our model zoo and
+ * [TensorFlow in Use](../about/uses.md), which provides a link to our model zoo and
lists some popular ways that TensorFlow is being used.
- * @{$bib$TensorFlow White Papers}, which provides abstracts of white papers
+ * [TensorFlow White Papers](../about/bib.md), which provides abstracts of white papers
about TensorFlow.
- * @{$attribution$Attribution}, which specifies how to attribute and refer
+ * [Attribution](../about/attribution.md), which specifies how to attribute and refer
to TensorFlow.
diff --git a/tensorflow/docs_src/api_guides/python/client.md b/tensorflow/docs_src/api_guides/python/client.md
index 56367e6671..fdd48e66dc 100644
--- a/tensorflow/docs_src/api_guides/python/client.md
+++ b/tensorflow/docs_src/api_guides/python/client.md
@@ -3,7 +3,7 @@
This library contains classes for launching graphs and executing operations.
-@{$guide/low_level_intro$This guide} has examples of how a graph
+[This guide](../../guide/low_level_intro.md) has examples of how a graph
is launched in a `tf.Session`.
## Session management
diff --git a/tensorflow/docs_src/api_guides/python/constant_op.md b/tensorflow/docs_src/api_guides/python/constant_op.md
index 498ec3db5d..9ba95b0f55 100644
--- a/tensorflow/docs_src/api_guides/python/constant_op.md
+++ b/tensorflow/docs_src/api_guides/python/constant_op.md
@@ -64,7 +64,7 @@ print(sess.run(norm))
```
Another common use of random values is the initialization of variables. Also see
-the @{$variables$Variables How To}.
+the [Variables How To](../../guide/variables.md).
```python
# Use random uniform values in [0, 1) as the initializer for a variable of shape
diff --git a/tensorflow/docs_src/api_guides/python/input_dataset.md b/tensorflow/docs_src/api_guides/python/input_dataset.md
index ab572e53d4..911a76c2df 100644
--- a/tensorflow/docs_src/api_guides/python/input_dataset.md
+++ b/tensorflow/docs_src/api_guides/python/input_dataset.md
@@ -2,7 +2,7 @@
[TOC]
`tf.data.Dataset` allows you to build complex input pipelines. See the
-@{$guide/datasets} for an in-depth explanation of how to use this API.
+[Importing Data](../../guide/datasets.md) for an in-depth explanation of how to use this API.
## Reader classes
diff --git a/tensorflow/docs_src/api_guides/python/io_ops.md b/tensorflow/docs_src/api_guides/python/io_ops.md
index ab3c70daa0..d7ce6fdfde 100644
--- a/tensorflow/docs_src/api_guides/python/io_ops.md
+++ b/tensorflow/docs_src/api_guides/python/io_ops.md
@@ -8,7 +8,7 @@ Note: Functions taking `Tensor` arguments can also take anything accepted by
## Placeholders
TensorFlow provides a placeholder operation that must be fed with data
-on execution. For more info, see the section on @{$reading_data#Feeding$Feeding data}.
+on execution. For more info, see the section on [Feeding data](../../api_guides/python/reading_data.md#Feeding).
* `tf.placeholder`
* `tf.placeholder_with_default`
@@ -21,7 +21,7 @@ there is a convenience function:
## Readers
TensorFlow provides a set of Reader classes for reading data formats.
-For more information on inputs and readers, see @{$reading_data$Reading data}.
+For more information on inputs and readers, see [Reading data](../../api_guides/python/reading_data.md).
* `tf.ReaderBase`
* `tf.TextLineReader`
@@ -42,7 +42,7 @@ formats into tensors.
### Example protocol buffer
-TensorFlow's @{$reading_data#standard_tensorflow_format$recommended format for training examples}
+TensorFlow's [recommended format for training examples](../../api_guides/python/reading_data.md#standard_tensorflow_format)
is serialized `Example` protocol buffers, [described
here](https://www.tensorflow.org/code/tensorflow/core/example/example.proto).
They contain `Features`, [described
@@ -62,7 +62,7 @@ here](https://www.tensorflow.org/code/tensorflow/core/example/feature.proto).
TensorFlow provides several implementations of 'Queues', which are
structures within the TensorFlow computation graph to stage pipelines
of tensors together. The following describe the basic Queue interface
-and some implementations. To see an example use, see @{$threading_and_queues$Threading and Queues}.
+and some implementations. To see an example use, see [Threading and Queues](../../api_guides/python/threading_and_queues.md).
* `tf.QueueBase`
* `tf.FIFOQueue`
@@ -85,7 +85,7 @@ and some implementations. To see an example use, see @{$threading_and_queues$Th
## Input pipeline
TensorFlow functions for setting up an input-prefetching pipeline.
-Please see the @{$reading_data$reading data how-to}
+Please see the [reading data how-to](../../api_guides/python/reading_data.md)
for context.
### Beginning of an input pipeline
diff --git a/tensorflow/docs_src/api_guides/python/meta_graph.md b/tensorflow/docs_src/api_guides/python/meta_graph.md
index 7dbd9a56f4..5e8a8b4d0f 100644
--- a/tensorflow/docs_src/api_guides/python/meta_graph.md
+++ b/tensorflow/docs_src/api_guides/python/meta_graph.md
@@ -23,7 +23,7 @@ protocol buffer. It contains the following fields:
* [`SaverDef`](https://www.tensorflow.org/code/tensorflow/core/protobuf/saver.proto) for the saver.
* [`CollectionDef`](https://www.tensorflow.org/code/tensorflow/core/protobuf/meta_graph.proto)
map that further describes additional components of the model such as
-@{$python/state_ops$`Variables`},
+[`Variables`](../../api_guides/python/state_ops.md),
`tf.train.QueueRunner`, etc.
In order for a Python object to be serialized
diff --git a/tensorflow/docs_src/api_guides/python/reading_data.md b/tensorflow/docs_src/api_guides/python/reading_data.md
index 78c36d965c..9f555ee85d 100644
--- a/tensorflow/docs_src/api_guides/python/reading_data.md
+++ b/tensorflow/docs_src/api_guides/python/reading_data.md
@@ -1,7 +1,7 @@
# Reading data
Note: The preferred way to feed data into a tensorflow program is using the
-@{$datasets$`tf.data` API}.
+[`tf.data` API](../../guide/datasets.md).
There are four methods of getting data into a TensorFlow program:
@@ -16,7 +16,7 @@ There are four methods of getting data into a TensorFlow program:
## `tf.data` API
-See the @{$guide/datasets} for an in-depth explanation of `tf.data.Dataset`.
+See the [Importing Data](../../guide/datasets.md) for an in-depth explanation of `tf.data.Dataset`.
The `tf.data` API enables you to extract and preprocess data
from different input/file formats, and apply transformations such as batching,
shuffling, and mapping functions over the dataset. This is an improved version
@@ -56,8 +56,8 @@ in
## `QueueRunner`
Warning: This section discusses implementing input pipelines using the
-queue-based APIs which can be cleanly replaced by the @{$datasets$`tf.data`
-API}.
+queue-based APIs which can be cleanly replaced by the [`tf.data`
+API](../../guide/datasets.md).
A typical queue-based pipeline for reading records from files has the following stages:
@@ -154,14 +154,14 @@ a uint8 tensor, standard operations can slice out each piece and reformat as
needed. For CIFAR-10, you can see how to do the reading and decoding in
[`tensorflow_models/tutorials/image/cifar10/cifar10_input.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_input.py)
and described in
-@{$deep_cnn#prepare-the-data$this tutorial}.
+[this tutorial](../../tutorials/images/deep_cnn.md#prepare-the-data).
#### Standard TensorFlow format
Another approach is to convert whatever data you have into a supported format.
This approach makes it easier to mix and match data sets and network
architectures. The recommended format for TensorFlow is a
-@{$python/python_io#tfrecords_format_details$TFRecords file}
+[TFRecords file](../../api_guides/python/python_io.md#tfrecords_format_details)
containing
[`tf.train.Example` protocol buffers](https://www.tensorflow.org/code/tensorflow/core/example/example.proto)
(which contain
@@ -279,7 +279,7 @@ This can be important:
How many threads do you need? the `tf.train.shuffle_batch*` functions add a
summary to the graph that indicates how full the example queue is. If you have
enough reading threads, that summary will stay above zero. You can
-@{$summaries_and_tensorboard$view your summaries as training progresses using TensorBoard}.
+[view your summaries as training progresses using TensorBoard](../../guide/summaries_and_tensorboard.md).
### Creating threads to prefetch using `QueueRunner` objects
@@ -368,7 +368,7 @@ threads got an error when running some operation (or an ordinary Python
exception).
For more about threading, queues, QueueRunners, and Coordinators
-@{$threading_and_queues$see here}.
+[see here](../../api_guides/python/threading_and_queues.md).
#### Aside: How clean shut-down when limiting epochs works
@@ -501,18 +501,18 @@ sessions, maybe in separate processes:
model that reads validation input data.
This is what is done `tf.estimator` and manually in
-@{$deep_cnn#save-and-restore-checkpoints$the example CIFAR-10 model}.
+[the example CIFAR-10 model](../../tutorials/images/deep_cnn.md#save-and-restore-checkpoints).
This has a couple of benefits:
* The eval is performed on a single snapshot of the trained variables.
* You can perform the eval even after training has completed and exited.
You can have the train and eval in the same graph in the same process, and share
-their trained variables or layers. See @{$variables$the shared variables tutorial}.
+their trained variables or layers. See [the shared variables tutorial](../../guide/variables.md).
To support the single-graph approach
-@{$guide/datasets$`tf.data`} also supplies
-@{$guide/datasets#creating_an_iterator$advanced iterator types} that
+[`tf.data`](../../guide/datasets.md) also supplies
+[advanced iterator types](../../guide/datasets.md#creating_an_iterator) that
that allow the user to change the input pipeline without rebuilding the graph or
session.
diff --git a/tensorflow/docs_src/api_guides/python/regression_examples.md b/tensorflow/docs_src/api_guides/python/regression_examples.md
index f8abbf0f97..d67f38f57a 100644
--- a/tensorflow/docs_src/api_guides/python/regression_examples.md
+++ b/tensorflow/docs_src/api_guides/python/regression_examples.md
@@ -66,7 +66,7 @@ watch the following video:
<a name="running"></a>
## Running the examples
-You must @{$install$install TensorFlow} prior to running these examples.
+You must [install TensorFlow](../../install/index.md) prior to running these examples.
Depending on the way you've installed TensorFlow, you might also
need to activate your TensorFlow environment. Then, do the following:
diff --git a/tensorflow/docs_src/api_guides/python/summary.md b/tensorflow/docs_src/api_guides/python/summary.md
index e290703b7d..fc45e7b4c3 100644
--- a/tensorflow/docs_src/api_guides/python/summary.md
+++ b/tensorflow/docs_src/api_guides/python/summary.md
@@ -2,7 +2,7 @@
[TOC]
Summaries provide a way to export condensed information about a model, which is
-then accessible in tools such as @{$summaries_and_tensorboard$TensorBoard}.
+then accessible in tools such as [TensorBoard](../../guide/summaries_and_tensorboard.md).
## Generation of Summaries
diff --git a/tensorflow/docs_src/api_guides/python/threading_and_queues.md b/tensorflow/docs_src/api_guides/python/threading_and_queues.md
index 48f0778b73..e00f17f955 100644
--- a/tensorflow/docs_src/api_guides/python/threading_and_queues.md
+++ b/tensorflow/docs_src/api_guides/python/threading_and_queues.md
@@ -3,7 +3,7 @@
Note: In versions of TensorFlow before 1.2, we recommended using multi-threaded,
queue-based input pipelines for performance. Beginning with TensorFlow 1.4,
however, we recommend using the `tf.data` module instead. (See
-@{$datasets$Datasets} for details. In TensorFlow 1.2 and 1.3, the module was
+[Datasets](../../guide/datasets.md) for details. In TensorFlow 1.2 and 1.3, the module was
called `tf.contrib.data`.) The `tf.data` module offers an easier-to-use
interface for constructing efficient input pipelines. Furthermore, we've stopped
developing the old multi-threaded, queue-based input pipelines. We've retained
diff --git a/tensorflow/docs_src/api_guides/python/train.md b/tensorflow/docs_src/api_guides/python/train.md
index a118123665..4b4c6a4fe3 100644
--- a/tensorflow/docs_src/api_guides/python/train.md
+++ b/tensorflow/docs_src/api_guides/python/train.md
@@ -74,9 +74,9 @@ moving averages for evaluations often improve results significantly.
## Coordinator and QueueRunner
-See @{$threading_and_queues$Threading and Queues}
+See [Threading and Queues](../../api_guides/python/threading_and_queues.md)
for how to use threads and queues. For documentation on the Queue API,
-see @{$python/io_ops#queues$Queues}.
+see [Queues](../../api_guides/python/io_ops.md#queues).
* `tf.train.Coordinator`
@@ -87,7 +87,7 @@ see @{$python/io_ops#queues$Queues}.
## Distributed execution
-See @{$distributed$Distributed TensorFlow} for
+See [Distributed TensorFlow](../../deploy/distributed.md) for
more information about how to configure a distributed TensorFlow program.
* `tf.train.Server`
@@ -105,7 +105,7 @@ more information about how to configure a distributed TensorFlow program.
## Reading Summaries from Event Files
-See @{$summaries_and_tensorboard$Summaries and TensorBoard} for an
+See [Summaries and TensorBoard](../../guide/summaries_and_tensorboard.md) for an
overview of summaries, event files, and visualization in TensorBoard.
* `tf.train.summary_iterator`
diff --git a/tensorflow/docs_src/community/contributing.md b/tensorflow/docs_src/community/contributing.md
index afbb8bbdd0..ece4a7c70b 100644
--- a/tensorflow/docs_src/community/contributing.md
+++ b/tensorflow/docs_src/community/contributing.md
@@ -25,12 +25,12 @@ guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md
[developers@tensorflow.org](https://groups.google.com/a/tensorflow.org/d/forum/developers)
mailing list, to coordinate and discuss with others contributing to TensorFlow.
-* For coding style conventions, read the @{$style_guide$TensorFlow Style Guide}.
+* For coding style conventions, read the [TensorFlow Style Guide](../community/style_guide.md).
-* Finally, review @{$documentation$Writing TensorFlow Documentation}, which
+* Finally, review [Writing TensorFlow Documentation](../community/documentation.md), which
explains documentation conventions.
-You may also wish to review our guide to @{$benchmarks$defining and running benchmarks}.
+You may also wish to review our guide to [defining and running benchmarks](../community/benchmarks.md).
## Special Interest Groups
diff --git a/tensorflow/docs_src/community/index.md b/tensorflow/docs_src/community/index.md
index 865a203bf8..1a30be32a5 100644
--- a/tensorflow/docs_src/community/index.md
+++ b/tensorflow/docs_src/community/index.md
@@ -40,7 +40,7 @@ We recommend that you join this list if you depend on TensorFlow in any way.
### Development Roadmap
-The @{$roadmap$Roadmap} summarizes plans for upcoming additions to TensorFlow.
+The [Roadmap](../community/roadmap.md) summarizes plans for upcoming additions to TensorFlow.
### Social Media
@@ -70,12 +70,12 @@ the [TensorFlow discuss mailing
list](https://groups.google.com/a/tensorflow.org/d/forum/discuss).
A number of other mailing lists exist, focused on different project areas, which
-can be found at @{$lists$TensorFlow Mailing Lists}.
+can be found at [TensorFlow Mailing Lists](../community/lists.md).
### User Groups
To meet with like-minded people local to you, check out the many
-@{$groups$TensorFlow user groups} around the world.
+[TensorFlow user groups](../community/groups.md) around the world.
## Contributing To TensorFlow
diff --git a/tensorflow/docs_src/community/style_guide.md b/tensorflow/docs_src/community/style_guide.md
index daf0d2fdc0..c78da20edd 100644
--- a/tensorflow/docs_src/community/style_guide.md
+++ b/tensorflow/docs_src/community/style_guide.md
@@ -88,7 +88,7 @@ creates a part of the graph and returns output tensors.
* Operations should contain an extensive Python comment with Args and Returns
declarations that explain both the type and meaning of each value. Possible
shapes, dtypes, or ranks should be specified in the description.
- @{$documentation$See documentation details}
+ [See documentation details](../community/documentation.md)
* For increased usability include an example of usage with inputs / outputs
of the op in Example section.
diff --git a/tensorflow/docs_src/deploy/distributed.md b/tensorflow/docs_src/deploy/distributed.md
index 6a760f53c8..2fba36cfa7 100644
--- a/tensorflow/docs_src/deploy/distributed.md
+++ b/tensorflow/docs_src/deploy/distributed.md
@@ -2,7 +2,7 @@
This document shows how to create a cluster of TensorFlow servers, and how to
distribute a computation graph across that cluster. We assume that you are
-familiar with the @{$guide/low_level_intro$basic concepts} of
+familiar with the [basic concepts](../guide/low_level_intro.md) of
writing low level TensorFlow programs.
## Hello distributed TensorFlow!
diff --git a/tensorflow/docs_src/deploy/hadoop.md b/tensorflow/docs_src/deploy/hadoop.md
index c4471562b9..b0d416df2e 100644
--- a/tensorflow/docs_src/deploy/hadoop.md
+++ b/tensorflow/docs_src/deploy/hadoop.md
@@ -6,7 +6,7 @@ at the moment.
## HDFS
-We assume that you are familiar with @{$reading_data$reading data}.
+We assume that you are familiar with [reading data](../api_guides/python/reading_data.md).
To use HDFS with TensorFlow, change the file paths you use to read and write
data to an HDFS path. For example:
@@ -61,5 +61,5 @@ be set:
export KRB5CCNAME=/tmp/krb5cc_10002
```
-If you are running @{$distributed$Distributed TensorFlow}, then all
+If you are running [Distributed TensorFlow](../deploy/distributed.md), then all
workers must have the environment variables set and Hadoop installed.
diff --git a/tensorflow/docs_src/deploy/index.md b/tensorflow/docs_src/deploy/index.md
index 3322004189..08b28de639 100644
--- a/tensorflow/docs_src/deploy/index.md
+++ b/tensorflow/docs_src/deploy/index.md
@@ -3,11 +3,11 @@
This section focuses on deploying real-world models. It contains
the following documents:
- * @{$distributed$Distributed TensorFlow}, which explains how to create
+ * [Distributed TensorFlow](../deploy/distributed.md), which explains how to create
a cluster of TensorFlow servers.
- * @{$hadoop$How to run TensorFlow on Hadoop}, which has a highly
+ * [How to run TensorFlow on Hadoop](../deploy/hadoop.md), which has a highly
self-explanatory title.
- * @{$s3$How to run TensorFlow with the S3 filesystem}, which explains how
+ * [How to run TensorFlow with the S3 filesystem](../deploy/s3.md), which explains how
to run TensorFlow with the S3 file system.
* The entire document set for [TensorFlow serving](/serving), an open-source,
flexible, high-performance serving system for machine-learned models
diff --git a/tensorflow/docs_src/deploy/s3.md b/tensorflow/docs_src/deploy/s3.md
index 079c796aa7..b4a759d687 100644
--- a/tensorflow/docs_src/deploy/s3.md
+++ b/tensorflow/docs_src/deploy/s3.md
@@ -64,7 +64,7 @@ You should see output similar to this:
### Reading Data
-When @{$reading_data$reading data}, change the file paths you use to read and write
+When [reading data](../api_guides/python/reading_data.md), change the file paths you use to read and write
data to an S3 path. For example:
```python
diff --git a/tensorflow/docs_src/extend/add_filesys.md b/tensorflow/docs_src/extend/add_filesys.md
index bc0f662f0c..5f8ac64d25 100644
--- a/tensorflow/docs_src/extend/add_filesys.md
+++ b/tensorflow/docs_src/extend/add_filesys.md
@@ -225,7 +225,7 @@ it will use the `FooBarFileSystem` implementation.
Next, you must build a shared object containing this implementation. An example
of doing so using bazel's `cc_binary` rule can be found
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/BUILD#L244),
-but you may use any build system to do so. See the section on @{$adding_an_op#build_the_op_library$building the op library} for similar
+but you may use any build system to do so. See the section on [building the op library](../extend/adding_an_op.md#build_the_op_library) for similar
instructions.
The result of building this target is a `.so` shared object file.
diff --git a/tensorflow/docs_src/extend/adding_an_op.md b/tensorflow/docs_src/extend/adding_an_op.md
index fbf5c0b90d..cc25ab9b45 100644
--- a/tensorflow/docs_src/extend/adding_an_op.md
+++ b/tensorflow/docs_src/extend/adding_an_op.md
@@ -56,8 +56,8 @@ PREREQUISITES:
* Some familiarity with C++.
* Must have installed the
- @{$install$TensorFlow binary}, or must have
- @{$install_sources$downloaded TensorFlow source},
+ [TensorFlow binary](../install/index.md), or must have
+ [downloaded TensorFlow source](../install/install_sources.md),
and be able to build it.
[TOC]
@@ -1140,7 +1140,7 @@ In general, changes to existing, checked-in specifications must be
backwards-compatible: changing the specification of an op must not break prior
serialized `GraphDef` protocol buffers constructed from older specifications.
The details of `GraphDef` compatibility are
-@{$version_compat#compatibility_of_graphs_and_checkpoints$described here}.
+[described here](../guide/version_compat.md#compatibility_of_graphs_and_checkpoints).
There are several ways to preserve backwards-compatibility.
@@ -1190,7 +1190,7 @@ callers. The Python API may be kept compatible by careful changes in a
hand-written Python wrapper, by keeping the old signature except possibly adding
new optional arguments to the end. Generally incompatible changes may only be
made when TensorFlow's changes major versions, and must conform to the
-@{$version_compat#compatibility_of_graphs_and_checkpoints$`GraphDef` version semantics}.
+[`GraphDef` version semantics](../guide/version_compat.md#compatibility_of_graphs_and_checkpoints).
### GPU Support
@@ -1262,7 +1262,7 @@ For example, add `-L /usr/local/cuda-8.0/lib64/` if your CUDA is installed in
Given a graph of ops, TensorFlow uses automatic differentiation
(backpropagation) to add new ops representing gradients with respect to the
existing ops (see
-@{$python/train#gradient_computation$Gradient Computation}).
+[Gradient Computation](../api_guides/python/train.md#gradient_computation)).
To make automatic differentiation work for new ops, you must register a gradient
function which computes gradients with respect to the ops' inputs given
gradients with respect to the ops' outputs.
diff --git a/tensorflow/docs_src/extend/architecture.md b/tensorflow/docs_src/extend/architecture.md
index 83d70c9468..eb33336bee 100644
--- a/tensorflow/docs_src/extend/architecture.md
+++ b/tensorflow/docs_src/extend/architecture.md
@@ -7,8 +7,8 @@ learning models and system-level optimizations.
This document describes the system architecture that makes this
combination of scale and flexibility possible. It assumes that you have basic familiarity
with TensorFlow programming concepts such as the computation graph, operations,
-and sessions. See @{$guide/low_level_intro$this document} for an introduction to
-these topics. Some familiarity with @{$distributed$distributed TensorFlow}
+and sessions. See [this document](../guide/low_level_intro.md) for an introduction to
+these topics. Some familiarity with [distributed TensorFlow](../deploy/distributed.md)
will also be helpful.
This document is for developers who want to extend TensorFlow in some way not
@@ -199,7 +199,7 @@ Many of the operation kernels are implemented using Eigen::Tensor, which uses
C++ templates to generate efficient parallel code for multicore CPUs and GPUs;
however, we liberally use libraries like cuDNN where a more efficient kernel
implementation is possible. We have also implemented
-@{$quantization$quantization}, which enables
+[quantization](../performance/quantization.md), which enables
faster inference in environments such as mobile devices and high-throughput
datacenter applications, and use the
[gemmlowp](https://github.com/google/gemmlowp) low-precision matrix library to
@@ -209,7 +209,7 @@ If it is difficult or inefficient to represent a subcomputation as a composition
of operations, users can register additional kernels that provide an efficient
implementation written in C++. For example, we recommend registering your own
fused kernels for some performance critical operations, such as the ReLU and
-Sigmoid activation functions and their corresponding gradients. The @{$xla$XLA Compiler} has an
+Sigmoid activation functions and their corresponding gradients. The [XLA Compiler](../performance/xla/index.md) has an
experimental implementation of automatic kernel fusion.
### Code
diff --git a/tensorflow/docs_src/extend/index.md b/tensorflow/docs_src/extend/index.md
index 0e4bfd1dc4..bbf4a8139b 100644
--- a/tensorflow/docs_src/extend/index.md
+++ b/tensorflow/docs_src/extend/index.md
@@ -3,16 +3,16 @@
This section explains how developers can add functionality to TensorFlow's
capabilities. Begin by reading the following architectural overview:
- * @{$architecture$TensorFlow Architecture}
+ * [TensorFlow Architecture](../extend/architecture.md)
The following guides explain how to extend particular aspects of
TensorFlow:
- * @{$adding_an_op$Adding a New Op}, which explains how to create your own
+ * [Adding a New Op](../extend/adding_an_op.md), which explains how to create your own
operations.
- * @{$add_filesys$Adding a Custom Filesystem Plugin}, which explains how to
+ * [Adding a Custom Filesystem Plugin](../extend/add_filesys.md), which explains how to
add support for your own shared or distributed filesystem.
- * @{$new_data_formats$Custom Data Readers}, which details how to add support
+ * [Custom Data Readers](../extend/new_data_formats.md), which details how to add support
for your own file and record formats.
Python is currently the only language supported by TensorFlow's API stability
@@ -24,11 +24,11 @@ plus community support for [Haskell](https://github.com/tensorflow/haskell) and
develop TensorFlow features in a language other than these languages, read the
following guide:
- * @{$language_bindings$TensorFlow in Other Languages}
+ * [TensorFlow in Other Languages](../extend/language_bindings.md)
To create tools compatible with TensorFlow's model format, read the following
guide:
- * @{$tool_developers$A Tool Developer's Guide to TensorFlow Model Files}
+ * [A Tool Developer's Guide to TensorFlow Model Files](../extend/tool_developers/index.md)
diff --git a/tensorflow/docs_src/extend/language_bindings.md b/tensorflow/docs_src/extend/language_bindings.md
index 9a968d365b..4727eabdc1 100644
--- a/tensorflow/docs_src/extend/language_bindings.md
+++ b/tensorflow/docs_src/extend/language_bindings.md
@@ -125,7 +125,7 @@ The `OpDef` specifies the following:
instead of CamelCase for the op's function name.
- A list of inputs and outputs. The types for these may be polymorphic by
referencing attributes, as described in the inputs and outputs section of
- @{$adding_an_op$Adding an op}.
+ [Adding an op](../extend/adding_an_op.md).
- A list of attributes, along with their default values (if any). Note that
some of these will be inferred (if they are determined by an input), some
will be optional (if they have a default), and some will be required (no
diff --git a/tensorflow/docs_src/extend/new_data_formats.md b/tensorflow/docs_src/extend/new_data_formats.md
index 47a8344b70..7ca50c9c76 100644
--- a/tensorflow/docs_src/extend/new_data_formats.md
+++ b/tensorflow/docs_src/extend/new_data_formats.md
@@ -4,7 +4,7 @@ PREREQUISITES:
* Some familiarity with C++.
* Must have
- @{$install_sources$downloaded TensorFlow source}, and be
+ [downloaded TensorFlow source](../install/install_sources.md), and be
able to build it.
We divide the task of supporting a file format into two pieces:
@@ -67,7 +67,7 @@ need to:
You can put all the C++ code in a single file, such as
`my_reader_dataset_op.cc`. It will help if you are
-familiar with @{$adding_an_op$the adding an op how-to}. The following skeleton
+familiar with [the adding an op how-to](../extend/adding_an_op.md). The following skeleton
can be used as a starting point for your implementation:
```c++
@@ -227,8 +227,8 @@ REGISTER_KERNEL_BUILDER(Name("MyReaderDataset").Device(tensorflow::DEVICE_CPU),
```
The last step is to build the C++ code and add a Python wrapper. The easiest way
-to do this is by @{$adding_an_op#build_the_op_library$compiling a dynamic
-library} (e.g. called `"my_reader_dataset_op.so"`), and adding a Python class
+to do this is by [compiling a dynamic
+library](../extend/adding_an_op.md#build_the_op_library) (e.g. called `"my_reader_dataset_op.so"`), and adding a Python class
that subclasses `tf.data.Dataset` to wrap it. An example Python program is
given here:
@@ -285,7 +285,7 @@ You can see some examples of `Dataset` wrapper classes in
## Writing an Op for a record format
Generally this is an ordinary op that takes a scalar string record as input, and
-so follow @{$adding_an_op$the instructions to add an Op}.
+so follow [the instructions to add an Op](../extend/adding_an_op.md).
You may optionally take a scalar string key as input, and include that in error
messages reporting improperly formatted data. That way users can more easily
track down where the bad data came from.
diff --git a/tensorflow/docs_src/guide/checkpoints.md b/tensorflow/docs_src/guide/checkpoints.md
index e1add29852..3c92cbbd40 100644
--- a/tensorflow/docs_src/guide/checkpoints.md
+++ b/tensorflow/docs_src/guide/checkpoints.md
@@ -9,13 +9,13 @@ Estimators. TensorFlow provides two model formats:
the model.
This document focuses on checkpoints. For details on `SavedModel`, see the
-@{$saved_model$Saving and Restoring} guide.
+[Saving and Restoring](../guide/saved_model.md) guide.
## Sample code
This document relies on the same
-[Iris classification example](https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py) detailed in @{$premade_estimators$Getting Started with TensorFlow}.
+[Iris classification example](https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py) detailed in [Getting Started with TensorFlow](../guide/premade_estimators.md).
To download and access the example, invoke the following two commands:
```shell
@@ -160,7 +160,7 @@ checkpoint to the `model_dir`. Each subsequent call to the Estimator's
1. The Estimator builds the model's
[graph](https://developers.google.com/machine-learning/glossary/#graph)
by running the `model_fn()`. (For details on the `model_fn()`, see
- @{$custom_estimators$Creating Custom Estimators.})
+ [Creating Custom Estimators.](../guide/custom_estimators.md))
2. The Estimator initializes the weights of the new model from the data
stored in the most recent checkpoint.
@@ -231,7 +231,7 @@ This separation will keep your checkpoints recoverable.
Checkpoints provide an easy automatic mechanism for saving and restoring
models created by Estimators.
-See the @{$saved_model$Saving and Restoring} guide for details about:
+See the [Saving and Restoring](../guide/saved_model.md) guide for details about:
* Saving and restoring models using low-level TensorFlow APIs.
* Exporting and importing models in the SavedModel format, which is a
diff --git a/tensorflow/docs_src/guide/custom_estimators.md b/tensorflow/docs_src/guide/custom_estimators.md
index 199a0e93de..913a35920f 100644
--- a/tensorflow/docs_src/guide/custom_estimators.md
+++ b/tensorflow/docs_src/guide/custom_estimators.md
@@ -5,7 +5,7 @@ This document introduces custom Estimators. In particular, this document
demonstrates how to create a custom `tf.estimator.Estimator` that
mimics the behavior of the pre-made Estimator
`tf.estimator.DNNClassifier` in solving the Iris problem. See
-the @{$premade_estimators$Pre-Made Estimators chapter} for details
+the [Pre-Made Estimators chapter](../guide/premade_estimators.md) for details
on the Iris problem.
To download and access the example code invoke the following two commands:
@@ -84,7 +84,7 @@ and a logits output layer.
## Write an Input function
Our custom Estimator implementation uses the same input function as our
-@{$premade_estimators$pre-made Estimator implementation}, from
+[pre-made Estimator implementation](../guide/premade_estimators.md), from
[`iris_data.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py).
Namely:
@@ -106,8 +106,8 @@ This input function builds an input pipeline that yields batches of
## Create feature columns
-As detailed in the @{$premade_estimators$Premade Estimators} and
-@{$feature_columns$Feature Columns} chapters, you must define
+As detailed in the [Premade Estimators](../guide/premade_estimators.md) and
+[Feature Columns](../guide/feature_columns.md) chapters, you must define
your model's feature columns to specify how the model should use each feature.
Whether working with pre-made Estimators or custom Estimators, you define
feature columns in the same fashion.
@@ -145,7 +145,7 @@ to the constructor are in turn passed on to the `model_fn`. In
[`custom_estimator.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/custom_estimator.py)
the following lines create the estimator and set the params to configure the
model. This configuration step is similar to how we configured the `tf.estimator.DNNClassifier` in
-@{$premade_estimators}.
+[Premade Estimators](../guide/premade_estimators.md).
```python
classifier = tf.estimator.Estimator(
@@ -489,7 +489,7 @@ configure your Estimator without modifying the code in the `model_fn`.
The rest of the code to train, evaluate, and generate predictions using our
Estimator is the same as in the
-@{$premade_estimators$Premade Estimators} chapter. For
+[Premade Estimators](../guide/premade_estimators.md) chapter. For
example, the following line will train the model:
```python
@@ -597,6 +597,6 @@ For more details, be sure to check out:
which contains more curated examples using custom estimators.
* This [TensorBoard video](https://youtu.be/eBbEDRsCmv4), which introduces
TensorBoard.
-* The @{$low_level_intro$Low Level Introduction}, which demonstrates
+* The [Low Level Introduction](../guide/low_level_intro.md), which demonstrates
how to experiment directly with TensorFlow's low level APIs, making debugging
easier.
diff --git a/tensorflow/docs_src/guide/datasets.md b/tensorflow/docs_src/guide/datasets.md
index bb18e8b79c..60de181b21 100644
--- a/tensorflow/docs_src/guide/datasets.md
+++ b/tensorflow/docs_src/guide/datasets.md
@@ -335,7 +335,7 @@ restore the current state of the iterator (and, effectively, the whole input
pipeline). A saveable object thus created can be added to `tf.train.Saver`
variables list or the `tf.GraphKeys.SAVEABLE_OBJECTS` collection for saving and
restoring in the same manner as a `tf.Variable`. Refer to
-@{$saved_model$Saving and Restoring} for details on how to save and restore
+[Saving and Restoring](../guide/saved_model.md) for details on how to save and restore
variables.
```python
@@ -782,8 +782,9 @@ with tf.train.MonitoredTrainingSession(...) as sess:
sess.run(training_op)
```
-To use a `Dataset` in the `input_fn` of a `tf.estimator.Estimator`, we also
-recommend using `Dataset.make_one_shot_iterator()`. For example:
+To use a `Dataset` in the `input_fn` of a `tf.estimator.Estimator`, simply
+return the `Dataset` and the framework will take care of creating an iterator
+and initializing it for you. For example:
```python
def dataset_input_fn():
@@ -814,10 +815,9 @@ def dataset_input_fn():
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
- iterator = dataset.make_one_shot_iterator()
- # `features` is a dictionary in which each value is a batch of values for
- # that feature; `labels` is a batch of labels.
- features, labels = iterator.get_next()
- return features, labels
+ # Each element of `dataset` is tuple containing a dictionary of features
+ # (in which each value is a batch of values for that feature), and a batch of
+ # labels.
+ return dataset
```
diff --git a/tensorflow/docs_src/guide/datasets_for_estimators.md b/tensorflow/docs_src/guide/datasets_for_estimators.md
index 969ea579f7..09a3830ca9 100644
--- a/tensorflow/docs_src/guide/datasets_for_estimators.md
+++ b/tensorflow/docs_src/guide/datasets_for_estimators.md
@@ -14,7 +14,7 @@ introduces the API by walking through two simple examples:
Taking slices from an array is the simplest way to get started with `tf.data`.
-The @{$premade_estimators$Premade Estimators} chapter describes
+The [Premade Estimators](../guide/premade_estimators.md) chapter describes
the following `train_input_fn`, from
[`iris_data.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py),
to pipe the data into the Estimator:
@@ -91,8 +91,8 @@ print(mnist_ds)
```
This will print the following line, showing the
-@{$guide/tensors#shapes$shapes} and
-@{$guide/tensors#data_types$types} of the items in
+[shapes](../guide/tensors.md#shapes) and
+[types](../guide/tensors.md#data_types) of the items in
the dataset. Note that a `Dataset` does not know how many items it contains.
``` None
@@ -128,7 +128,7 @@ print(dataset)
Here we see that when a `Dataset` contains structured elements, the `shapes`
and `types` of the `Dataset` take on the same structure. This dataset contains
-dictionaries of @{$guide/tensors#rank$scalars}, all of type
+dictionaries of [scalars](../guide/tensors.md#rank), all of type
`tf.float64`.
The first line of the iris `train_input_fn` uses the same functionality, but
@@ -377,11 +377,11 @@ Now you have the basic idea of how to efficiently load data into an
Estimator. Consider the following documents next:
-* @{$custom_estimators}, which demonstrates how to build your own
+* [Creating Custom Estimators](../guide/custom_estimators.md), which demonstrates how to build your own
custom `Estimator` model.
-* The @{$low_level_intro#datasets$Low Level Introduction}, which demonstrates
+* The [Low Level Introduction](../guide/low_level_intro.md#datasets), which demonstrates
how to experiment directly with `tf.data.Datasets` using TensorFlow's low
level APIs.
-* @{$guide/datasets} which goes into great detail about additional
+* [Importing Data](../guide/datasets.md) which goes into great detail about additional
functionality of `Datasets`.
diff --git a/tensorflow/docs_src/guide/debugger.md b/tensorflow/docs_src/guide/debugger.md
index 4c4a04a88a..5af27471a2 100644
--- a/tensorflow/docs_src/guide/debugger.md
+++ b/tensorflow/docs_src/guide/debugger.md
@@ -95,7 +95,7 @@ intermediate tensors (tensors that are neither inputs or outputs of the
`Session.run()` call, but are in the path leading from the inputs to the
outputs). This filter is for `nan`s and `inf`s is a common enough use case that
we ship it with the
-@{$python/tfdbg#Classes_for_debug_dump_data_and_directories$`debug_data`}
+[`debug_data`](../api_guides/python/tfdbg.md#Classes_for_debug_dump_data_and_directories)
module.
Note: You can also write your own custom filters. See `tfdbg.DebugDumpDir.find`
diff --git a/tensorflow/docs_src/guide/eager.md b/tensorflow/docs_src/guide/eager.md
index e47a8b599c..3b5797a638 100644
--- a/tensorflow/docs_src/guide/eager.md
+++ b/tensorflow/docs_src/guide/eager.md
@@ -558,7 +558,7 @@ m.result() # => 5.5
#### Summaries and TensorBoard
-@{$summaries_and_tensorboard$TensorBoard} is a visualization tool for
+[TensorBoard](../guide/summaries_and_tensorboard.md) is a visualization tool for
understanding, debugging and optimizing the model training process. It uses
summary events that are written while executing the program.
diff --git a/tensorflow/docs_src/guide/embedding.md b/tensorflow/docs_src/guide/embedding.md
index 8a98367dfb..6007e6847b 100644
--- a/tensorflow/docs_src/guide/embedding.md
+++ b/tensorflow/docs_src/guide/embedding.md
@@ -78,7 +78,7 @@ Embeddings can be trained in many network types, and with various loss
functions and data sets. For example, one could use a recurrent neural network
to predict the next word from the previous one given a large corpus of
sentences, or one could train two networks to do multi-lingual translation.
-These methods are described in the @{$word2vec$Vector Representations of Words}
+These methods are described in the [Vector Representations of Words](../tutorials/representation/word2vec.md)
tutorial.
## Visualizing Embeddings
diff --git a/tensorflow/docs_src/guide/estimators.md b/tensorflow/docs_src/guide/estimators.md
index 7b54e3de29..3903bfd126 100644
--- a/tensorflow/docs_src/guide/estimators.md
+++ b/tensorflow/docs_src/guide/estimators.md
@@ -84,7 +84,7 @@ of the following four steps:
... # manipulate dataset, extracting the feature dict and the label
return feature_dict, label
- (See @{$guide/datasets} for full details.)
+ (See [Importing Data](../guide/datasets.md) for full details.)
2. **Define the feature columns.** Each `tf.feature_column`
identifies a feature name, its type, and any input pre-processing.
@@ -136,7 +136,7 @@ The heart of every Estimator--whether pre-made or custom--is its
evaluation, and prediction. When you are using a pre-made Estimator,
someone else has already implemented the model function. When relying
on a custom Estimator, you must write the model function yourself. A
-@{$custom_estimators$companion document}
+[companion document](../guide/custom_estimators.md)
explains how to write the model function.
diff --git a/tensorflow/docs_src/guide/faq.md b/tensorflow/docs_src/guide/faq.md
index 8370097560..a02635ebba 100644
--- a/tensorflow/docs_src/guide/faq.md
+++ b/tensorflow/docs_src/guide/faq.md
@@ -2,7 +2,7 @@
This document provides answers to some of the frequently asked questions about
TensorFlow. If you have a question that is not covered here, you might find an
-answer on one of the TensorFlow @{$about$community resources}.
+answer on one of the TensorFlow [community resources](../about/index.md).
[TOC]
@@ -11,7 +11,7 @@ answer on one of the TensorFlow @{$about$community resources}.
#### Can I run distributed training on multiple computers?
Yes! TensorFlow gained
-@{$distributed$support for distributed computation} in
+[support for distributed computation](../deploy/distributed.md) in
version 0.8. TensorFlow now supports multiple devices (CPUs and GPUs) in one or
more computers.
@@ -23,7 +23,7 @@ As of the 0.6.0 release timeframe (Early December 2015), we do support Python
## Building a TensorFlow graph
See also the
-@{$python/framework$API documentation on building graphs}.
+[API documentation on building graphs](../api_guides/python/framework.md).
#### Why does `c = tf.matmul(a, b)` not execute the matrix multiplication immediately?
@@ -48,16 +48,16 @@ device, and `"/device:GPU:i"` (or `"/gpu:i"`) for the *i*th GPU device.
To place a group of operations on a device, create them within a
`tf.device` context. See
the how-to documentation on
-@{$using_gpu$using GPUs with TensorFlow} for details of how
+[using GPUs with TensorFlow](../guide/using_gpu.md) for details of how
TensorFlow assigns operations to devices, and the
-@{$deep_cnn$CIFAR-10 tutorial} for an example model that
+[CIFAR-10 tutorial](../tutorials/images/deep_cnn.md) for an example model that
uses multiple GPUs.
## Running a TensorFlow computation
See also the
-@{$python/client$API documentation on running graphs}.
+[API documentation on running graphs](../api_guides/python/client.md).
#### What's the deal with feeding and placeholders?
@@ -106,7 +106,7 @@ a significant amount of memory, and can be released when the session is closed b
`tf.Session.close`.
The intermediate tensors that are created as part of a call to
-@{$python/client$`Session.run()`} will be freed at or before the
+[`Session.run()`](../api_guides/python/client.md) will be freed at or before the
end of the call.
#### Does the runtime parallelize parts of graph execution?
@@ -118,7 +118,7 @@ dimensions:
CPU, or multiple threads in a GPU.
* Independent nodes in a TensorFlow graph can run in parallel on multiple
devices, which makes it possible to speed up
- @{$deep_cnn$CIFAR-10 training using multiple GPUs}.
+ [CIFAR-10 training using multiple GPUs](../tutorials/images/deep_cnn.md).
* The Session API allows multiple concurrent steps (i.e. calls to
`tf.Session.run` in parallel). This
enables the runtime to get higher throughput, if a single step does not use
@@ -141,9 +141,9 @@ Bindings for various other languages (such as [C#](https://github.com/migueldeic
#### Does TensorFlow make use of all the devices (GPUs and CPUs) available on my machine?
TensorFlow supports multiple GPUs and CPUs. See the how-to documentation on
-@{$using_gpu$using GPUs with TensorFlow} for details of how
+[using GPUs with TensorFlow](../guide/using_gpu.md) for details of how
TensorFlow assigns operations to devices, and the
-@{$deep_cnn$CIFAR-10 tutorial} for an example model that
+[CIFAR-10 tutorial](../tutorials/images/deep_cnn.md) for an example model that
uses multiple GPUs.
Note that TensorFlow only uses GPU devices with a compute capability greater
@@ -155,16 +155,16 @@ The `tf.ReaderBase` and
`tf.QueueBase` classes provide special operations that
can *block* until input (or free space in a bounded queue) becomes
available. These operations allow you to build sophisticated
-@{$reading_data$input pipelines}, at the cost of making the
+[input pipelines](../api_guides/python/reading_data.md), at the cost of making the
TensorFlow computation somewhat more complicated. See the how-to documentation
for
-@{$reading_data#creating_threads_to_prefetch_using_queuerunner_objects$using `QueueRunner` objects to drive queues and readers}
+[using `QueueRunner` objects to drive queues and readers](../api_guides/python/reading_data.md#creating_threads_to_prefetch_using_queuerunner_objects)
for more information on how to use them.
## Variables
-See also the how-to documentation on @{$variables$variables} and
-@{$python/state_ops$the API documentation for variables}.
+See also the how-to documentation on [variables](../guide/variables.md) and
+[the API documentation for variables](../api_guides/python/state_ops.md).
#### What is the lifetime of a variable?
@@ -231,7 +231,7 @@ to encode the batch size as a Python constant, but instead to use a symbolic
#### How can I visualize a TensorFlow graph?
-See the @{$graph_viz$graph visualization tutorial}.
+See the [graph visualization tutorial](../guide/graph_viz.md).
#### What is the simplest way to send data to TensorBoard?
@@ -241,7 +241,7 @@ these summaries to a log directory. Then, start TensorBoard using
python tensorflow/tensorboard/tensorboard.py --logdir=path/to/log-directory
For more details, see the
-@{$summaries_and_tensorboard$Summaries and TensorBoard tutorial}.
+[Summaries and TensorBoard tutorial](../guide/summaries_and_tensorboard.md).
#### Every time I launch TensorBoard, I get a network security popup!
@@ -251,7 +251,7 @@ the flag --host=localhost. This should quiet any security warnings.
## Extending TensorFlow
See the how-to documentation for
-@{$adding_an_op$adding a new operation to TensorFlow}.
+[adding a new operation to TensorFlow](../extend/adding_an_op.md).
#### My data is in a custom format. How do I read it using TensorFlow?
@@ -273,8 +273,8 @@ consider converting it, offline, to a format that is easily parsable, such
as `tf.python_io.TFRecordWriter` format.
The most efficient method to customize the parsing behavior is to
-@{$adding_an_op$add a new op written in C++} that parses your
-data format. The @{$new_data_formats$guide to handling new data formats} has
+[add a new op written in C++](../extend/adding_an_op.md) that parses your
+data format. The [guide to handling new data formats](../extend/new_data_formats.md) has
more information about the steps for doing this.
diff --git a/tensorflow/docs_src/guide/feature_columns.md b/tensorflow/docs_src/guide/feature_columns.md
index b189c4334e..3ad41855e4 100644
--- a/tensorflow/docs_src/guide/feature_columns.md
+++ b/tensorflow/docs_src/guide/feature_columns.md
@@ -5,7 +5,7 @@ intermediaries between raw data and Estimators. Feature columns are very rich,
enabling you to transform a diverse range of raw data into formats that
Estimators can use, allowing easy experimentation.
-In @{$premade_estimators$Premade Estimators}, we used the premade
+In [Premade Estimators](../guide/premade_estimators.md), we used the premade
Estimator, `tf.estimator.DNNClassifier` to train a model to
predict different types of Iris flowers from four input features. That example
created only numerical feature columns (of type
@@ -534,7 +534,7 @@ embedding_column = tf.feature_column.embedding_column(
dimension=embedding_dimensions)
```
-@{$guide/embedding$Embeddings} is a significant topic within machine
+[Embeddings](../guide/embedding.md) is a significant topic within machine
learning. This information was just to get you started using them as feature
columns.
@@ -559,7 +559,7 @@ As the following list indicates, not all Estimators permit all types of
For more examples on feature columns, view the following:
-* The @{$low_level_intro#feature_columns$Low Level Introduction} demonstrates how
+* The [Low Level Introduction](../guide/low_level_intro.md#feature_columns) demonstrates how
experiment directly with `feature_columns` using TensorFlow's low level APIs.
* The [Estimator wide and deep learning tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep)
solves a binary classification problem using `feature_columns` on a variety of
diff --git a/tensorflow/docs_src/guide/graph_viz.md b/tensorflow/docs_src/guide/graph_viz.md
index 97b0e2d4de..23f722bbe7 100644
--- a/tensorflow/docs_src/guide/graph_viz.md
+++ b/tensorflow/docs_src/guide/graph_viz.md
@@ -5,7 +5,7 @@ TensorFlow computation graphs are powerful but complicated. The graph visualizat
![Visualization of a TensorFlow graph](https://www.tensorflow.org/images/graph_vis_animation.gif "Visualization of a TensorFlow graph")
*Visualization of a TensorFlow graph.*
-To see your own graph, run TensorBoard pointing it to the log directory of the job, click on the graph tab on the top pane and select the appropriate run using the menu at the upper left corner. For in depth information on how to run TensorBoard and make sure you are logging all the necessary information, see @{$summaries_and_tensorboard$TensorBoard: Visualizing Learning}.
+To see your own graph, run TensorBoard pointing it to the log directory of the job, click on the graph tab on the top pane and select the appropriate run using the menu at the upper left corner. For in depth information on how to run TensorBoard and make sure you are logging all the necessary information, see [TensorBoard: Visualizing Learning](../guide/summaries_and_tensorboard.md).
## Name scoping and nodes
@@ -251,7 +251,7 @@ is a snippet from the train and test section of a modification of the
[Estimators MNIST tutorial](../tutorials/estimators/cnn.md), in which we have
recorded summaries and
runtime statistics. See the
-@{$summaries_and_tensorboard#serializing-the-data$Summaries Tutorial}
+[Summaries Tutorial](../guide/summaries_and_tensorboard.md#serializing-the-data)
for details on how to record summaries.
Full source is [here](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py).
diff --git a/tensorflow/docs_src/guide/graphs.md b/tensorflow/docs_src/guide/graphs.md
index 2bb44fbb32..c70479dba2 100644
--- a/tensorflow/docs_src/guide/graphs.md
+++ b/tensorflow/docs_src/guide/graphs.md
@@ -38,13 +38,13 @@ programs:
machines. TensorFlow inserts the necessary communication and coordination
between devices.
-* **Compilation.** TensorFlow's @{$performance/xla$XLA compiler} can
+* **Compilation.** TensorFlow's [XLA compiler](../performance/xla/index.md) can
use the information in your dataflow graph to generate faster code, for
example, by fusing together adjacent operations.
* **Portability.** The dataflow graph is a language-independent representation
of the code in your model. You can build a dataflow graph in Python, store it
- in a @{$saved_model$SavedModel}, and restore it in a C++ program for
+ in a [SavedModel](../guide/saved_model.md), and restore it in a C++ program for
low-latency inference.
@@ -93,7 +93,7 @@ to all API functions in the same context. For example:
stored value. The `tf.Variable` object also has methods such as
`tf.Variable.assign` and `tf.Variable.assign_add` that
create `tf.Operation` objects that, when executed, update the stored value.
- (See @{$guide/variables} for more information about variables.)
+ (See [Variables](../guide/variables.md) for more information about variables.)
* Calling `tf.train.Optimizer.minimize` will add operations and tensors to the
default graph that calculates gradients, and return a `tf.Operation` that,
@@ -210,7 +210,7 @@ with tf.device("/device:GPU:0"):
# Operations created in this context will be pinned to the GPU.
result = tf.matmul(weights, img)
```
-If you are deploying TensorFlow in a @{$distributed$typical distributed configuration},
+If you are deploying TensorFlow in a [typical distributed configuration](../deploy/distributed.md),
you might specify the job name and task ID to place variables on
a task in the parameter server job (`"/job:ps"`), and the other operations on
task in the worker job (`"/job:worker"`):
diff --git a/tensorflow/docs_src/guide/index.md b/tensorflow/docs_src/guide/index.md
index 1c920e7d70..50499582cc 100644
--- a/tensorflow/docs_src/guide/index.md
+++ b/tensorflow/docs_src/guide/index.md
@@ -5,38 +5,38 @@ works. The units are as follows:
## High Level APIs
- * @{$guide/keras}, TensorFlow's high-level API for building and
+ * [Keras](../guide/keras.md), TensorFlow's high-level API for building and
training deep learning models.
- * @{$guide/eager}, an API for writing TensorFlow code
+ * [Eager Execution](../guide/eager.md), an API for writing TensorFlow code
imperatively, like you would use Numpy.
- * @{$guide/datasets}, easy input pipelines to bring your data into
+ * [Importing Data](../guide/datasets.md), easy input pipelines to bring your data into
your TensorFlow program.
- * @{$guide/estimators}, a high-level API that provides
+ * [Estimators](../guide/estimators.md), a high-level API that provides
fully-packaged models ready for large-scale training and production.
## Estimators
-* @{$premade_estimators}, the basics of premade Estimators.
-* @{$checkpoints}, save training progress and resume where you left off.
-* @{$feature_columns}, handle a variety of input data types without changes to the model.
-* @{$datasets_for_estimators}, use `tf.data` to input data.
-* @{$custom_estimators}, write your own Estimator.
+* [Premade Estimators](../guide/premade_estimators.md), the basics of premade Estimators.
+* [Checkpoints](../guide/checkpoints.md), save training progress and resume where you left off.
+* [Feature Columns](../guide/feature_columns.md), handle a variety of input data types without changes to the model.
+* [Datasets for Estimators](../guide/datasets_for_estimators.md), use `tf.data` to input data.
+* [Creating Custom Estimators](../guide/custom_estimators.md), write your own Estimator.
## Accelerators
- * @{$using_gpu} explains how TensorFlow assigns operations to
+ * [Using GPUs](../guide/using_gpu.md) explains how TensorFlow assigns operations to
devices and how you can change the arrangement manually.
- * @{$using_tpu} explains how to modify `Estimator` programs to run on a TPU.
+ * [Using TPUs](../guide/using_tpu.md) explains how to modify `Estimator` programs to run on a TPU.
## Low Level APIs
- * @{$guide/low_level_intro}, which introduces the
+ * [Introduction](../guide/low_level_intro.md), which introduces the
basics of how you can use TensorFlow outside of the high Level APIs.
- * @{$guide/tensors}, which explains how to create,
+ * [Tensors](../guide/tensors.md), which explains how to create,
manipulate, and access Tensors--the fundamental object in TensorFlow.
- * @{$guide/variables}, which details how
+ * [Variables](../guide/variables.md), which details how
to represent shared, persistent state in your program.
- * @{$guide/graphs}, which explains:
+ * [Graphs and Sessions](../guide/graphs.md), which explains:
* dataflow graphs, which are TensorFlow's representation of computations
as dependencies between operations.
* sessions, which are TensorFlow's mechanism for running dataflow graphs
@@ -46,19 +46,19 @@ works. The units are as follows:
such as Estimators or Keras, the high-level API creates and manages
graphs and sessions for you, but understanding graphs and sessions
can still be helpful.
- * @{$guide/saved_model}, which
+ * [Save and Restore](../guide/saved_model.md), which
explains how to save and restore variables and models.
## ML Concepts
- * @{$guide/embedding}, which introduces the concept
+ * [Embeddings](../guide/embedding.md), which introduces the concept
of embeddings, provides a simple example of training an embedding in
TensorFlow, and explains how to view embeddings with the TensorBoard
Embedding Projector.
## Debugging
- * @{$guide/debugger}, which
+ * [TensorFlow Debugger](../guide/debugger.md), which
explains how to use the TensorFlow debugger (tfdbg).
## TensorBoard
@@ -66,17 +66,17 @@ works. The units are as follows:
TensorBoard is a utility to visualize different aspects of machine learning.
The following guides explain how to use TensorBoard:
- * @{$guide/summaries_and_tensorboard},
+ * [TensorBoard: Visualizing Learning](../guide/summaries_and_tensorboard.md),
which introduces TensorBoard.
- * @{$guide/graph_viz}, which
+ * [TensorBoard: Graph Visualization](../guide/graph_viz.md), which
explains how to visualize the computational graph.
- * @{$guide/tensorboard_histograms} which demonstrates the how to
+ * [TensorBoard Histogram Dashboard](../guide/tensorboard_histograms.md) which demonstrates the how to
use TensorBoard's histogram dashboard.
## Misc
- * @{$guide/version_compat},
+ * [TensorFlow Version Compatibility](../guide/version_compat.md),
which explains backward compatibility guarantees and non-guarantees.
- * @{$guide/faq}, which contains frequently asked
+ * [Frequently Asked Questions](../guide/faq.md), which contains frequently asked
questions about TensorFlow.
diff --git a/tensorflow/docs_src/guide/low_level_intro.md b/tensorflow/docs_src/guide/low_level_intro.md
index dc6cb9ee0d..d002f8af0b 100644
--- a/tensorflow/docs_src/guide/low_level_intro.md
+++ b/tensorflow/docs_src/guide/low_level_intro.md
@@ -9,7 +9,7 @@ This guide gets you started programming in the low-level TensorFlow APIs
* Use high level components ([datasets](#datasets), [layers](#layers), and
[feature_columns](#feature_columns)) in this low level environment.
* Build your own training loop, instead of using the one
- @{$premade_estimators$provided by Estimators}.
+ [provided by Estimators](../guide/premade_estimators.md).
We recommend using the higher level APIs to build models when possible.
Knowing TensorFlow Core is valuable for the following reasons:
@@ -21,7 +21,7 @@ Knowing TensorFlow Core is valuable for the following reasons:
## Setup
-Before using this guide, @{$install$install TensorFlow}.
+Before using this guide, [install TensorFlow](../install/index.md).
To get the most out of this guide, you should know the following:
@@ -145,7 +145,7 @@ browser, and you should see a graph similar to the following:
![TensorBoard screenshot](https://www.tensorflow.org/images/getting_started_add.png)
-For more about TensorBoard's graph visualization tools see @{$graph_viz}.
+For more about TensorBoard's graph visualization tools see [TensorBoard: Graph Visualization](../guide/graph_viz.md).
### Session
@@ -303,7 +303,7 @@ while True:
break
```
-For more details on Datasets and Iterators see: @{$guide/datasets}.
+For more details on Datasets and Iterators see: [Importing Data](../guide/datasets.md).
## Layers
@@ -398,7 +398,7 @@ and layer reuse impossible.
The easiest way to experiment with feature columns is using the
`tf.feature_column.input_layer` function. This function only accepts
-@{$feature_columns$dense columns} as inputs, so to view the result
+[dense columns](../guide/feature_columns.md) as inputs, so to view the result
of a categorical column you must wrap it in an
`tf.feature_column.indicator_column`. For example:
@@ -589,7 +589,7 @@ print(sess.run(y_pred))
To learn more about building models with TensorFlow consider the following:
-* @{$custom_estimators$Custom Estimators}, to learn how to build
+* [Custom Estimators](../guide/custom_estimators.md), to learn how to build
customized models with TensorFlow. Your knowledge of TensorFlow Core will
help you understand and debug your own models.
@@ -597,8 +597,8 @@ If you want to learn more about the inner workings of TensorFlow consider the
following documents, which go into more depth on many of the topics discussed
here:
-* @{$graphs}
-* @{$tensors}
-* @{$variables}
+* [Graphs and Sessions](../guide/graphs.md)
+* [Tensors](../guide/tensors.md)
+* [Variables](../guide/variables.md)
diff --git a/tensorflow/docs_src/guide/premade_estimators.md b/tensorflow/docs_src/guide/premade_estimators.md
index dc38f0c1d3..a1703058c3 100644
--- a/tensorflow/docs_src/guide/premade_estimators.md
+++ b/tensorflow/docs_src/guide/premade_estimators.md
@@ -8,7 +8,7 @@ how to solve the Iris classification problem in TensorFlow.
Prior to using the sample code in this document, you'll need to do the
following:
-* @{$install$Install TensorFlow}.
+* [Install TensorFlow](../install/index.md).
* If you installed TensorFlow with virtualenv or Anaconda, activate your
TensorFlow environment.
* Install or upgrade pandas by issuing the following command:
@@ -78,10 +78,10 @@ provides a programming stack consisting of multiple API layers:
We strongly recommend writing TensorFlow programs with the following APIs:
-* @{$guide/estimators$Estimators}, which represent a complete model.
+* [Estimators](../guide/estimators.md), which represent a complete model.
The Estimator API provides methods to train the model, to judge the model's
accuracy, and to generate predictions.
-* @{$guide/datasets_for_estimators}, which build a data input
+* [Datasets for Estimators](../guide/datasets_for_estimators.md), which build a data input
pipeline. The Dataset API has methods to load and manipulate data, and feed
it into your model. The Dataset API meshes well with the Estimators API.
@@ -173,14 +173,14 @@ example is an Iris Versicolor.
An Estimator is TensorFlow's high-level representation of a complete model. It
handles the details of initialization, logging, saving and restoring, and many
other features so you can concentrate on your model. For more details see
-@{$guide/estimators}.
+[Estimators](../guide/estimators.md).
An Estimator is any class derived from `tf.estimator.Estimator`. TensorFlow
provides a collection of
`tf.estimator`
(for example, `LinearRegressor`) to implement common ML algorithms. Beyond
those, you may write your own
-@{$custom_estimators$custom Estimators}.
+[custom Estimators](../guide/custom_estimators.md).
We recommend using pre-made Estimators when just getting started.
To write a TensorFlow program based on pre-made Estimators, you must perform the
@@ -287,7 +287,7 @@ for key in train_x.keys():
```
Feature columns can be far more sophisticated than those we're showing here. We
-detail feature columns @{$feature_columns$later on} in our Getting
+detail feature columns [later on](../guide/feature_columns.md) in our Getting
Started guide.
Now that we have the description of how we want the model to represent the raw
@@ -423,8 +423,8 @@ Pre-made Estimators are an effective way to quickly create standard models.
Now that you've gotten started writing TensorFlow programs, consider the
following material:
-* @{$checkpoints$Checkpoints} to learn how to save and restore models.
-* @{$guide/datasets_for_estimators} to learn more about importing
+* [Checkpoints](../guide/checkpoints.md) to learn how to save and restore models.
+* [Datasets for Estimators](../guide/datasets_for_estimators.md) to learn more about importing
data into your model.
-* @{$custom_estimators$Creating Custom Estimators} to learn how to
+* [Creating Custom Estimators](../guide/custom_estimators.md) to learn how to
write your own Estimator, customized for a particular problem.
diff --git a/tensorflow/docs_src/guide/saved_model.md b/tensorflow/docs_src/guide/saved_model.md
index c260da7966..6c967fd882 100644
--- a/tensorflow/docs_src/guide/saved_model.md
+++ b/tensorflow/docs_src/guide/saved_model.md
@@ -7,7 +7,7 @@ automatically save and restore variables in the `model_dir`.
## Save and restore variables
-TensorFlow @{$variables} are the best way to represent shared, persistent state
+TensorFlow [Variables](../guide/variables.md) are the best way to represent shared, persistent state
manipulated by your program. The `tf.train.Saver` constructor adds `save` and
`restore` ops to the graph for all, or a specified list, of the variables in the
graph. The `Saver` object provides methods to run these ops, specifying paths
@@ -274,7 +274,7 @@ Ops has not changed.
The `tf.saved_model.builder.SavedModelBuilder` class allows
users to control whether default-valued attributes must be stripped from the
-@{$extend/tool_developers#nodes$`NodeDefs`}
+[`NodeDefs`](../extend/tool_developers/index.md#nodes)
while adding a meta graph to the SavedModel bundle. Both
`tf.saved_model.builder.SavedModelBuilder.add_meta_graph_and_variables`
and `tf.saved_model.builder.SavedModelBuilder.add_meta_graph`
@@ -413,7 +413,7 @@ SavedModel format. This section explains how to:
### Prepare serving inputs
-During training, an @{$premade_estimators#input_fn$`input_fn()`} ingests data
+During training, an [`input_fn()`](../guide/premade_estimators.md#input_fn) ingests data
and prepares it for use by the model. At serving time, similarly, a
`serving_input_receiver_fn()` accepts inference requests and prepares them for
the model. This function has the following purposes:
@@ -616,7 +616,7 @@ result = stub.Classify(request, 10.0) # 10 secs timeout
The returned result in this example is a `ClassificationResponse` protocol
buffer.
-This is a skeletal example; please see the @{$deploy$Tensorflow Serving}
+This is a skeletal example; please see the [Tensorflow Serving](../deploy/index.md)
documentation and [examples](https://github.com/tensorflow/serving/tree/master/tensorflow_serving/example)
for more details.
@@ -647,7 +647,7 @@ You can use the SavedModel Command Line Interface (CLI) to inspect and
execute a SavedModel.
For example, you can use the CLI to inspect the model's `SignatureDef`s.
The CLI enables you to quickly confirm that the input
-@{$tensors$Tensor dtype and shape} match the model. Moreover, if you
+[Tensor dtype and shape](../guide/tensors.md) match the model. Moreover, if you
want to test your model, you can use the CLI to do a sanity check by
passing in sample inputs in various formats (for example, Python
expressions) and then fetching the output.
diff --git a/tensorflow/docs_src/guide/summaries_and_tensorboard.md b/tensorflow/docs_src/guide/summaries_and_tensorboard.md
index 6177c3393b..788c556b9d 100644
--- a/tensorflow/docs_src/guide/summaries_and_tensorboard.md
+++ b/tensorflow/docs_src/guide/summaries_and_tensorboard.md
@@ -36,7 +36,7 @@ lifecycle for summary data within TensorBoard.
First, create the TensorFlow graph that you'd like to collect summary
data from, and decide which nodes you would like to annotate with
-@{$python/summary$summary operations}.
+[summary operations](../api_guides/python/summary.md).
For example, suppose you are training a convolutional neural network for
recognizing MNIST digits. You'd like to record how the learning rate
@@ -53,7 +53,7 @@ this data by attaching
the gradient outputs and to the variable that holds your weights, respectively.
For details on all of the summary operations available, check out the docs on
-@{$python/summary$summary operations}.
+[summary operations](../api_guides/python/summary.md).
Operations in TensorFlow don't do anything until you run them, or an op that
depends on their output. And the summary nodes that we've just created are
@@ -74,7 +74,7 @@ Also, the `FileWriter` can optionally take a `Graph` in its constructor.
If it receives a `Graph` object, then TensorBoard will visualize your graph
along with tensor shape information. This will give you a much better sense of
what flows through the graph: see
-@{$graph_viz#tensor-shape-information$Tensor shape information}.
+[Tensor shape information](../guide/graph_viz.md#tensor-shape-information).
Now that you've modified your graph and have a `FileWriter`, you're ready to
start running your network! If you want, you could run the merged summary op
@@ -219,7 +219,7 @@ When looking at TensorBoard, you will see the navigation tabs in the top right
corner. Each tab represents a set of serialized data that can be visualized.
For in depth information on how to use the *graph* tab to visualize your graph,
-see @{$graph_viz$TensorBoard: Graph Visualization}.
+see [TensorBoard: Graph Visualization](../guide/graph_viz.md).
For more usage information on TensorBoard in general, see the
[TensorBoard GitHub](https://github.com/tensorflow/tensorboard).
diff --git a/tensorflow/docs_src/guide/tensors.md b/tensorflow/docs_src/guide/tensors.md
index 6b5a110a1c..4f0ddb21b5 100644
--- a/tensorflow/docs_src/guide/tensors.md
+++ b/tensorflow/docs_src/guide/tensors.md
@@ -298,7 +298,7 @@ to call `tf.train.start_queue_runners` before evaluating any `tf.Tensor`s.
## Printing Tensors
For debugging purposes you might want to print the value of a `tf.Tensor`. While
- @{$debugger$tfdbg} provides advanced debugging support, TensorFlow also has an
+ [tfdbg](../guide/debugger.md) provides advanced debugging support, TensorFlow also has an
operation to directly print the value of a `tf.Tensor`.
Note that you rarely want to use the following pattern when printing a
diff --git a/tensorflow/docs_src/guide/using_gpu.md b/tensorflow/docs_src/guide/using_gpu.md
index c0218fd12e..8cb9b354c7 100644
--- a/tensorflow/docs_src/guide/using_gpu.md
+++ b/tensorflow/docs_src/guide/using_gpu.md
@@ -211,5 +211,5 @@ AddN: /job:localhost/replica:0/task:0/cpu:0
[ 98. 128.]]
```
-The @{$deep_cnn$cifar10 tutorial} is a good example
+The [cifar10 tutorial](../tutorials/images/deep_cnn.md) is a good example
demonstrating how to do training with multiple GPUs.
diff --git a/tensorflow/docs_src/guide/using_tpu.md b/tensorflow/docs_src/guide/using_tpu.md
index 90a663b75e..59b34e19e0 100644
--- a/tensorflow/docs_src/guide/using_tpu.md
+++ b/tensorflow/docs_src/guide/using_tpu.md
@@ -22,8 +22,8 @@ Standard `Estimators` can drive models on CPU and GPUs. You must use
`tf.contrib.tpu.TPUEstimator` to drive a model on TPUs.
Refer to TensorFlow's Getting Started section for an introduction to the basics
-of using a @{$premade_estimators$pre-made `Estimator`}, and
-@{$custom_estimators$custom `Estimator`s}.
+of using a [pre-made `Estimator`](../guide/premade_estimators.md), and
+[custom `Estimator`s](../guide/custom_estimators.md).
The `TPUEstimator` class differs somewhat from the `Estimator` class.
@@ -171,9 +171,9 @@ This section details the changes you must make to the model function
During regular usage TensorFlow attempts to determine the shapes of each
`tf.Tensor` during graph construction. During execution any unknown shape
dimensions are determined dynamically,
-see @{$guide/tensors#shape$Tensor Shapes} for more details.
+see [Tensor Shapes](../guide/tensors.md#shape) for more details.
-To run on Cloud TPUs TensorFlow models are compiled using @{$xla$XLA}.
+To run on Cloud TPUs TensorFlow models are compiled using [XLA](../performance/xla/index.md).
XLA uses a similar system for determining shapes at compile time. XLA requires
that all tensor dimensions be statically defined at compile time. All shapes
must evaluate to a constant, and not depend on external data, or stateful
@@ -184,7 +184,7 @@ operations like variables or a random number generator.
Remove any use of `tf.summary` from your model.
-@{$summaries_and_tensorboard$TensorBoard summaries} are a great way see inside
+[TensorBoard summaries](../guide/summaries_and_tensorboard.md) are a great way see inside
your model. A minimal set of basic summaries are automatically recorded by the
`TPUEstimator`, to `event` files in the `model_dir`. Custom summaries, however,
are currently unsupported when training on a Cloud TPU. So while the
@@ -343,7 +343,7 @@ weight when creating your `tf.metrics`.
Efficient use of the `tf.data.Dataset` API is critical when using a Cloud
TPU, as it is impossible to use the Cloud TPU's unless you can feed it data
-quickly enough. See @{$datasets_performance} for details on dataset performance.
+quickly enough. See [Input Pipeline Performance Guide](../performance/datasets_performance.md) for details on dataset performance.
For all but the simplest experimentation (using
`tf.data.Dataset.from_tensor_slices` or other in-graph data) you will need to
@@ -361,7 +361,7 @@ Small datasets can be loaded entirely into memory using
`tf.data.Dataset.cache`.
Regardless of the data format used, it is strongly recommended that you
-@{$performance_guide#use_large_files$use large files}, on the order of
+[use large files](../performance/performance_guide.md#use_large_files), on the order of
100MB. This is especially important in this networked setting as the overhead
of opening a file is significantly higher.
@@ -391,5 +391,5 @@ to make a Cloud TPU compatible model are the example models published in:
For more information about tuning TensorFlow code for performance see:
- * The @{$performance$Performance Section.}
+ * The [Performance Section.](../performance/index.md)
diff --git a/tensorflow/docs_src/guide/version_compat.md b/tensorflow/docs_src/guide/version_compat.md
index 29ac066e6f..de93d225e3 100644
--- a/tensorflow/docs_src/guide/version_compat.md
+++ b/tensorflow/docs_src/guide/version_compat.md
@@ -38,6 +38,9 @@ patch versions. The public APIs consist of
`tensorflow` module and its submodules, except for
* functions and classes in `tf.contrib`
* functions and classes whose names start with `_` (as these are private)
+ * functions, arguments, properties and classes whose name starts with
+ `experimental`, or whose fully qualified name includes a module called
+ `experimental`
Note that the code in the `examples/` and `tools/` directories is not
reachable through the `tensorflow` Python module and is thus not covered by
the compatibility guarantee.
@@ -75,7 +78,7 @@ backward incompatible ways between minor releases. These include:
* **Other languages**: TensorFlow APIs in languages other than Python and C,
such as:
- - @{$cc/guide$C++} (exposed through header files in
+ - [C++](../api_guides/cc/guide.md) (exposed through header files in
[`tensorflow/cc`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/cc)).
- [Java](../api_docs/java/reference/org/tensorflow/package-summary),
- [Go](https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go)
@@ -98,7 +101,7 @@ backward incompatible ways between minor releases. These include:
accuracy for the overall system.
* **Random numbers:** The specific random numbers computed by the
- @{$python/constant_op#Random_Tensors$random ops} may change at any time.
+ [random ops](../api_guides/python/constant_op.md#Random_Tensors) may change at any time.
Users should rely only on approximately correct distributions and
statistical strength, not the specific bits computed. However, we will make
changes to random bits rarely (or perhaps never) for patch releases. We
@@ -175,6 +178,8 @@ This section is relevant only when making incompatible changes to the `GraphDef`
format, such as when adding ops, removing ops, or changing the functionality
of existing ops. The previous section should suffice for most users.
+<a id="backward_forward"/>
+
### Backward and partial forward compatibility
Our versioning scheme has three requirements:
diff --git a/tensorflow/docs_src/install/index.md b/tensorflow/docs_src/install/index.md
index 55481cc400..76e590e1e1 100644
--- a/tensorflow/docs_src/install/index.md
+++ b/tensorflow/docs_src/install/index.md
@@ -17,23 +17,23 @@ systems listed above.
The following guides explain how to install a version of TensorFlow
that enables you to write applications in Python:
- * @{$install_linux$Install TensorFlow on Ubuntu}
- * @{$install_mac$Install TensorFlow on macOS}
- * @{$install_windows$Install TensorFlow on Windows}
- * @{$install_raspbian$Install TensorFlow on a Raspberry Pi}
- * @{$install_sources$Install TensorFlow from source code}
+ * [Install TensorFlow on Ubuntu](../install/install_linux.md)
+ * [Install TensorFlow on macOS](../install/install_mac.md)
+ * [Install TensorFlow on Windows](../install/install_windows.md)
+ * [Install TensorFlow on a Raspberry Pi](../install/install_raspbian.md)
+ * [Install TensorFlow from source code](../install/install_sources.md)
Many aspects of the Python TensorFlow API changed from version 0.n to 1.0.
The following guide explains how to migrate older TensorFlow applications
to Version 1.0:
- * @{$migration$Transition to TensorFlow 1.0}
+ * [Transition to TensorFlow 1.0](../install/migration.md)
The following guides explain how to install TensorFlow libraries for use in
other programming languages. These APIs are aimed at deploying TensorFlow
models in applications and are not as extensive as the Python APIs.
- * @{$install_java$Install TensorFlow for Java}
- * @{$install_c$Install TensorFlow for C}
- * @{$install_go$Install TensorFlow for Go}
+ * [Install TensorFlow for Java](../install/install_java.md)
+ * [Install TensorFlow for C](../install/install_c.md)
+ * [Install TensorFlow for Go](../install/install_go.md)
diff --git a/tensorflow/docs_src/install/install_c.md b/tensorflow/docs_src/install/install_c.md
index 4a63f11fca..084634bc9c 100644
--- a/tensorflow/docs_src/install/install_c.md
+++ b/tensorflow/docs_src/install/install_c.md
@@ -28,8 +28,8 @@ enable TensorFlow for C:
entitled "Determine which TensorFlow to install" in one of the
following guides:
- * @{$install_linux#determine_which_tensorflow_to_install$Installing TensorFlow on Linux}
- * @{$install_mac#determine_which_tensorflow_to_install$Installing TensorFlow on macOS}
+ * [Installing TensorFlow on Linux](../install/install_linux.md#determine_which_tensorflow_to_install)
+ * [Installing TensorFlow on macOS](../install/install_mac.md#determine_which_tensorflow_to_install)
2. Download and extract the TensorFlow C library into `/usr/local/lib` by
invoking the following shell commands:
diff --git a/tensorflow/docs_src/install/install_go.md b/tensorflow/docs_src/install/install_go.md
index f0f8436777..0c604d7713 100644
--- a/tensorflow/docs_src/install/install_go.md
+++ b/tensorflow/docs_src/install/install_go.md
@@ -29,8 +29,8 @@ steps to install this library and enable TensorFlow for Go:
the help of GPU(s). To help you decide, read the section entitled
"Determine which TensorFlow to install" in one of the following guides:
- * @{$install_linux#determine_which_tensorflow_to_install$Installing TensorFlow on Linux}
- * @{$install_mac#determine_which_tensorflow_to_install$Installing TensorFlow on macOS}
+ * [Installing TensorFlow on Linux](../install/install_linux.md#determine_which_tensorflow_to_install)
+ * [Installing TensorFlow on macOS](../install/install_mac.md#determine_which_tensorflow_to_install)
2. Download and extract the TensorFlow C library into `/usr/local/lib` by
invoking the following shell commands:
diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md
index c131a2ea76..c411cb78fe 100644
--- a/tensorflow/docs_src/install/install_java.md
+++ b/tensorflow/docs_src/install/install_java.md
@@ -135,7 +135,7 @@ instead:
GPU acceleration is available via Maven only for Linux and only if your system
meets the
-@{$install_linux#determine_which_tensorflow_to_install$requirements for GPU}.
+[requirements for GPU](../install/install_linux.md#determine_which_tensorflow_to_install).
## Using TensorFlow with JDK
@@ -155,8 +155,8 @@ Take the following steps to install TensorFlow for Java on Linux or macOS:
the help of GPU(s). To help you decide, read the section entitled
"Determine which TensorFlow to install" in one of the following guides:
- * @{$install_linux#determine_which_tensorflow_to_install$Installing TensorFlow on Linux}
- * @{$install_mac#determine_which_tensorflow_to_install$Installing TensorFlow on macOS}
+ * [Installing TensorFlow on Linux](../install/install_linux.md#determine_which_tensorflow_to_install)
+ * [Installing TensorFlow on macOS](../install/install_mac.md#determine_which_tensorflow_to_install)
3. Download and extract the appropriate Java Native Interface (JNI)
file for your operating system and processor support by running the
diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md
index 0febdee99f..5fcfa4b988 100644
--- a/tensorflow/docs_src/install/install_linux.md
+++ b/tensorflow/docs_src/install/install_linux.md
@@ -520,7 +520,7 @@ The following NVIDIA® <i>software</i> must be installed on your system:
To use a GPU with CUDA Compute Capability 3.0, or different versions of the
preceding NVIDIA libraries see
-@{$install_sources$installing TensorFlow from Sources}. If using Ubuntu 16.04
+[installing TensorFlow from Sources](../install/install_sources.md). If using Ubuntu 16.04
and possibly other Debian based linux distros, `apt-get` can be used with the
NVIDIA repository to simplify installation.
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md
index dfd9fbce4b..e8e13142e9 100644
--- a/tensorflow/docs_src/install/install_sources.md
+++ b/tensorflow/docs_src/install/install_sources.md
@@ -189,7 +189,7 @@ Note: These are just the minimum requirements to _build_ tensorflow. Installing
the pip package will download additional packages required to _run_ it. If you
plan on executing tasks directly with `bazel` , without the pip installation,
you may need to install additional python packages. For example, you should `pip
-install mock enum34` before running TensorFlow's tests with bazel.
+install enum34` before running TensorFlow's tests with bazel.
<a name="ConfigureInstallation"></a>
@@ -364,6 +364,8 @@ continue to work against your built package.
If RAM is an issue on your system, you may limit RAM usage by specifying
<code>--local_resources 2048,.5,1.0</code> while invoking `bazel`.
+### Run the build_pip_package script
+
The <code>bazel build</code> command builds a script named `build_pip_package`.
Running this script as follows will build a `.whl` file within the
`/tmp/tensorflow_pkg` directory:
diff --git a/tensorflow/docs_src/performance/index.md b/tensorflow/docs_src/performance/index.md
index 131d28fa3e..a0f26a8c3a 100644
--- a/tensorflow/docs_src/performance/index.md
+++ b/tensorflow/docs_src/performance/index.md
@@ -7,18 +7,18 @@ details on the high level APIs to use along with best practices to build
and train high performance models, and quantize models for the least latency
and highest throughput for inference.
- * @{$performance_guide$Performance Guide} contains a collection of best
+ * [Performance Guide](../performance/performance_guide.md) contains a collection of best
practices for optimizing your TensorFlow code.
- * @{$datasets_performance$Data input pipeline guide} describes the tf.data
+ * [Data input pipeline guide](../performance/datasets_performance.md) describes the tf.data
API for building efficient data input pipelines for TensorFlow.
- * @{$performance/benchmarks$Benchmarks} contains a collection of
+ * [Benchmarks](../performance/benchmarks.md) contains a collection of
benchmark results for a variety of hardware configurations.
* For improving inference efficiency on mobile and
embedded hardware, see
- @{$quantization$How to Quantize Neural Networks with TensorFlow}, which
+ [How to Quantize Neural Networks with TensorFlow](../performance/quantization.md), which
explains how to use quantization to reduce model size, both in storage
and at runtime.
@@ -31,20 +31,20 @@ XLA (Accelerated Linear Algebra) is an experimental compiler for linear
algebra that optimizes TensorFlow computations. The following guides explore
XLA:
- * @{$xla$XLA Overview}, which introduces XLA.
- * @{$broadcasting$Broadcasting Semantics}, which describes XLA's
+ * [XLA Overview](../performance/xla/index.md), which introduces XLA.
+ * [Broadcasting Semantics](../performance/xla/broadcasting.md), which describes XLA's
broadcasting semantics.
- * @{$developing_new_backend$Developing a new back end for XLA}, which
+ * [Developing a new back end for XLA](../performance/xla/developing_new_backend.md), which
explains how to re-target TensorFlow in order to optimize the performance
of the computational graph for particular hardware.
- * @{$jit$Using JIT Compilation}, which describes the XLA JIT compiler that
+ * [Using JIT Compilation](../performance/xla/jit.md), which describes the XLA JIT compiler that
compiles and runs parts of TensorFlow graphs via XLA in order to optimize
performance.
- * @{$operation_semantics$Operation Semantics}, which is a reference manual
+ * [Operation Semantics](../performance/xla/operation_semantics.md), which is a reference manual
describing the semantics of operations in the `ComputationBuilder`
interface.
- * @{$shapes$Shapes and Layout}, which details the `Shape` protocol buffer.
- * @{$tfcompile$Using AOT compilation}, which explains `tfcompile`, a
+ * [Shapes and Layout](../performance/xla/shapes.md), which details the `Shape` protocol buffer.
+ * [Using AOT compilation](../performance/xla/tfcompile.md), which explains `tfcompile`, a
standalone tool that compiles TensorFlow graphs into executable code in
order to optimize performance.
diff --git a/tensorflow/docs_src/performance/performance_guide.md b/tensorflow/docs_src/performance/performance_guide.md
index df70309568..9ea1d6a705 100644
--- a/tensorflow/docs_src/performance/performance_guide.md
+++ b/tensorflow/docs_src/performance/performance_guide.md
@@ -41,7 +41,7 @@ approaches to identifying issues:
utilization is not approaching 80-100%, then the input pipeline may be the
bottleneck.
* Generate a timeline and look for large blocks of white space (waiting). An
- example of generating a timeline exists as part of the @{$jit$XLA JIT}
+ example of generating a timeline exists as part of the [XLA JIT](../performance/xla/jit.md)
tutorial.
* Check CPU usage. It is possible to have an optimized input pipeline and lack
the CPU cycles to process the pipeline.
@@ -68,7 +68,7 @@ the CPU.
#### Using the tf.data API
-The @{$datasets$tf.data API} is replacing `queue_runner` as the recommended API
+The [tf.data API](../guide/datasets.md) is replacing `queue_runner` as the recommended API
for building input pipelines. This
[ResNet example](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10_estimator/cifar10_main.py)
([arXiv:1512.03385](https://arxiv.org/abs/1512.03385))
@@ -78,7 +78,7 @@ training CIFAR-10 illustrates the use of the `tf.data` API along with
The `tf.data` API utilizes C++ multi-threading and has a much lower overhead
than the Python-based `queue_runner` that is limited by Python's multi-threading
performance. A detailed performance guide for the `tf.data` API can be found
-@{$datasets_performance$here}.
+[here](../performance/datasets_performance.md).
While feeding data using a `feed_dict` offers a high level of flexibility, in
general `feed_dict` does not provide a scalable solution. If only a single GPU
@@ -174,7 +174,7 @@ faster using `NHWC` than the normally most efficient `NCHW`.
### Common fused Ops
Fused Ops combine multiple operations into a single kernel for improved
-performance. There are many fused Ops within TensorFlow and @{$xla$XLA} will
+performance. There are many fused Ops within TensorFlow and [XLA](../performance/xla/index.md) will
create fused Ops when possible to automatically improve performance. Collected
below are select fused Ops that can greatly improve performance and may be
overlooked.
@@ -257,7 +257,7 @@ the CPU in use. Speedups for training and inference on CPU are documented below
in [Comparing compiler optimizations](#comparing-compiler-optimizations).
To install the most optimized version of TensorFlow,
-@{$install_sources$build and install} from source. If there is a need to build
+[build and install](../install/install_sources.md) from source. If there is a need to build
TensorFlow on a platform that has different hardware than the target, then
cross-compile with the highest optimizations for the target platform. The
following command is an example of using `bazel` to compile for a specific
@@ -298,7 +298,7 @@ each of the towers. How each tower gets the updated variables and how the
gradients are applied has an impact on the performance, scaling, and convergence
of the model. The rest of this section provides an overview of variable
placement and the towering of a model on multiple GPUs.
-@{$performance_models$High-Performance Models} gets into more details regarding
+[High-Performance Models](../performance/performance_models.md) gets into more details regarding
more complex methods that can be used to share and update variables between
towers.
@@ -307,7 +307,7 @@ and even how the hardware has been configured. An example of this, is that two
systems can be built with NVIDIA Tesla P100s but one may be using PCIe and the
other [NVLink](http://www.nvidia.com/object/nvlink.html). In that scenario, the
optimal solution for each system may be different. For real world examples, read
-the @{$performance/benchmarks$benchmark} page which details the settings that
+the [benchmark](../performance/benchmarks.md) page which details the settings that
were optimal for a variety of platforms. Below is a summary of what was learned
from benchmarking various platforms and configurations:
@@ -433,7 +433,7 @@ scenarios.
## Optimizing for CPU
CPUs, which includes Intel® Xeon Phi™, achieve optimal performance when
-TensorFlow is @{$install_sources$built from source} with all of the instructions
+TensorFlow is [built from source](../install/install_sources.md) with all of the instructions
supported by the target CPU.
Beyond using the latest instruction sets, Intel® has added support for the
diff --git a/tensorflow/docs_src/performance/performance_models.md b/tensorflow/docs_src/performance/performance_models.md
index 66bf684d5b..151c0b2946 100644
--- a/tensorflow/docs_src/performance/performance_models.md
+++ b/tensorflow/docs_src/performance/performance_models.md
@@ -9,7 +9,7 @@ incorporated into high-level APIs.
## Input Pipeline
-The @{$performance_guide$Performance Guide} explains how to identify possible
+The [Performance Guide](../performance/performance_guide.md) explains how to identify possible
input pipeline issues and best practices. We found that using `tf.FIFOQueue`
and `tf.train.queue_runner` could not saturate multiple current generation GPUs
when using large inputs and processing with higher samples per second, such
diff --git a/tensorflow/docs_src/performance/quantization.md b/tensorflow/docs_src/performance/quantization.md
index 4499f5715c..3326d82964 100644
--- a/tensorflow/docs_src/performance/quantization.md
+++ b/tensorflow/docs_src/performance/quantization.md
@@ -80,7 +80,7 @@ need for a separate calibration step.
TensorFlow can train models with quantization in the loop. Because training
requires small gradient adjustments, floating point values are still used. To
keep models as floating point while adding the quantization error in the training
-loop, @{$array_ops#Fake_quantization$fake quantization} nodes simulate the
+loop, [fake quantization](../api_guides/python/array_ops.md#Fake_quantization) nodes simulate the
effect of quantization in the forward and backward passes.
Since it's difficult to add these fake quantization operations to all the
diff --git a/tensorflow/docs_src/performance/xla/index.md b/tensorflow/docs_src/performance/xla/index.md
index 8f5de83ea6..770737c34c 100644
--- a/tensorflow/docs_src/performance/xla/index.md
+++ b/tensorflow/docs_src/performance/xla/index.md
@@ -14,7 +14,7 @@ XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear
algebra that optimizes TensorFlow computations. The results are improvements in
speed, memory usage, and portability on server and mobile platforms. Initially,
most users will not see large benefits from XLA, but are welcome to experiment
-by using XLA via @{$jit$just-in-time (JIT) compilation} or @{$tfcompile$ahead-of-time (AOT) compilation}. Developers targeting new hardware accelerators are
+by using XLA via [just-in-time (JIT) compilation](../../performance/xla/jit.md) or [ahead-of-time (AOT) compilation](../../performance/xla/tfcompile.md). Developers targeting new hardware accelerators are
especially encouraged to try out XLA.
The XLA framework is experimental and in active development. In particular,
@@ -54,13 +54,13 @@ We had several objectives for XLA to work with TensorFlow:
The input language to XLA is called "HLO IR", or just HLO (High Level
Optimizer). The semantics of HLO are described on the
-@{$operation_semantics$Operation Semantics} page. It
+[Operation Semantics](../../performance/xla/operation_semantics.md) page. It
is most convenient to think of HLO as a [compiler
IR](https://en.wikipedia.org/wiki/Intermediate_representation).
XLA takes graphs ("computations") defined in HLO and compiles them into machine
instructions for various architectures. XLA is modular in the sense that it is
-easy to slot in an alternative backend to @{$developing_new_backend$target some novel HW architecture}. The CPU backend for x64 and ARM64 as
+easy to slot in an alternative backend to [target some novel HW architecture](../../performance/xla/developing_new_backend.md). The CPU backend for x64 and ARM64 as
well as the NVIDIA GPU backend are in the TensorFlow source tree.
The following diagram shows the compilation process in XLA:
@@ -94,5 +94,5 @@ CPU backend supports multiple CPU ISAs.
## Supported Platforms
-XLA currently supports @{$jit$JIT compilation} on x86-64 and NVIDIA GPUs; and
-@{$tfcompile$AOT compilation} for x86-64 and ARM.
+XLA currently supports [JIT compilation](../../performance/xla/jit.md) on x86-64 and NVIDIA GPUs; and
+[AOT compilation](../../performance/xla/tfcompile.md) for x86-64 and ARM.
diff --git a/tensorflow/docs_src/performance/xla/jit.md b/tensorflow/docs_src/performance/xla/jit.md
index 7202ef47f7..83b3e71566 100644
--- a/tensorflow/docs_src/performance/xla/jit.md
+++ b/tensorflow/docs_src/performance/xla/jit.md
@@ -133,7 +133,7 @@ Execute the python script to train the model with XLA and turn on a debugging
feature of XLA via an environmental variable that outputs the XLA graph.
```shell
-TF_XLA_FLAGS=--xla_generate_hlo_graph=.* python mnist_softmax_xla.py
+TF_XLA_FLAGS="--xla_hlo_graph_path=/tmp --xla_generate_hlo_graph=.*" python mnist_softmax_xla.py
```
Open the timeline file created (`timeline.ctf.json`). The rendered timeline
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index e24a7cda73..2de30d1b3d 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -505,16 +505,17 @@ Computes a convolution of the kind used in neural networks. Here, a convolution
can be thought of as a n-dimensional window moving across a n-dimensional base
area and a computation is performed for each possible position of the window.
-| Arguments | Type | Semantics |
-| ---------------- | ----------------------- | ----------------------------- |
-| `lhs` | `XlaOp` | rank n+2 array of inputs |
-| `rhs` | `XlaOp` | rank n+2 array of kernel |
-: : : weights :
-| `window_strides` | `ArraySlice<int64>` | n-d array of kernel strides |
-| `padding` | `ArraySlice<pair<int64, | n-d array of (low, high) |
-: : int64>>` : padding :
-| `lhs_dilation` | `ArraySlice<int64>` | n-d lhs dilation factor array |
-| `rhs_dilation` | `ArraySlice<int64>` | n-d rhs dilation factor array |
+| Arguments | Type | Semantics |
+| --------------------- | -------------------- | ----------------------------- |
+| `lhs` | `XlaOp` | rank n+2 array of inputs |
+| `rhs` | `XlaOp` | rank n+2 array of kernel |
+: : : weights :
+| `window_strides` | `ArraySlice<int64>` | n-d array of kernel strides |
+| `padding` | `ArraySlice< | n-d array of (low, high) |
+: : pair<int64, int64>>` : padding :
+| `lhs_dilation` | `ArraySlice<int64>` | n-d lhs dilation factor array |
+| `rhs_dilation` | `ArraySlice<int64>` | n-d rhs dilation factor array |
+| `feature_group_count` | int64 | the number of feature groups |
Let n be the number of spatial dimensions. The `lhs` argument is a rank n+2
array describing the base area. This is called the input, even though of course
@@ -532,8 +533,8 @@ The `rhs` argument is a rank n+2 array describing the convolutional
filter/kernel/window. The dimensions are, in this order:
* `output-z`: The `z` dimension of the output.
-* `input-z`: The size of this dimension should equal the size of the `z`
- dimension in lhs.
+* `input-z`: The size of this dimension times `feature_group_count` should
+ equal the size of the `z` dimension in lhs.
* `spatial_dims`: Describes the `n` spatial dimensions that define the n-d
window that moves across the base area.
@@ -566,6 +567,24 @@ Dilation of the rhs is also called atrous convolution. For more details, see
`tf.nn.atrous_conv2d`. Dilation of the lhs is also called transposed
convolution. For more details, see `tf.nn.conv2d_transpose`.
+The `feature_group_count` argument (default value 1) can be used for grouped
+convolutions. `feature_group_count` needs to be a divisor of both the input and
+the output feature dimension. If `feature_group_count` is greater than 1, it
+means that conceptually the input and output feature dimension and the `rhs`
+output feature dimension are split evenly into `feature_group_count` many
+groups, each group consisting of a consecutive subsequence of features. The
+input feature dimension of `rhs` needs to be equal to the `lhs` input feature
+dimension divided by `feature_group_count` (so it already has the size of a
+group of input features). The i-th groups are used together to compute
+`feature_group_count` many separate convolutions. The results of these
+convolutions are concatenated together in the output feature dimension.
+
+For depthwise convolution the `feature_group_count` argument would be set to the
+input feature dimension, and the filter would be reshaped from
+`[filter_height, filter_width, in_channels, channel_multiplier]` to
+`[filter_height, filter_width, 1, in_channels * channel_multiplier]`. For more
+details, see `tf.nn.depthwise_conv2d`.
+
The output shape has these dimensions, in this order:
* `batch`: Same size as `batch` on the input (`lhs`).
@@ -1009,7 +1028,7 @@ Arguments | Type | Semantics
`rhs` | `XlaOp` | right-hand-side operand: array of type T
The arguments' shapes have to be either similar or compatible. See the
-@{$broadcasting$broadcasting} documentation about what it means for shapes to
+[broadcasting](../../performance/xla/broadcasting.md) documentation about what it means for shapes to
be compatible. The result of an operation has a shape which is the result of
broadcasting the two input arrays. In this variant, operations between arrays of
different ranks are *not* supported, unless one of the operands is a scalar.
@@ -1033,7 +1052,7 @@ the dimensions of the higher-rank shape. The unmapped dimensions of the expanded
shape are filled with dimensions of size one. Degenerate-dimension broadcasting
then broadcasts the shapes along these degenerate dimensions to equalize the
shapes of both operands. The semantics are described in detail on the
-@{$broadcasting$broadcasting page}.
+[broadcasting page](../../performance/xla/broadcasting.md).
## Element-wise comparison operations
@@ -1056,7 +1075,7 @@ Arguments | Type | Semantics
`rhs` | `XlaOp` | right-hand-side operand: array of type T
The arguments' shapes have to be either similar or compatible. See the
-@{$broadcasting$broadcasting} documentation about what it means for shapes to
+[broadcasting](../../performance/xla/broadcasting.md) documentation about what it means for shapes to
be compatible. The result of an operation has a shape which is the result of
broadcasting the two input arrays with the element type `PRED`. In this variant,
operations between arrays of different ranks are *not* supported, unless one of
@@ -1073,7 +1092,7 @@ matrix to a vector).
The additional `broadcast_dimensions` operand is a slice of integers specifying
the dimensions to use for broadcasting the operands. The semantics are described
-in detail on the @{$broadcasting$broadcasting page}.
+in detail on the [broadcasting page](../../performance/xla/broadcasting.md).
## Element-wise unary functions
@@ -1119,7 +1138,7 @@ array with the same shape. It is allowed for `operand` to be a scalar (rank 0).
## Gather
The XLA gather operation stitches together several slices (each slice at a
-potentially different runtime offset) of an input tensor into an output tensor.
+potentially different runtime offset) of an input array.
### General Semantics
@@ -1127,151 +1146,141 @@ See also
[`XlaBuilder::Gather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
For a more intuitive description, see the "Informal Description" section below.
-<b> `gather(operand, gather_indices, output_window_dims, elided_window_dims, window_bounds, gather_dims_to_operand_dims)` </b>
+<b> `gather(operand, start_indices, offset_dims, collapsed_slice_dims, slice_sizes, start_index_map)` </b>
|Arguments | Type | Semantics |
|----------------- | ----------------------- | --------------------------------|
-|`operand` | `XlaOp` | The tensor we’re gathering |
+|`operand` | `XlaOp` | The array we’re gathering |
: : : from. :
-|`gather_indices` | `XlaOp` | Tensor containing the starting |
-: : : indices of the slices we're :
-: : : stitching together into the :
-: : : output tensor. :
-|`index_vector_dim` | `int64` | The dimension in |
-: : : `gather_indices` that contains :
-: : : the starting indices. :
-|`output_window_dims` | `ArraySlice<int64>` | The set of dimensions in the |
-: : : output shape that are _window :
-: : : dimensions_ (defined below). :
-: : : Not all window dimensions may :
-: : : be present in the output shape. :
-|`elided_window_dims` | `ArraySlice<int64>` | The set of _window dimensions_ |
-: : : that are not present in the output shape. :
-: : : `window_bounds[i]` must be `1` for all `i` :
-: : : in `elided_window_dims`. :
-|`window_bounds` | `ArraySlice<int64>` | `window_bounds[i]` is the bounds |
-: : : for window dimension `i`. This includes :
-: : : both the window dimensions that are :
-: : : explicitly part of the output shape (via :
-: : : `output_window_dims`) and the window :
-: : : dimensions that are elided (via :
-: : : `elided_window_dims`). :
-|`gather_dims_to_operand_dims` | `ArraySlice<int64>` | A dimension map (the |
-: : : array is interpreted as mapping `i` to :
-: : : `gather_dims_to_operand_dims[i]`) from :
-: : : the gather indices in `gather_indices` to :
-: : : the operand index space. It has to be :
-: : : one-to-one and total. :
-
-For every index `Out` in the output tensor, we compute two things (more
-precisely described later):
-
- - An index into `gather_indices.rank` - `1` dimensions of `gather_indices`,
- which gives us a starting index of a slice, _operand slice_, in the operand
- tensor. These `gather_indices.rank` - `1` dimensions are all the dimensions
- in `gather_indices` except `index_vector_dim`.
-
- - A _window index_ that has the same rank as the operand. This index is
- composed of the values in `Out` at dimensions `output_window_dims`, embedded
- with zeroes according to `elided_window_dims`.
-
-The _window index_ is the relative index of the element in _operand slice_ that
-should be present in the output at index `Out`.
-
-The output is a tensor of rank `output_window_dims.size` + `gather_indices.rank`
-- `1`. Additionally, as a shorthand, we define `output_gather_dims` of type
-`ArraySlice<int64>` as the set of dimensions in the output shape but not in
-`output_window_dims`, in ascending order. E.g. if the output tensor has rank
-`5`, `output_window_dims` is {`2`, `4`} then `output_gather_dims` is {`0`, `1`,
-`3`}
-
-If `index_vector_dim` is equal to `gather_indices.rank` we implicitly
-consider `gather_indices` to have a trailing `1` dimension (i.e. if
-`gather_indices` was of shape `[6,7]` and `index_vector_dim` is `2` then
-we implicitly consider the shape of `gather_indices` to be `[6,7,1]`).
-
-The bounds for the output tensor along dimension `i` is computed as follows:
-
- 1. If `i` is present in `output_gather_dims` (i.e. is equal to
- `output_gather_dims[k]` for some `k`) then we pick the corresponding
- dimension bounds out of `gather_indices.shape`, skipping
- `index_vector_dim` (i.e. pick `gather_indices.shape.dims`[`k`] if `k`
- < `index_vector_dim` and `gather_indices.shape.dims`[`k`+`1`]
- otherwise).
- 2. If `i` is present in `output_window_dims` (i.e. equal to
- `output_window_dims`[`k`] for some `k`) then we pick the corresponding
- bound out of `window_bounds` after accounting for `elided_window_dims`
- (i.e. we pick `adjusted_window_bounds`[`k`] where `adjusted_window_bounds`
- is `window_bounds` with the bounds at indices `elided_window_dims`
- removed).
-
-The operand index `In` corresponding to an output index `Out` is computed as
-follows:
-
- 1. Let `G` = { `Out`[`k`] for `k` in `output_gather_dims` }. Use `G` to slice
- out vector `S` such that `S`[`i`] = `gather_indices`[Combine(`G`, `i`)]
- where Combine(A, b) inserts b at position `index_vector_dim` into A.
- Note that this is well defined even if `G` is empty -- if `G` is empty then
- `S` = `gather_indices`.
- 2. Create an index, `S`<sub>`in`</sub>, into `operand` using `S` by
- scattering `S` using the `gather_dims_to_operand_dims` map
- (`S`<sub>`in`</sub> is the starting indices for _operand slice_ mentioned
- above). More precisely:
- 1. `S`<sub>`in`</sub>[`gather_dims_to_operand_dims`[`k`]] = `S`[`k`] if `k` <
- `gather_dims_to_operand_dims.size`.
+|`start_indices` | `XlaOp` | Array containing the starting |
+: : : indices of the slices we gather.:
+|`index_vector_dim` | `int64` | The dimension in |
+: : : `start_indices` that "contains" :
+: : : the starting indices. See :
+: : : below for a detailed :
+: : : description. :
+|`offset_dims` | `ArraySlice<int64>` | The set of dimensions in the :
+: : : output shape that offset into a :
+: : : array sliced from operand. :
+|`slice_sizes` | `ArraySlice<int64>` | `slice_sizes[i]` is the bounds |
+: : : for the slice on dimension `i`.:
+|`collapsed_slice_dims` | `ArraySlice<int64>` | The set of dimensions in each :
+| : | slice that are collapsed away. :
+| : | These dimensions must have size:
+| : | 1. |
+|`start_index_map` | `ArraySlice<int64>` | A map that describes how to map|
+: : : indices in `start_indices` to :
+: : : to legal indices into operand. :
+
+For convenience, we label dimensions in the output array not in `offset_dims`
+as `batch_dims`.
+
+The output is an array of rank `batch_dims.size` + `operand.rank` -
+`collapsed_slice_dims`.size.
+
+If `index_vector_dim` is equal to `start_indices.rank` we implicitly consider
+`start_indices` to have a trailing `1` dimension (i.e. if `start_indices` was of
+shape `[6,7]` and `index_vector_dim` is `2` then we implicitly consider the
+shape of `start_indices` to be `[6,7,1]`).
+
+The bounds for the output array along dimension `i` is computed as follows:
+
+ 1. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for
+ some `k`) then we pick the corresponding dimension bounds out of
+ `start_indices.shape`, skipping `index_vector_dim` (i.e. pick
+ `start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and
+ `start_indices.shape.dims`[`k`+`1`] otherwise).
+
+ 2. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for
+ some `k`) then we pick the corresponding bound out of `slice_sizes` after
+ accounting for `collapsed_slice_dims` (i.e. we pick
+ `adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes`
+ with the bounds at indices `collapsed_slice_dims` removed).
+
+Formally, the operand index `In` corresponding to an output index `Out` is
+computed as follows:
+
+ 1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out
+ vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where
+ Combine(A, b) inserts b at position `index_vector_dim` into A. Note that
+ this is well defined even if `G` is empty -- if `G` is empty then `S` =
+ `start_indices`.
+
+ 2. Create a starting index, `S`<sub>`in`</sub>, into `operand` using `S` by
+ scattering `S` using `start_index_map`. More precisely:
+ 1. `S`<sub>`in`</sub>[`start_index_map`[`k`]] = `S`[`k`] if `k` <
+ `start_index_map.size`.
2. `S`<sub>`in`</sub>[`_`] = `0` otherwise.
- 3. Create an index `W`<sub>`in`</sub> into `operand` by scattering the indices
- at the output window dimensions in `Out` according to
- the `elided_window_dims` set (`W`<sub>`in`</sub> is the _window index_
- mentioned above). More precisely:
- 1. `W`<sub>`in`</sub>[`window_dims_to_operand_dims`(`k`)] = `Out`[`k`] if
- `k` < `output_window_dims.size` (`window_dims_to_operand_dims` is
- defined below).
- 2. `W`<sub>`in`</sub>[`_`] = `0` otherwise.
- 4. `In` is `W`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
+
+ 3. Create an index `O`<sub>`in`</sub> into `operand` by scattering the indices
+ at the offset dimensions in `Out` according to the `collapsed_slice_dims`
+ set. More precisely:
+ 1. `O`<sub>`in`</sub>[`expand_offset_dims`(`k`)] =
+ `Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size`
+ (`expand_offset_dims` is defined below).
+ 2. `O`<sub>`in`</sub>[`_`] = `0` otherwise.
+ 4. `In` is `O`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
addition.
-`window_dims_to_operand_dims` is the monotonic function with domain [`0`,
-`output_window_dims.size`) and range [`0`, `operand.rank`) \
-`elided_window_dims`. So if, e.g., `output_window_dims.size` is `4`,
-`operand.rank` is `6` and `elided_window_dims` is {`0`, `2`} then
-`window_dims_to_operand_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}.
+`expand_offset_dims` is the monotonic function with domain [`0`, `offset.size`)
+and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g.,
+`offset.size` is `4`, `operand.rank` is `6` and `collapsed_slice_dims` is {`0`,
+`2`} then `expand_offset_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}.
### Informal Description and Examples
-`index_vector_dim` is set to `gather_indices.rank` - `1` in all of the
-examples that follow. More interesting values for `index_vector_dim`
-does not change the operation fundamentally, but makes the visual representation
-more cumbersome.
+Informally, every index `Out` in the output array corresponds to an element `E`
+in the operand array, computed as follows:
+
+ - We use the batch dimensions in `Out` to look up a starting index from
+ `start_indices`.
+
+ - We use `start_index_map` to map the starting index (which may have size less
+ than operand.rank) to a "full" starting index into operand.
+
+ - We dynamic-slice out a slice with size `slice_sizes` using the full starting
+ index.
+
+ - We reshape the slice by collapsing the `collapsed_slice_dims` dimensions.
+ Since all collapsed slice dimensions have to have bound 1 this reshape is
+ always legal.
+
+ - We use the offset dimensions in `Out` to index into this slice to get the
+ input element, `E`, corresponding to output index `Out`.
+
+`index_vector_dim` is set to `start_indices.rank` - `1` in all of the
+examples that follow. More interesting values for `index_vector_dim` does not
+change the operation fundamentally, but makes the visual representation more
+cumbersome.
To get an intuition on how all of the above fits together, let's look at an
-example that gathers 5 slices of shape `[8,6]` from a `[16,11]` tensor. The
-position of a slice into the `[16,11]` tensor can be represented as an index
+example that gathers 5 slices of shape `[8,6]` from a `[16,11]` array. The
+position of a slice into the `[16,11]` array can be represented as an index
vector of shape `S64[2]`, so the set of 5 positions can be represented as a
-`S64[5,2]` tensor.
+`S64[5,2]` array.
The behavior of the gather operation can then be depicted as an index
-transformation that takes [`G`,`W`<sub>`0`</sub>,`W`<sub>`1`</sub>], an index in
-the output shape, and maps it to an element in the input tensor in the following
+transformation that takes [`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>], an index in
+the output shape, and maps it to an element in the input array in the following
way:
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../../images/ops_xla_gather_0.svg">
</div>
-We first select an (`X`,`Y`) vector from the gather indices tensor using `G`.
-The element in the output tensor at index
-[`G`,`W`<sub>`0`</sub>,`W`<sub>`1`</sub>] is then the element in the input
-tensor at index [`X`+`W`<sub>`0`</sub>,`Y`+`W`<sub>`1`</sub>].
+We first select an (`X`,`Y`) vector from the gather indices array using `G`.
+The element in the output array at index
+[`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>] is then the element in the input
+array at index [`X`+`O`<sub>`0`</sub>,`Y`+`O`<sub>`1`</sub>].
-`window_bounds` is `[8,6]`, which decides the range of W<sub>`0`</sub> and
+`slice_sizes` is `[8,6]`, which decides the range of W<sub>`0`</sub> and
W<sub>`1`</sub>, and this in turn decides the bounds of the slice.
This gather operation acts as a batch dynamic slice with `G` as the batch
dimension.
The gather indices may be multidimensional. For instance, a more general
-version of the example above using a "gather indices" tensor of shape `[4,5,2]`
+version of the example above using a "gather indices" array of shape `[4,5,2]`
would translate indices like this:
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
@@ -1279,25 +1288,25 @@ would translate indices like this:
</div>
Again, this acts as a batch dynamic slice `G`<sub>`0`</sub> and
-`G`<sub>`1`</sub> as the batch dimensions. The window bounds are still `[8,6]`.
+`G`<sub>`1`</sub> as the batch dimensions. The slice size is still `[8,6]`.
The gather operation in XLA generalizes the informal semantics outlined above in
the following ways:
- 1. We can configure which dimensions in the output shape are the window
- dimensions (dimensions containing `W`<sub>`0`</sub>, `W`<sub>`1`</sub> in
- the last example). The output gather dimensions (dimensions containing
+ 1. We can configure which dimensions in the output shape are the offset
+ dimensions (dimensions containing `O`<sub>`0`</sub>, `O`<sub>`1`</sub> in
+ the last example). The output batch dimensions (dimensions containing
`G`<sub>`0`</sub>, `G`<sub>`1`</sub> in the last example) are defined to be
- the output dimensions that are not window dimensions.
+ the output dimensions that are not offset dimensions.
- 2. The number of output window dimensions explicitly present in the output
+ 2. The number of output offset dimensions explicitly present in the output
shape may be smaller than the input rank. These "missing" dimensions, which
- are listed explicitly as `elided_window_dims`, must have a window bound of
- `1`. Since they have a window bound of `1` the only valid index for them is
+ are listed explicitly as `collapsed_slice_dims`, must have a slice size of
+ `1`. Since they have a slice size of `1` the only valid index for them is
`0` and eliding them does not introduce ambiguity.
- 3. The slice extracted from the "Gather Indices" tensor ((`X`, `Y`) in the last
- example) may have fewer elements than the input tensor rank, and an explicit
+ 3. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last
+ example) may have fewer elements than the input array rank, and an explicit
mapping dictates how the index should be expanded to have the same rank as
the input.
@@ -1308,20 +1317,19 @@ As a final example, we use (2) and (3) to implement `tf.gather_nd`:
</div>
`G`<sub>`0`</sub> and `G`<sub>`1`</sub> are used to slice out a starting index
-from the gather indices tensor as usual, except the starting index has only one
-element, `X`. Similarly, there is only one output window index with the value
-`W`<sub>`0`</sub>. However, before being used as indices into the input tensor,
-these are expanded in accordance to "Gather Index Mapping"
-(`gather_dims_to_operand_dims` in the formal description) and "Window Mapping"
-(`window_dims_to_operand_dims` in the formal description) into
-[`0`,`W`<sub>`0`</sub>] and [`X`,`0`] respectively, adding up to
-[`X`,`W`<sub>`0`</sub>]. In other words, the output index
-[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`W`<sub>`0`</sub>] maps to the input index
+from the gather indices array as usual, except the starting index has only one
+element, `X`. Similarly, there is only one output offset index with the value
+`O`<sub>`0`</sub>. However, before being used as indices into the input array,
+these are expanded in accordance to "Gather Index Mapping" (`start_index_map` in
+the formal description) and "Offset Mapping" (`expand_offset_dims` in the formal
+description) into [`0`,`O`<sub>`0`</sub>] and [`X`,`0`] respectively, adding up
+to [`X`,`O`<sub>`0`</sub>]. In other words, the output index
+[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`O`<sub>`0`</sub>] maps to the input index
[`GatherIndices`[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`0`],`X`] which gives us
the semantics for `tf.gather_nd`.
-`window_bounds` for this case is `[1,11]`. Intuitively this means that every
-index `X` in the gather indices tensor picks an entire row and the result is the
+`slice_sizes` for this case is `[1,11]`. Intuitively this means that every
+index `X` in the gather indices array picks an entire row and the result is the
concatenation of all these rows.
## GetTupleElement
diff --git a/tensorflow/docs_src/performance/xla/tfcompile.md b/tensorflow/docs_src/performance/xla/tfcompile.md
index e4b803164f..2e0f3774c4 100644
--- a/tensorflow/docs_src/performance/xla/tfcompile.md
+++ b/tensorflow/docs_src/performance/xla/tfcompile.md
@@ -17,7 +17,7 @@ kernels that are actually used in the computation.
The compiler is built on top of the XLA framework. The code bridging TensorFlow
to the XLA framework resides under
[tensorflow/compiler](https://www.tensorflow.org/code/tensorflow/compiler/),
-which also includes support for @{$jit$just-in-time (JIT) compilation} of
+which also includes support for [just-in-time (JIT) compilation](../../performance/xla/jit.md) of
TensorFlow graphs.
## What does tfcompile do?
@@ -116,7 +116,7 @@ tf_library(
> [make_test_graphs.py]("https://www.tensorflow.org/code/tensorflow/compiler/aot/tests/make_test_graphs.py")
> and specify the output location with the --out_dir flag.
-Typical graphs contain @{$python/state_ops$`Variables`}
+Typical graphs contain [`Variables`](../../api_guides/python/state_ops.md)
representing the weights that are learned via training, but `tfcompile` cannot
compile a subgraph that contain `Variables`. The
[freeze_graph.py](https://www.tensorflow.org/code/tensorflow/python/tools/freeze_graph.py)
diff --git a/tensorflow/docs_src/tutorials/_toc.yaml b/tensorflow/docs_src/tutorials/_toc.yaml
index 0e25208a00..c0b85497e0 100644
--- a/tensorflow/docs_src/tutorials/_toc.yaml
+++ b/tensorflow/docs_src/tutorials/_toc.yaml
@@ -37,6 +37,26 @@ toc:
status: external
- title: "Custom training: walkthrough"
path: /tutorials/eager/custom_training_walkthrough
+
+- title: ML at production scale
+ style: accordion
+ section:
+ - title: Linear model with Estimators
+ path: /tutorials/estimators/linear
+ - title: Wide and deep learning
+ path: https://github.com/tensorflow/models/tree/master/official/wide_deep
+ status: external
+ - title: Boosted trees
+ path: https://github.com/tensorflow/models/tree/master/official/boosted_trees
+ status: external
+ - title: Text classifier with TF-Hub
+ path: /hub/tutorials/text_classification_with_tf_hub
+ - title: Build a CNN using Estimators
+ path: /tutorials/estimators/cnn
+
+- title: Generative models
+ style: accordion
+ section:
- title: Text generation
path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
status: external
@@ -46,41 +66,25 @@ toc:
- title: Image captioning
path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
status: external
- - title: Neural Style Transfer
- path: https://github.com/tensorflow/models/blob/master/research/nst_blogpost/4_Neural_Style_Transfer_with_Eager_Execution.ipynb
- status: external
- title: DCGAN
path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb
status: external
- title: VAE
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb
status: external
+
+- title: Images
+ style: accordion
+ section:
- title: Pix2Pix
path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
status: external
+ - title: Neural Style Transfer
+ path: https://github.com/tensorflow/models/blob/master/research/nst_blogpost/4_Neural_Style_Transfer_with_Eager_Execution.ipynb
+ status: external
- title: Image Segmentation
path: https://github.com/tensorflow/models/blob/master/samples/outreach/blogs/segmentation_blogpost/image_segmentation.ipynb
status: external
-
-- title: ML at production scale
- style: accordion
- section:
- - title: Linear model with Estimators
- path: /tutorials/estimators/linear
- - title: Wide and deep learning
- path: https://github.com/tensorflow/models/tree/master/official/wide_deep
- status: external
- - title: Boosted trees
- path: https://github.com/tensorflow/models/tree/master/official/boosted_trees
- status: external
- - title: Text classifier with TF-Hub
- path: /hub/tutorials/text_classification_with_tf_hub
- - title: Build a CNN using Estimators
- path: /tutorials/estimators/cnn
-
-- title: Images
- style: accordion
- section:
- title: Image recognition
path: /tutorials/images/image_recognition
- title: Image retraining
diff --git a/tensorflow/docs_src/tutorials/eager/index.md b/tensorflow/docs_src/tutorials/eager/index.md
index a13b396094..887c820b85 100644
--- a/tensorflow/docs_src/tutorials/eager/index.md
+++ b/tensorflow/docs_src/tutorials/eager/index.md
@@ -10,4 +10,3 @@ auto&nbsp;differentiation. Start with these notebooks, then read the
3. <span>[Custom training: basics](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb){:.external}</span>
4. <span>[Custom layers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb){:.external}</span>
5. [Custom training: walkthrough](/tutorials/eager/custom_training_walkthrough)
-6. <span>[Advanced example: Neural machine translation with attention](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb){:.external}</span>
diff --git a/tensorflow/docs_src/tutorials/estimators/cnn.md b/tensorflow/docs_src/tutorials/estimators/cnn.md
index 100f501cc2..2fd69f50a0 100644
--- a/tensorflow/docs_src/tutorials/estimators/cnn.md
+++ b/tensorflow/docs_src/tutorials/estimators/cnn.md
@@ -190,7 +190,7 @@ def cnn_model_fn(features, labels, mode):
The following sections (with headings corresponding to each code block above)
dive deeper into the `tf.layers` code used to create each layer, as well as how
to calculate loss, configure the training op, and generate predictions. If
-you're already experienced with CNNs and @{$custom_estimators$TensorFlow `Estimator`s},
+you're already experienced with CNNs and [TensorFlow `Estimator`s](../../guide/custom_estimators.md),
and find the above code intuitive, you may want to skim these sections or just
skip ahead to ["Training and Evaluating the CNN MNIST Classifier"](#train_eval_mnist).
@@ -501,8 +501,8 @@ if mode == tf.estimator.ModeKeys.TRAIN:
```
> Note: For a more in-depth look at configuring training ops for Estimator model
-> functions, see @{$custom_estimators#defining-the-training-op-for-the-model$"Defining the training op for the model"}
-> in the @{$custom_estimators$"Creating Estimations in tf.estimator"} tutorial.
+> functions, see ["Defining the training op for the model"](../../guide/custom_estimators.md#defining-the-training-op-for-the-model)
+> in the ["Creating Estimations in tf.estimator"](../../guide/custom_estimators.md) tutorial.
### Add evaluation metrics
@@ -567,7 +567,7 @@ be saved (here, we specify the temp directory `/tmp/mnist_convnet_model`, but
feel free to change to another directory of your choice).
> Note: For an in-depth walkthrough of the TensorFlow `Estimator` API, see the
-> tutorial @{$custom_estimators$"Creating Estimators in tf.estimator."}
+> tutorial ["Creating Estimators in tf.estimator."](../../guide/custom_estimators.md)
### Set Up a Logging Hook {#set_up_a_logging_hook}
@@ -593,8 +593,8 @@ operation earlier when we generated the probabilities in `cnn_model_fn`.
> Note: If you don't explicitly assign a name to an operation via the `name`
> argument, TensorFlow will assign a default name. A couple easy ways to
> discover the names applied to operations are to visualize your graph on
-> @{$graph_viz$TensorBoard}) or to enable the
-> @{$guide/debugger$TensorFlow Debugger (tfdbg)}.
+> [TensorBoard](../../guide/graph_viz.md)) or to enable the
+> [TensorFlow Debugger (tfdbg)](../../guide/debugger.md).
Next, we create the `LoggingTensorHook`, passing `tensors_to_log` to the
`tensors` argument. We set `every_n_iter=50`, which specifies that probabilities
@@ -686,9 +686,9 @@ Here, we've achieved an accuracy of 97.3% on our test data set.
To learn more about TensorFlow Estimators and CNNs in TensorFlow, see the
following resources:
-* @{$custom_estimators$Creating Estimators in tf.estimator}
+* [Creating Estimators in tf.estimator](../../guide/custom_estimators.md)
provides an introduction to the TensorFlow Estimator API. It walks through
configuring an Estimator, writing a model function, calculating loss, and
defining a training op.
-* @{$deep_cnn} walks through how to build a MNIST CNN classification model
+* [Advanced Convolutional Neural Networks](../../tutorials/images/deep_cnn.md) walks through how to build a MNIST CNN classification model
*without estimators* using lower-level TensorFlow operations.
diff --git a/tensorflow/docs_src/tutorials/images/deep_cnn.md b/tensorflow/docs_src/tutorials/images/deep_cnn.md
index 42ad484bbf..00996b82e6 100644
--- a/tensorflow/docs_src/tutorials/images/deep_cnn.md
+++ b/tensorflow/docs_src/tutorials/images/deep_cnn.md
@@ -40,7 +40,7 @@ designing larger and more sophisticated models in TensorFlow:
and `tf.nn.local_response_normalization`
(Chapter 3.3 in
[AlexNet paper](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf)).
-* @{$summaries_and_tensorboard$Visualization}
+* [Visualization](../../guide/summaries_and_tensorboard.md)
of network activities during training, including input images,
losses and distributions of activations and gradients.
* Routines for calculating the
@@ -114,7 +114,7 @@ The input part of the model is built by the functions `inputs()` and
`distorted_inputs()` which read images from the CIFAR-10 binary data files.
These files contain fixed byte length records, so we use
`tf.FixedLengthRecordReader`.
-See @{$reading_data#reading-from-files$Reading Data} to
+See [Reading Data](../../api_guides/python/reading_data.md#reading-from-files) to
learn more about how the `Reader` class works.
The images are processed as follows:
@@ -131,10 +131,10 @@ artificially increase the data set size:
* Randomly distort the `tf.image.random_brightness`.
* Randomly distort the `tf.image.random_contrast`.
-Please see the @{$python/image$Images} page for the list of
+Please see the [Images](../../api_guides/python/image.md) page for the list of
available distortions. We also attach an
`tf.summary.image` to the images
-so that we may visualize them in @{$summaries_and_tensorboard$TensorBoard}.
+so that we may visualize them in [TensorBoard](../../guide/summaries_and_tensorboard.md).
This is a good practice to verify that inputs are built correctly.
<div style="width:50%; margin:auto; margin-bottom:10px; margin-top:20px;">
@@ -160,8 +160,8 @@ Layer Name | Description
`conv2` | `tf.nn.conv2d` and `tf.nn.relu` activation.
`norm2` | `tf.nn.local_response_normalization`.
`pool2` | `tf.nn.max_pool`.
-`local3` | @{$python/nn$fully connected layer with rectified linear activation}.
-`local4` | @{$python/nn$fully connected layer with rectified linear activation}.
+`local3` | [fully connected layer with rectified linear activation](../../api_guides/python/nn.md).
+`local4` | [fully connected layer with rectified linear activation](../../api_guides/python/nn.md).
`softmax_linear` | linear transformation to produce logits.
Here is a graph generated from TensorBoard describing the inference operation:
@@ -205,7 +205,7 @@ We visualize it in TensorBoard with a `tf.summary.scalar`:
We train the model using standard
[gradient descent](https://en.wikipedia.org/wiki/Gradient_descent)
-algorithm (see @{$python/train$Training} for other methods)
+algorithm (see [Training](../../api_guides/python/train.md) for other methods)
with a learning rate that
`tf.train.exponential_decay`
over time.
@@ -265,7 +265,7 @@ in `cifar10_input.py`.
`cifar10_train.py` periodically uses a `tf.train.Saver` to save
all model parameters in
-@{$guide/saved_model$checkpoint files}
+[checkpoint files](../../guide/saved_model.md)
but it does *not* evaluate the model. The checkpoint file
will be used by `cifar10_eval.py` to measure the predictive
performance (see [Evaluating a Model](#evaluating-a-model) below).
@@ -282,7 +282,7 @@ how the model is training. We want more insight into the model during training:
* Are the gradients, activations and weights reasonable?
* What is the learning rate currently at?
-@{$summaries_and_tensorboard$TensorBoard} provides this
+[TensorBoard](../../guide/summaries_and_tensorboard.md) provides this
functionality, displaying data exported periodically from `cifar10_train.py` via
a
`tf.summary.FileWriter`.
@@ -413,7 +413,7 @@ scope indicating that they should be run on the first GPU.
All variables are pinned to the CPU and accessed via
`tf.get_variable`
in order to share them in a multi-GPU version.
-See how-to on @{$variables$Sharing Variables}.
+See how-to on [Sharing Variables](../../guide/variables.md).
### Launching and Training the Model on Multiple GPU cards
diff --git a/tensorflow/docs_src/tutorials/images/image_recognition.md b/tensorflow/docs_src/tutorials/images/image_recognition.md
index 83a8d97cf0..52913b2082 100644
--- a/tensorflow/docs_src/tutorials/images/image_recognition.md
+++ b/tensorflow/docs_src/tutorials/images/image_recognition.md
@@ -106,7 +106,7 @@ curl -L "https://storage.googleapis.com/download.tensorflow.org/models/inception
Next, we need to compile the C++ binary that includes the code to load and run the graph.
If you've followed
-@{$install_sources$the instructions to download the source installation of TensorFlow}
+[the instructions to download the source installation of TensorFlow](../../install/install_sources.md)
for your platform, you should be able to build the example by
running this command from your shell terminal:
@@ -448,7 +448,7 @@ and Michael Nielsen's book has a
covering them.
To find out more about implementing convolutional neural networks, you can jump
-to the TensorFlow @{$deep_cnn$deep convolutional networks tutorial},
+to the TensorFlow [deep convolutional networks tutorial](../../tutorials/images/deep_cnn.md),
or start a bit more gently with our [Estimator MNIST tutorial](../estimators/cnn.md).
Finally, if you want to get up to speed on research in this area, you can
read the recent work of all the papers referenced in this tutorial.
diff --git a/tensorflow/docs_src/tutorials/representation/kernel_methods.md b/tensorflow/docs_src/tutorials/representation/kernel_methods.md
index 71e87f4d3e..67adc4951c 100644
--- a/tensorflow/docs_src/tutorials/representation/kernel_methods.md
+++ b/tensorflow/docs_src/tutorials/representation/kernel_methods.md
@@ -2,7 +2,7 @@
Note: This document uses a deprecated version of `tf.estimator`,
`tf.contrib.learn.Estimator`, which has a different interface. It also uses
-other `contrib` methods whose @{$version_compat#not_covered$API may not be stable}.
+other `contrib` methods whose [API may not be stable](../../guide/version_compat.md#not_covered).
In this tutorial, we demonstrate how combining (explicit) kernel methods with
linear models can drastically increase the latters' quality of predictions
@@ -52,7 +52,7 @@ In order to feed data to a `tf.contrib.learn Estimator`, it is helpful to conver
it to Tensors. For this, we will use an `input function` which adds Ops to the
TensorFlow graph that, when executed, create mini-batches of Tensors to be used
downstream. For more background on input functions, check
-@{$premade_estimators#create_input_functions$this section on input functions}.
+[this section on input functions](../../guide/premade_estimators.md#create_input_functions).
In this example, we will use the `tf.train.shuffle_batch` Op which, besides
converting numpy arrays to Tensors, allows us to specify the batch_size and
whether to randomize the input every time the input_fn Ops are executed
diff --git a/tensorflow/docs_src/tutorials/representation/linear.md b/tensorflow/docs_src/tutorials/representation/linear.md
index 014409c617..4f0e67f08e 100644
--- a/tensorflow/docs_src/tutorials/representation/linear.md
+++ b/tensorflow/docs_src/tutorials/representation/linear.md
@@ -18,7 +18,7 @@ tutorial walks through the code in greater detail.
To understand this overview it will help to have some familiarity
with basic machine learning concepts, and also with
-@{$premade_estimators$Estimators}.
+[Estimators](../../guide/premade_estimators.md).
[TOC]
@@ -175,7 +175,7 @@ the data itself. You provide the data through an input function.
The input function must return a dictionary of tensors. Each key corresponds to
the name of a `FeatureColumn`. Each key's value is a tensor containing the
values of that feature for all data instances. See
-@{$premade_estimators#input_fn} for a
+[Premade Estimators](../../guide/premade_estimators.md#input_fn) for a
more comprehensive look at input functions, and `input_fn` in the
[wide and deep learning tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep)
for an example implementation of an input function.
diff --git a/tensorflow/docs_src/tutorials/representation/word2vec.md b/tensorflow/docs_src/tutorials/representation/word2vec.md
index 7964650e19..df0d3176b6 100644
--- a/tensorflow/docs_src/tutorials/representation/word2vec.md
+++ b/tensorflow/docs_src/tutorials/representation/word2vec.md
@@ -383,13 +383,13 @@ compromised speed because we use Python for reading and feeding data items --
each of which require very little work on the TensorFlow back-end. If you find
your model is seriously bottlenecked on input data, you may want to implement a
custom data reader for your problem, as described in
-@{$new_data_formats$New Data Formats}. For the case of Skip-Gram
+[New Data Formats](../../extend/new_data_formats.md). For the case of Skip-Gram
modeling, we've actually already done this for you as an example in
[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py).
If your model is no longer I/O bound but you want still more performance, you
can take things further by writing your own TensorFlow Ops, as described in
-@{$adding_an_op$Adding a New Op}. Again we've provided an
+[Adding a New Op](../../extend/adding_an_op.md). Again we've provided an
example of this for the Skip-Gram case
[models/tutorials/embedding/word2vec_optimized.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec_optimized.py).
Feel free to benchmark these against each other to measure performance
diff --git a/tensorflow/docs_src/tutorials/sequences/recurrent.md b/tensorflow/docs_src/tutorials/sequences/recurrent.md
index 10d60f7966..39ad441381 100644
--- a/tensorflow/docs_src/tutorials/sequences/recurrent.md
+++ b/tensorflow/docs_src/tutorials/sequences/recurrent.md
@@ -138,7 +138,7 @@ for current_batch_of_words in words_in_dataset:
### Inputs
The word IDs will be embedded into a dense representation (see the
-@{$word2vec$Vector Representations Tutorial}) before feeding to
+[Vector Representations Tutorial](../../tutorials/representation/word2vec.md)) before feeding to
the LSTM. This allows the model to efficiently represent the knowledge about
particular words. It is also easy to write:
diff --git a/tensorflow/docs_src/tutorials/sequences/recurrent_quickdraw.md b/tensorflow/docs_src/tutorials/sequences/recurrent_quickdraw.md
index 37bce5b76d..657fab8a53 100644
--- a/tensorflow/docs_src/tutorials/sequences/recurrent_quickdraw.md
+++ b/tensorflow/docs_src/tutorials/sequences/recurrent_quickdraw.md
@@ -32,7 +32,7 @@ drawings in 345 categories.
To try the code for this tutorial:
-1. @{$install$Install TensorFlow} if you haven't already.
+1. [Install TensorFlow](../../install/index.md) if you haven't already.
1. Download the [tutorial code]
(https://github.com/tensorflow/models/tree/master/tutorials/rnn/quickdraw/train_model.py).
1. [Download the data](#download-the-data) in `TFRecord` format from
@@ -58,8 +58,7 @@ To try the code for this tutorial:
We make the data that we use in this tutorial available as `TFRecord` files
containing `TFExamples`. You can download the data from here:
-
-http://download.tensorflow.org/data/quickdraw_tutorial_dataset_v1.tar.gz
+<a rel="nofollow" href="http://download.tensorflow.org/data/quickdraw_tutorial_dataset_v1.tar.gz">http://download.tensorflow.org/data/quickdraw_tutorial_dataset_v1.tar.gz</a> (~1GB).
Alternatively you can download the original data in `ndjson` format from the
Google cloud and convert it to the `TFRecord` files containing `TFExamples`
@@ -108,7 +107,7 @@ This download will take a while and download a bit more than 23GB of data.
### Optional: Converting the data
To convert the `ndjson` files to
-@{$python/python_io#TFRecords_Format_Details$TFRecord} files containing
+[TFRecord](../../api_guides/python/python_io.md#TFRecords_Format_Details) files containing
[`tf.train.Example`](https://www.tensorflow.org/code/tensorflow/core/example/example.proto)
protos run the following command.
@@ -118,7 +117,7 @@ protos run the following command.
```
This will store the data in 10 shards of
-@{$python/python_io#TFRecords_Format_Details$TFRecord} files with 10000 items
+[TFRecord](../../api_guides/python/python_io.md#TFRecords_Format_Details) files with 10000 items
per class for the training data and 1000 items per class as eval data.
This conversion process is described in more detail in the following.
@@ -220,7 +219,7 @@ length 2.
### Defining the model
To define the model we create a new `Estimator`. If you want to read more about
-estimators, we recommend @{$custom_estimators$this tutorial}.
+estimators, we recommend [this tutorial](../../guide/custom_estimators.md).
To build the model, we:
diff --git a/tensorflow/examples/ios/benchmark/ios_image_load.h b/tensorflow/examples/ios/benchmark/ios_image_load.h
index 78eaded8d7..3f94984692 100644
--- a/tensorflow/examples/ios/benchmark/ios_image_load.h
+++ b/tensorflow/examples/ios/benchmark/ios_image_load.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_
-#define TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_
+#ifndef TENSORFLOW_EXAMPLES_IOS_BENCHMARK_IOS_IMAGE_LOAD_H_
+#define TENSORFLOW_EXAMPLES_IOS_BENCHMARK_IOS_IMAGE_LOAD_H_
#include <vector>
@@ -24,4 +24,4 @@ std::vector<tensorflow::uint8> LoadImageFromFile(const char* file_name,
int* out_height,
int* out_channels);
-#endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_
+#endif // TENSORFLOW_EXAMPLES_IOS_BENCHMARK_IOS_IMAGE_LOAD_H_
diff --git a/tensorflow/examples/ios/camera/ios_image_load.h b/tensorflow/examples/ios/camera/ios_image_load.h
index 87a847e145..f10b0b983a 100644
--- a/tensorflow/examples/ios/camera/ios_image_load.h
+++ b/tensorflow/examples/ios/camera/ios_image_load.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_IMAGE_LOAD_H_
-#define TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_IMAGE_LOAD_H_
+#ifndef TENSORFLOW_EXAMPLES_IOS_CAMERA_IOS_IMAGE_LOAD_H_
+#define TENSORFLOW_EXAMPLES_IOS_CAMERA_IOS_IMAGE_LOAD_H_
#include <vector>
@@ -24,4 +24,4 @@ std::vector<tensorflow::uint8> LoadImageFromFile(const char* file_name,
int* out_height,
int* out_channels);
-#endif // TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_IMAGE_LOAD_H_
+#endif // TENSORFLOW_EXAMPLES_IOS_CAMERA_IOS_IMAGE_LOAD_H_
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 3e0ea619e3..de096acc4d 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -3355,6 +3355,28 @@ func BitwiseXor(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
+// Computes element-wise population count (a.k.a. popcount, bitsum, bitcount).
+//
+// For each entry in `x`, calculates the number of `1` (on) bits in the binary
+// representation of that entry.
+//
+// **NOTE**: It is more efficient to first `tf.bitcast` your tensors into
+// `int32` or `int64` and perform the bitcount on the result, than to feed in
+// 8- or 16-bit inputs and then aggregate the resulting counts.
+func PopulationCount(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "PopulationCount",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the mean along sparse segments of a tensor.
//
// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
@@ -4037,78 +4059,6 @@ func SlideDataset(scope *Scope, input_dataset tf.Output, window_size tf.Output,
return op.Output(0)
}
-// FusedBatchNormAttr is an optional argument to FusedBatchNorm.
-type FusedBatchNormAttr func(optionalAttr)
-
-// FusedBatchNormEpsilon sets the optional epsilon attribute to value.
-//
-// value: A small float number added to the variance of x.
-// If not specified, defaults to 0.0001
-func FusedBatchNormEpsilon(value float32) FusedBatchNormAttr {
- return func(m optionalAttr) {
- m["epsilon"] = value
- }
-}
-
-// FusedBatchNormDataFormat sets the optional data_format attribute to value.
-//
-// value: The data format for x and y. Either "NHWC" (default) or "NCHW".
-// If not specified, defaults to "NHWC"
-func FusedBatchNormDataFormat(value string) FusedBatchNormAttr {
- return func(m optionalAttr) {
- m["data_format"] = value
- }
-}
-
-// FusedBatchNormIsTraining sets the optional is_training attribute to value.
-//
-// value: A bool value to indicate the operation is for training (default)
-// or inference.
-// If not specified, defaults to true
-func FusedBatchNormIsTraining(value bool) FusedBatchNormAttr {
- return func(m optionalAttr) {
- m["is_training"] = value
- }
-}
-
-// Batch normalization.
-//
-// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
-// The size of 1D Tensors matches the dimension C of the 4D Tensors.
-//
-// Arguments:
-// x: A 4D Tensor for input data.
-// scale: A 1D Tensor for scaling factor, to scale the normalized x.
-// offset: A 1D Tensor for offset, to shift to the normalized x.
-// mean: A 1D Tensor for population mean. Used for inference only;
-// must be empty for training.
-// variance: A 1D Tensor for population variance. Used for inference only;
-// must be empty for training.
-//
-// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow
-// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by
-// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused
-// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance
-// in the cuDNN case), to be reused in the gradient computation.
-func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormAttr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "FusedBatchNorm",
- Input: []tf.Input{
- x, scale, offset, mean, variance,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
-}
-
// ApproximateEqualAttr is an optional argument to ApproximateEqual.
type ApproximateEqualAttr func(optionalAttr)
@@ -8419,139 +8369,6 @@ func OrderedMapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...Or
return op.Output(0)
}
-// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter.
-type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr)
-
-// DepthwiseConv2dNativeBackpropFilterDataFormat sets the optional data_format attribute to value.
-//
-// value: Specify the data format of the input and output data. With the
-// default format "NHWC", the data is stored in the order of:
-// [batch, height, width, channels].
-// Alternatively, the format could be "NCHW", the data storage order of:
-// [batch, channels, height, width].
-// If not specified, defaults to "NHWC"
-func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2dNativeBackpropFilterAttr {
- return func(m optionalAttr) {
- m["data_format"] = value
- }
-}
-
-// DepthwiseConv2dNativeBackpropFilterDilations sets the optional dilations attribute to value.
-//
-// value: 1-D tensor of length 4. The dilation factor for each dimension of
-// `input`. If set to k > 1, there will be k-1 skipped cells between each filter
-// element on that dimension. The dimension order is determined by the value of
-// `data_format`, see above for details. Dilations in the batch and depth
-// dimensions must be 1.
-// If not specified, defaults to <i:1 i:1 i:1 i:1 >
-func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr {
- return func(m optionalAttr) {
- m["dilations"] = value
- }
-}
-
-// Computes the gradients of depthwise convolution with respect to the filter.
-//
-// Arguments:
-// input: 4-D with shape based on `data_format`. For example, if
-// `data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height,
-// in_width, in_channels]` tensor.
-// filter_sizes: An integer vector representing the tensor shape of `filter`,
-// where `filter` is a 4-D
-// `[filter_height, filter_width, in_channels, depthwise_multiplier]` tensor.
-// out_backprop: 4-D with shape based on `data_format`.
-// For example, if `data_format` is 'NHWC' then
-// out_backprop shape is `[batch, out_height, out_width, out_channels]`.
-// Gradients w.r.t. the output of the convolution.
-// strides: The stride of the sliding window for each dimension of the input
-// of the convolution.
-// padding: The type of padding algorithm to use.
-//
-// Returns 4-D with shape
-// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t.
-// the `filter` input of the convolution.
-func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropFilterAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"strides": strides, "padding": padding}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "DepthwiseConv2dNativeBackpropFilter",
- Input: []tf.Input{
- input, filter_sizes, out_backprop,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns immutable tensor from memory region.
-//
-// The current implementation memmaps the tensor from a file.
-//
-// Arguments:
-// dtype: Type of the returned tensor.
-// shape: Shape of the returned tensor.
-// memory_region_name: Name of readonly memory region used by the tensor, see
-// NewReadOnlyMemoryRegionFromFile in tensorflow::Env.
-func ImmutableConst(scope *Scope, dtype tf.DataType, shape tf.Shape, memory_region_name string) (tensor tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtype": dtype, "shape": shape, "memory_region_name": memory_region_name}
- opspec := tf.OpSpec{
- Type: "ImmutableConst",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// StringJoinAttr is an optional argument to StringJoin.
-type StringJoinAttr func(optionalAttr)
-
-// StringJoinSeparator sets the optional separator attribute to value.
-//
-// value: string, an optional join separator.
-// If not specified, defaults to ""
-func StringJoinSeparator(value string) StringJoinAttr {
- return func(m optionalAttr) {
- m["separator"] = value
- }
-}
-
-// Joins the strings in the given list of string tensors into one tensor;
-//
-// with the given separator (default is an empty separator).
-//
-// Arguments:
-// inputs: A list of string tensors. The tensors must all have the same shape,
-// or be scalars. Scalars may be mixed in; these will be broadcast to the shape
-// of non-scalar inputs.
-func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "StringJoin",
- Input: []tf.Input{
- tf.OutputList(inputs),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl.
type ResourceApplyFtrlAttr func(optionalAttr)
@@ -8794,28 +8611,6 @@ func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...Ass
return scope.AddOperation(opspec)
}
-// Computes element-wise population count (a.k.a. popcount, bitsum, bitcount).
-//
-// For each entry in `x`, calculates the number of `1` (on) bits in the binary
-// representation of that entry.
-//
-// **NOTE**: It is more efficient to first `tf.bitcast` your tensors into
-// `int32` or `int64` and perform the bitcount on the result, than to feed in
-// 8- or 16-bit inputs and then aggregate the resulting counts.
-func PopulationCount(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "PopulationCount",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Broadcasts a tensor value to one or more other devices.
func CollectiveBcastSend(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) {
if scope.Err() != nil {
@@ -9496,34 +9291,216 @@ func IsInf(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
-// Computes the sum along sparse segments of a tensor divided by the sqrt of N.
+// TruncatedNormalAttr is an optional argument to TruncatedNormal.
+type TruncatedNormalAttr func(optionalAttr)
+
+// TruncatedNormalSeed sets the optional seed attribute to value.
//
-// N is the size of the segment being reduced.
+// value: If either `seed` or `seed2` are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func TruncatedNormalSeed(value int64) TruncatedNormalAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// TruncatedNormalSeed2 sets the optional seed2 attribute to value.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func TruncatedNormalSeed2(value int64) TruncatedNormalAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Outputs random values from a truncated normal distribution.
+//
+// The generated values follow a normal distribution with mean 0 and standard
+// deviation 1, except that values whose magnitude is more than 2 standard
+// deviations from the mean are dropped and re-picked.
//
// Arguments:
+// shape: The shape of the output tensor.
+// dtype: The type of the output.
//
-// indices: A 1-D tensor. Has same rank as `segment_ids`.
-// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+// Returns A tensor of the specified shape filled with random truncated normal
+// values.
+func TruncatedNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...TruncatedNormalAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtype": dtype}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "TruncatedNormal",
+ Input: []tf.Input{
+ shape,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// SkipgramAttr is an optional argument to Skipgram.
+type SkipgramAttr func(optionalAttr)
+
+// SkipgramWindowSize sets the optional window_size attribute to value.
//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `k`, the number of segments.
-func SparseSegmentSqrtN(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) {
+// value: The number of words to predict to the left and right of the target.
+// If not specified, defaults to 5
+func SkipgramWindowSize(value int64) SkipgramAttr {
+ return func(m optionalAttr) {
+ m["window_size"] = value
+ }
+}
+
+// SkipgramMinCount sets the optional min_count attribute to value.
+//
+// value: The minimum number of word occurrences for it to be included in the
+// vocabulary.
+// If not specified, defaults to 5
+func SkipgramMinCount(value int64) SkipgramAttr {
+ return func(m optionalAttr) {
+ m["min_count"] = value
+ }
+}
+
+// SkipgramSubsample sets the optional subsample attribute to value.
+//
+// value: Threshold for word occurrence. Words that appear with higher
+// frequency will be randomly down-sampled. Set to 0 to disable.
+// If not specified, defaults to 0.001
+func SkipgramSubsample(value float32) SkipgramAttr {
+ return func(m optionalAttr) {
+ m["subsample"] = value
+ }
+}
+
+// Parses a text file and creates a batch of examples.
+//
+// DEPRECATED at GraphDef version 19: Moving word2vec into tensorflow_models/tutorials and deprecating its ops here as a result
+//
+// Arguments:
+// filename: The corpus's text file name.
+// batch_size: The size of produced batch.
+//
+// Returns A vector of words in the corpus.Frequencies of words. Sorted in the non-ascending order.Number of words per epoch in the data file.The current epoch number.The total number of words processed so far.A vector of word ids.A vector of word ids.
+func Skipgram(scope *Scope, filename string, batch_size int64, optional ...SkipgramAttr) (vocab_word tf.Output, vocab_freq tf.Output, words_per_epoch tf.Output, current_epoch tf.Output, total_words_processed tf.Output, examples tf.Output, labels tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{"filename": filename, "batch_size": batch_size}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "SparseSegmentSqrtN",
+ Type: "Skipgram",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4), op.Output(5), op.Output(6)
+}
+
+// StringToNumberAttr is an optional argument to StringToNumber.
+type StringToNumberAttr func(optionalAttr)
+
+// StringToNumberOutType sets the optional out_type attribute to value.
+//
+// value: The numeric type to interpret each string in `string_tensor` as.
+// If not specified, defaults to DT_FLOAT
+func StringToNumberOutType(value tf.DataType) StringToNumberAttr {
+ return func(m optionalAttr) {
+ m["out_type"] = value
+ }
+}
+
+// Converts each string in the input Tensor to the specified numeric type.
+//
+// (Note that int32 overflow results in an error while float overflow
+// results in a rounded value.)
+//
+// Returns A Tensor of the same shape as the input `string_tensor`.
+func StringToNumber(scope *Scope, string_tensor tf.Output, optional ...StringToNumberAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "StringToNumber",
Input: []tf.Input{
- data, indices, segment_ids,
+ string_tensor,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
+// ResourceApplyFtrlV2Attr is an optional argument to ResourceApplyFtrlV2.
+type ResourceApplyFtrlV2Attr func(optionalAttr)
+
+// ResourceApplyFtrlV2UseLocking sets the optional use_locking attribute to value.
+//
+// value: If `True`, updating of the var and accum tensors will be protected
+// by a lock; otherwise the behavior is undefined, but may exhibit less
+// contention.
+// If not specified, defaults to false
+func ResourceApplyFtrlV2UseLocking(value bool) ResourceApplyFtrlV2Attr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Update '*var' according to the Ftrl-proximal scheme.
+//
+// grad_with_shrinkage = grad + 2 * l2_shrinkage * var
+// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage
+// linear += grad_with_shrinkage +
+// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var
+// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2
+// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0
+// accum = accum_new
+//
+// Arguments:
+// var_: Should be from a Variable().
+// accum: Should be from a Variable().
+// linear: Should be from a Variable().
+// grad: The gradient.
+// lr: Scaling factor. Must be a scalar.
+// l1: L1 regulariation. Must be a scalar.
+// l2: L2 shrinkage regulariation. Must be a scalar.
+//
+// lr_power: Scaling factor. Must be a scalar.
+//
+// Returns the created operation.
+func ResourceApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, l2_shrinkage tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlV2Attr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceApplyFtrlV2",
+ Input: []tf.Input{
+ var_, accum, linear, grad, lr, l1, l2, l2_shrinkage, lr_power,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
// Adds up a `SparseTensor` and a dense `Tensor`, producing a dense `Tensor`.
//
// This Op does not require `a_indices` be sorted in standard lexicographic order.
@@ -9824,6 +9801,139 @@ func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, option
return op.Output(0)
}
+// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter.
+type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr)
+
+// DepthwiseConv2dNativeBackpropFilterDataFormat sets the optional data_format attribute to value.
+//
+// value: Specify the data format of the input and output data. With the
+// default format "NHWC", the data is stored in the order of:
+// [batch, height, width, channels].
+// Alternatively, the format could be "NCHW", the data storage order of:
+// [batch, channels, height, width].
+// If not specified, defaults to "NHWC"
+func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2dNativeBackpropFilterAttr {
+ return func(m optionalAttr) {
+ m["data_format"] = value
+ }
+}
+
+// DepthwiseConv2dNativeBackpropFilterDilations sets the optional dilations attribute to value.
+//
+// value: 1-D tensor of length 4. The dilation factor for each dimension of
+// `input`. If set to k > 1, there will be k-1 skipped cells between each filter
+// element on that dimension. The dimension order is determined by the value of
+// `data_format`, see above for details. Dilations in the batch and depth
+// dimensions must be 1.
+// If not specified, defaults to <i:1 i:1 i:1 i:1 >
+func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr {
+ return func(m optionalAttr) {
+ m["dilations"] = value
+ }
+}
+
+// Computes the gradients of depthwise convolution with respect to the filter.
+//
+// Arguments:
+// input: 4-D with shape based on `data_format`. For example, if
+// `data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height,
+// in_width, in_channels]` tensor.
+// filter_sizes: An integer vector representing the tensor shape of `filter`,
+// where `filter` is a 4-D
+// `[filter_height, filter_width, in_channels, depthwise_multiplier]` tensor.
+// out_backprop: 4-D with shape based on `data_format`.
+// For example, if `data_format` is 'NHWC' then
+// out_backprop shape is `[batch, out_height, out_width, out_channels]`.
+// Gradients w.r.t. the output of the convolution.
+// strides: The stride of the sliding window for each dimension of the input
+// of the convolution.
+// padding: The type of padding algorithm to use.
+//
+// Returns 4-D with shape
+// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t.
+// the `filter` input of the convolution.
+func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropFilterAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"strides": strides, "padding": padding}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "DepthwiseConv2dNativeBackpropFilter",
+ Input: []tf.Input{
+ input, filter_sizes, out_backprop,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns immutable tensor from memory region.
+//
+// The current implementation memmaps the tensor from a file.
+//
+// Arguments:
+// dtype: Type of the returned tensor.
+// shape: Shape of the returned tensor.
+// memory_region_name: Name of readonly memory region used by the tensor, see
+// NewReadOnlyMemoryRegionFromFile in tensorflow::Env.
+func ImmutableConst(scope *Scope, dtype tf.DataType, shape tf.Shape, memory_region_name string) (tensor tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtype": dtype, "shape": shape, "memory_region_name": memory_region_name}
+ opspec := tf.OpSpec{
+ Type: "ImmutableConst",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// StringJoinAttr is an optional argument to StringJoin.
+type StringJoinAttr func(optionalAttr)
+
+// StringJoinSeparator sets the optional separator attribute to value.
+//
+// value: string, an optional join separator.
+// If not specified, defaults to ""
+func StringJoinSeparator(value string) StringJoinAttr {
+ return func(m optionalAttr) {
+ m["separator"] = value
+ }
+}
+
+// Joins the strings in the given list of string tensors into one tensor;
+//
+// with the given separator (default is an empty separator).
+//
+// Arguments:
+// inputs: A list of string tensors. The tensors must all have the same shape,
+// or be scalars. Scalars may be mixed in; these will be broadcast to the shape
+// of non-scalar inputs.
+func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "StringJoin",
+ Input: []tf.Input{
+ tf.OutputList(inputs),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// StringSplitV2Attr is an optional argument to StringSplitV2.
type StringSplitV2Attr func(optionalAttr)
@@ -9997,6 +10107,24 @@ func SparseMatMul(scope *Scope, a tf.Output, b tf.Output, optional ...SparseMatM
return op.Output(0)
}
+// Elementwise computes the bitwise AND of `x` and `y`.
+//
+// The result will have those bits set, that are set in both `x` and `y`. The
+// computation is performed on the underlying representations of `x` and `y`.
+func BitwiseAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BitwiseAnd",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Concatenates quantized tensors along one dimension.
//
// Arguments:
@@ -11227,6 +11355,85 @@ func FakeQuantWithMinMaxVars(scope *Scope, inputs tf.Output, min tf.Output, max
return op.Output(0)
}
+// ResourceScatterNdUpdateAttr is an optional argument to ResourceScatterNdUpdate.
+type ResourceScatterNdUpdateAttr func(optionalAttr)
+
+// ResourceScatterNdUpdateUseLocking sets the optional use_locking attribute to value.
+//
+// value: An optional bool. Defaults to True. If True, the assignment will
+// be protected by a lock; otherwise the behavior is undefined,
+// but may exhibit less contention.
+// If not specified, defaults to true
+func ResourceScatterNdUpdateUseLocking(value bool) ResourceScatterNdUpdateAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Applies sparse `updates` to individual values or slices within a given
+//
+// variable according to `indices`.
+//
+// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+//
+// `indices` must be integer tensor, containing indices into `ref`.
+// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+//
+// The innermost dimension of `indices` (with length `K`) corresponds to
+// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+// dimension of `ref`.
+//
+// `updates` is `Tensor` of rank `Q-1+P-K` with shape:
+//
+// ```
+// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+// ```
+//
+// For example, say we want to update 4 scattered elements to a rank-1 tensor to
+// 8 elements. In Python, that update would look like this:
+//
+// ```python
+// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+// indices = tf.constant([[4], [3], [1] ,[7]])
+// updates = tf.constant([9, 10, 11, 12])
+// update = tf.scatter_nd_update(ref, indices, updates)
+// with tf.Session() as sess:
+// print sess.run(update)
+// ```
+//
+// The resulting update to ref would look like this:
+//
+// [1, 11, 3, 10, 9, 6, 7, 12]
+//
+// See @{tf.scatter_nd} for more details about how to make updates to
+// slices.
+//
+// Arguments:
+// ref: A resource handle. Must be from a VarHandleOp.
+// indices: A Tensor. Must be one of the following types: int32, int64.
+// A tensor of indices into ref.
+// updates: A Tensor. Must have the same type as ref. A tensor of updated
+// values to add to ref.
+//
+// Returns the created operation.
+func ResourceScatterNdUpdate(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdUpdateAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceScatterNdUpdate",
+ Input: []tf.Input{
+ ref, indices, updates,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
// Applies softmax to a batched N-D `SparseTensor`.
//
// The inputs represent an N-D SparseTensor with logical shape `[..., B, C]`
@@ -12171,34 +12378,6 @@ func OrderedMapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.
return values
}
-// Inverse fast Fourier transform.
-//
-// Computes the inverse 1-dimensional discrete Fourier transform over the
-// inner-most dimension of `input`.
-//
-// Arguments:
-// input: A complex64 tensor.
-//
-// Returns A complex64 tensor of the same shape as `input`. The inner-most
-// dimension of `input` is replaced with its inverse 1D Fourier transform.
-//
-// @compatibility(numpy)
-// Equivalent to np.fft.ifft
-// @end_compatibility
-func IFFT(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "IFFT",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp.
type ResourceSparseApplyRMSPropAttr func(optionalAttr)
@@ -12777,85 +12956,6 @@ func DeserializeSparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataT
return op.Output(0), op.Output(1), op.Output(2)
}
-// ResourceScatterNdUpdateAttr is an optional argument to ResourceScatterNdUpdate.
-type ResourceScatterNdUpdateAttr func(optionalAttr)
-
-// ResourceScatterNdUpdateUseLocking sets the optional use_locking attribute to value.
-//
-// value: An optional bool. Defaults to True. If True, the assignment will
-// be protected by a lock; otherwise the behavior is undefined,
-// but may exhibit less contention.
-// If not specified, defaults to true
-func ResourceScatterNdUpdateUseLocking(value bool) ResourceScatterNdUpdateAttr {
- return func(m optionalAttr) {
- m["use_locking"] = value
- }
-}
-
-// Applies sparse `updates` to individual values or slices within a given
-//
-// variable according to `indices`.
-//
-// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
-//
-// `indices` must be integer tensor, containing indices into `ref`.
-// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
-//
-// The innermost dimension of `indices` (with length `K`) corresponds to
-// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
-// dimension of `ref`.
-//
-// `updates` is `Tensor` of rank `Q-1+P-K` with shape:
-//
-// ```
-// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
-// ```
-//
-// For example, say we want to update 4 scattered elements to a rank-1 tensor to
-// 8 elements. In Python, that update would look like this:
-//
-// ```python
-// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
-// indices = tf.constant([[4], [3], [1] ,[7]])
-// updates = tf.constant([9, 10, 11, 12])
-// update = tf.scatter_nd_update(ref, indices, updates)
-// with tf.Session() as sess:
-// print sess.run(update)
-// ```
-//
-// The resulting update to ref would look like this:
-//
-// [1, 11, 3, 10, 9, 6, 7, 12]
-//
-// See @{tf.scatter_nd} for more details about how to make updates to
-// slices.
-//
-// Arguments:
-// ref: A resource handle. Must be from a VarHandleOp.
-// indices: A Tensor. Must be one of the following types: int32, int64.
-// A tensor of indices into ref.
-// updates: A Tensor. Must have the same type as ref. A tensor of updated
-// values to add to ref.
-//
-// Returns the created operation.
-func ResourceScatterNdUpdate(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdUpdateAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResourceScatterNdUpdate",
- Input: []tf.Input{
- ref, indices, updates,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
// SqueezeAttr is an optional argument to Squeeze.
type SqueezeAttr func(optionalAttr)
@@ -16074,6 +16174,78 @@ func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
+// FusedBatchNormAttr is an optional argument to FusedBatchNorm.
+type FusedBatchNormAttr func(optionalAttr)
+
+// FusedBatchNormEpsilon sets the optional epsilon attribute to value.
+//
+// value: A small float number added to the variance of x.
+// If not specified, defaults to 0.0001
+func FusedBatchNormEpsilon(value float32) FusedBatchNormAttr {
+ return func(m optionalAttr) {
+ m["epsilon"] = value
+ }
+}
+
+// FusedBatchNormDataFormat sets the optional data_format attribute to value.
+//
+// value: The data format for x and y. Either "NHWC" (default) or "NCHW".
+// If not specified, defaults to "NHWC"
+func FusedBatchNormDataFormat(value string) FusedBatchNormAttr {
+ return func(m optionalAttr) {
+ m["data_format"] = value
+ }
+}
+
+// FusedBatchNormIsTraining sets the optional is_training attribute to value.
+//
+// value: A bool value to indicate the operation is for training (default)
+// or inference.
+// If not specified, defaults to true
+func FusedBatchNormIsTraining(value bool) FusedBatchNormAttr {
+ return func(m optionalAttr) {
+ m["is_training"] = value
+ }
+}
+
+// Batch normalization.
+//
+// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
+// The size of 1D Tensors matches the dimension C of the 4D Tensors.
+//
+// Arguments:
+// x: A 4D Tensor for input data.
+// scale: A 1D Tensor for scaling factor, to scale the normalized x.
+// offset: A 1D Tensor for offset, to shift to the normalized x.
+// mean: A 1D Tensor for population mean. Used for inference only;
+// must be empty for training.
+// variance: A 1D Tensor for population variance. Used for inference only;
+// must be empty for training.
+//
+// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow
+// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by
+// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused
+// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance
+// in the cuDNN case), to be reused in the gradient computation.
+func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormAttr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "FusedBatchNorm",
+ Input: []tf.Input{
+ x, scale, offset, mean, variance,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
+}
+
// RandomStandardNormalAttr is an optional argument to RandomStandardNormal.
type RandomStandardNormalAttr func(optionalAttr)
@@ -16882,216 +17054,6 @@ func MaxPoolV2(scope *Scope, input tf.Output, ksize tf.Output, strides tf.Output
return op.Output(0)
}
-// SkipgramAttr is an optional argument to Skipgram.
-type SkipgramAttr func(optionalAttr)
-
-// SkipgramWindowSize sets the optional window_size attribute to value.
-//
-// value: The number of words to predict to the left and right of the target.
-// If not specified, defaults to 5
-func SkipgramWindowSize(value int64) SkipgramAttr {
- return func(m optionalAttr) {
- m["window_size"] = value
- }
-}
-
-// SkipgramMinCount sets the optional min_count attribute to value.
-//
-// value: The minimum number of word occurrences for it to be included in the
-// vocabulary.
-// If not specified, defaults to 5
-func SkipgramMinCount(value int64) SkipgramAttr {
- return func(m optionalAttr) {
- m["min_count"] = value
- }
-}
-
-// SkipgramSubsample sets the optional subsample attribute to value.
-//
-// value: Threshold for word occurrence. Words that appear with higher
-// frequency will be randomly down-sampled. Set to 0 to disable.
-// If not specified, defaults to 0.001
-func SkipgramSubsample(value float32) SkipgramAttr {
- return func(m optionalAttr) {
- m["subsample"] = value
- }
-}
-
-// Parses a text file and creates a batch of examples.
-//
-// DEPRECATED at GraphDef version 19: Moving word2vec into tensorflow_models/tutorials and deprecating its ops here as a result
-//
-// Arguments:
-// filename: The corpus's text file name.
-// batch_size: The size of produced batch.
-//
-// Returns A vector of words in the corpus.Frequencies of words. Sorted in the non-ascending order.Number of words per epoch in the data file.The current epoch number.The total number of words processed so far.A vector of word ids.A vector of word ids.
-func Skipgram(scope *Scope, filename string, batch_size int64, optional ...SkipgramAttr) (vocab_word tf.Output, vocab_freq tf.Output, words_per_epoch tf.Output, current_epoch tf.Output, total_words_processed tf.Output, examples tf.Output, labels tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"filename": filename, "batch_size": batch_size}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Skipgram",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4), op.Output(5), op.Output(6)
-}
-
-// StringToNumberAttr is an optional argument to StringToNumber.
-type StringToNumberAttr func(optionalAttr)
-
-// StringToNumberOutType sets the optional out_type attribute to value.
-//
-// value: The numeric type to interpret each string in `string_tensor` as.
-// If not specified, defaults to DT_FLOAT
-func StringToNumberOutType(value tf.DataType) StringToNumberAttr {
- return func(m optionalAttr) {
- m["out_type"] = value
- }
-}
-
-// Converts each string in the input Tensor to the specified numeric type.
-//
-// (Note that int32 overflow results in an error while float overflow
-// results in a rounded value.)
-//
-// Returns A Tensor of the same shape as the input `string_tensor`.
-func StringToNumber(scope *Scope, string_tensor tf.Output, optional ...StringToNumberAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "StringToNumber",
- Input: []tf.Input{
- string_tensor,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// ResourceApplyFtrlV2Attr is an optional argument to ResourceApplyFtrlV2.
-type ResourceApplyFtrlV2Attr func(optionalAttr)
-
-// ResourceApplyFtrlV2UseLocking sets the optional use_locking attribute to value.
-//
-// value: If `True`, updating of the var and accum tensors will be protected
-// by a lock; otherwise the behavior is undefined, but may exhibit less
-// contention.
-// If not specified, defaults to false
-func ResourceApplyFtrlV2UseLocking(value bool) ResourceApplyFtrlV2Attr {
- return func(m optionalAttr) {
- m["use_locking"] = value
- }
-}
-
-// Update '*var' according to the Ftrl-proximal scheme.
-//
-// grad_with_shrinkage = grad + 2 * l2_shrinkage * var
-// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage
-// linear += grad_with_shrinkage +
-// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var
-// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2
-// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0
-// accum = accum_new
-//
-// Arguments:
-// var_: Should be from a Variable().
-// accum: Should be from a Variable().
-// linear: Should be from a Variable().
-// grad: The gradient.
-// lr: Scaling factor. Must be a scalar.
-// l1: L1 regulariation. Must be a scalar.
-// l2: L2 shrinkage regulariation. Must be a scalar.
-//
-// lr_power: Scaling factor. Must be a scalar.
-//
-// Returns the created operation.
-func ResourceApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, l2_shrinkage tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlV2Attr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResourceApplyFtrlV2",
- Input: []tf.Input{
- var_, accum, linear, grad, lr, l1, l2, l2_shrinkage, lr_power,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
-// TruncatedNormalAttr is an optional argument to TruncatedNormal.
-type TruncatedNormalAttr func(optionalAttr)
-
-// TruncatedNormalSeed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func TruncatedNormalSeed(value int64) TruncatedNormalAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// TruncatedNormalSeed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func TruncatedNormalSeed2(value int64) TruncatedNormalAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Outputs random values from a truncated normal distribution.
-//
-// The generated values follow a normal distribution with mean 0 and standard
-// deviation 1, except that values whose magnitude is more than 2 standard
-// deviations from the mean are dropped and re-picked.
-//
-// Arguments:
-// shape: The shape of the output tensor.
-// dtype: The type of the output.
-//
-// Returns A tensor of the specified shape filled with random truncated normal
-// values.
-func TruncatedNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...TruncatedNormalAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtype": dtype}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "TruncatedNormal",
- Input: []tf.Input{
- shape,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// MutableDenseHashTableV2Attr is an optional argument to MutableDenseHashTableV2.
type MutableDenseHashTableV2Attr func(optionalAttr)
@@ -17191,6 +17153,34 @@ func MutableDenseHashTableV2(scope *Scope, empty_key tf.Output, value_dtype tf.D
return op.Output(0)
}
+// Inverse fast Fourier transform.
+//
+// Computes the inverse 1-dimensional discrete Fourier transform over the
+// inner-most dimension of `input`.
+//
+// Arguments:
+// input: A complex64 tensor.
+//
+// Returns A complex64 tensor of the same shape as `input`. The inner-most
+// dimension of `input` is replaced with its inverse 1D Fourier transform.
+//
+// @compatibility(numpy)
+// Equivalent to np.fft.ifft
+// @end_compatibility
+func IFFT(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "IFFT",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// 2D fast Fourier transform.
//
// Computes the 2-dimensional discrete Fourier transform over the inner-most
@@ -17699,123 +17689,6 @@ func TextLineDataset(scope *Scope, filenames tf.Output, compression_type tf.Outp
return op.Output(0)
}
-// CudnnRNNParamsSizeAttr is an optional argument to CudnnRNNParamsSize.
-type CudnnRNNParamsSizeAttr func(optionalAttr)
-
-// CudnnRNNParamsSizeRnnMode sets the optional rnn_mode attribute to value.
-// If not specified, defaults to "lstm"
-func CudnnRNNParamsSizeRnnMode(value string) CudnnRNNParamsSizeAttr {
- return func(m optionalAttr) {
- m["rnn_mode"] = value
- }
-}
-
-// CudnnRNNParamsSizeInputMode sets the optional input_mode attribute to value.
-// If not specified, defaults to "linear_input"
-func CudnnRNNParamsSizeInputMode(value string) CudnnRNNParamsSizeAttr {
- return func(m optionalAttr) {
- m["input_mode"] = value
- }
-}
-
-// CudnnRNNParamsSizeDirection sets the optional direction attribute to value.
-// If not specified, defaults to "unidirectional"
-func CudnnRNNParamsSizeDirection(value string) CudnnRNNParamsSizeAttr {
- return func(m optionalAttr) {
- m["direction"] = value
- }
-}
-
-// CudnnRNNParamsSizeDropout sets the optional dropout attribute to value.
-// If not specified, defaults to 0
-func CudnnRNNParamsSizeDropout(value float32) CudnnRNNParamsSizeAttr {
- return func(m optionalAttr) {
- m["dropout"] = value
- }
-}
-
-// CudnnRNNParamsSizeSeed sets the optional seed attribute to value.
-// If not specified, defaults to 0
-func CudnnRNNParamsSizeSeed(value int64) CudnnRNNParamsSizeAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// CudnnRNNParamsSizeSeed2 sets the optional seed2 attribute to value.
-// If not specified, defaults to 0
-func CudnnRNNParamsSizeSeed2(value int64) CudnnRNNParamsSizeAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Computes size of weights that can be used by a Cudnn RNN model.
-//
-// Return the params size that can be used by the Cudnn RNN model. Subsequent
-// weight allocation and initialization should use this size.
-//
-// num_layers: Specifies the number of layers in the RNN model.
-// num_units: Specifies the size of the hidden state.
-// input_size: Specifies the size of the input state.
-// rnn_mode: Indicates the type of the RNN model.
-// input_mode: Indicate whether there is a linear projection between the input and
-// The actual computation before the first layer. 'skip_input' is only allowed
-// when input_size == num_units; 'auto_select' implies 'skip_input' when
-// input_size == num_units; otherwise, it implies 'linear_input'.
-// direction: Indicates whether a bidirectional model will be used.
-// dir = (direction == bidirectional) ? 2 : 1
-// dropout: dropout probability. When set to 0., dropout is disabled.
-// seed: the 1st part of a seed to initialize dropout.
-// seed2: the 2nd part of a seed to initialize dropout.
-// params_size: The size of the params buffer that should be allocated and
-// initialized for this RNN model. Note that this params buffer may not be
-// compatible across GPUs. Please use CudnnRNNParamsWeights and
-// CudnnRNNParamsBiases to save and restore them in a way that is compatible
-// across different runs.
-func CudnnRNNParamsSize(scope *Scope, num_layers tf.Output, num_units tf.Output, input_size tf.Output, T tf.DataType, S tf.DataType, optional ...CudnnRNNParamsSizeAttr) (params_size tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"T": T, "S": S}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "CudnnRNNParamsSize",
- Input: []tf.Input{
- num_layers, num_units, input_size,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes gradients for SparseSegmentMean.
-//
-// Returns tensor "output" with same shape as grad, except for dimension 0 whose
-// value is output_dim0.
-//
-// Arguments:
-// grad: gradient propagated to the SparseSegmentMean op.
-// indices: indices passed to the corresponding SparseSegmentMean op.
-// segment_ids: segment_ids passed to the corresponding SparseSegmentMean op.
-// output_dim0: dimension 0 of "data" passed to SparseSegmentMean op.
-func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SparseSegmentMeanGrad",
- Input: []tf.Input{
- grad, indices, segment_ids, output_dim0,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Returns the set of files matching one or more glob patterns.
//
// Note that this routine only supports wildcard characters in the
@@ -20548,6 +20421,151 @@ func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf
return op.Output(0)
}
+// CudnnRNNParamsSizeAttr is an optional argument to CudnnRNNParamsSize.
+type CudnnRNNParamsSizeAttr func(optionalAttr)
+
+// CudnnRNNParamsSizeRnnMode sets the optional rnn_mode attribute to value.
+// If not specified, defaults to "lstm"
+func CudnnRNNParamsSizeRnnMode(value string) CudnnRNNParamsSizeAttr {
+ return func(m optionalAttr) {
+ m["rnn_mode"] = value
+ }
+}
+
+// CudnnRNNParamsSizeInputMode sets the optional input_mode attribute to value.
+// If not specified, defaults to "linear_input"
+func CudnnRNNParamsSizeInputMode(value string) CudnnRNNParamsSizeAttr {
+ return func(m optionalAttr) {
+ m["input_mode"] = value
+ }
+}
+
+// CudnnRNNParamsSizeDirection sets the optional direction attribute to value.
+// If not specified, defaults to "unidirectional"
+func CudnnRNNParamsSizeDirection(value string) CudnnRNNParamsSizeAttr {
+ return func(m optionalAttr) {
+ m["direction"] = value
+ }
+}
+
+// CudnnRNNParamsSizeDropout sets the optional dropout attribute to value.
+// If not specified, defaults to 0
+func CudnnRNNParamsSizeDropout(value float32) CudnnRNNParamsSizeAttr {
+ return func(m optionalAttr) {
+ m["dropout"] = value
+ }
+}
+
+// CudnnRNNParamsSizeSeed sets the optional seed attribute to value.
+// If not specified, defaults to 0
+func CudnnRNNParamsSizeSeed(value int64) CudnnRNNParamsSizeAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// CudnnRNNParamsSizeSeed2 sets the optional seed2 attribute to value.
+// If not specified, defaults to 0
+func CudnnRNNParamsSizeSeed2(value int64) CudnnRNNParamsSizeAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Computes size of weights that can be used by a Cudnn RNN model.
+//
+// Return the params size that can be used by the Cudnn RNN model. Subsequent
+// weight allocation and initialization should use this size.
+//
+// num_layers: Specifies the number of layers in the RNN model.
+// num_units: Specifies the size of the hidden state.
+// input_size: Specifies the size of the input state.
+// rnn_mode: Indicates the type of the RNN model.
+// input_mode: Indicate whether there is a linear projection between the input and
+// The actual computation before the first layer. 'skip_input' is only allowed
+// when input_size == num_units; 'auto_select' implies 'skip_input' when
+// input_size == num_units; otherwise, it implies 'linear_input'.
+// direction: Indicates whether a bidirectional model will be used.
+// dir = (direction == bidirectional) ? 2 : 1
+// dropout: dropout probability. When set to 0., dropout is disabled.
+// seed: the 1st part of a seed to initialize dropout.
+// seed2: the 2nd part of a seed to initialize dropout.
+// params_size: The size of the params buffer that should be allocated and
+// initialized for this RNN model. Note that this params buffer may not be
+// compatible across GPUs. Please use CudnnRNNParamsWeights and
+// CudnnRNNParamsBiases to save and restore them in a way that is compatible
+// across different runs.
+func CudnnRNNParamsSize(scope *Scope, num_layers tf.Output, num_units tf.Output, input_size tf.Output, T tf.DataType, S tf.DataType, optional ...CudnnRNNParamsSizeAttr) (params_size tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"T": T, "S": S}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "CudnnRNNParamsSize",
+ Input: []tf.Input{
+ num_layers, num_units, input_size,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes gradients for SparseSegmentMean.
+//
+// Returns tensor "output" with same shape as grad, except for dimension 0 whose
+// value is output_dim0.
+//
+// Arguments:
+// grad: gradient propagated to the SparseSegmentMean op.
+// indices: indices passed to the corresponding SparseSegmentMean op.
+// segment_ids: segment_ids passed to the corresponding SparseSegmentMean op.
+// output_dim0: dimension 0 of "data" passed to SparseSegmentMean op.
+func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseSegmentMeanGrad",
+ Input: []tf.Input{
+ grad, indices, segment_ids, output_dim0,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes the sum along sparse segments of a tensor divided by the sqrt of N.
+//
+// N is the size of the segment being reduced.
+//
+// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
+// segments.
+//
+// Arguments:
+//
+// indices: A 1-D tensor. Has same rank as `segment_ids`.
+// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+//
+// Returns Has same shape as data, except for dimension 0 which
+// has size `k`, the number of segments.
+func SparseSegmentSqrtN(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseSegmentSqrtN",
+ Input: []tf.Input{
+ data, indices, segment_ids,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Compute the upper regularized incomplete Gamma function `Q(a, x)`.
//
// The upper regularized incomplete Gamma function is defined as:
@@ -31898,21 +31916,3 @@ func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Outpu
}
return scope.AddOperation(opspec)
}
-
-// Elementwise computes the bitwise AND of `x` and `y`.
-//
-// The result will have those bits set, that are set in both `x` and `y`. The
-// computation is performed on the underlying representations of `x` and `y`.
-func BitwiseAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BitwiseAnd",
- Input: []tf.Input{
- x, y,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
diff --git a/tensorflow/js/BUILD b/tensorflow/js/BUILD
new file mode 100644
index 0000000000..ad0dc44f54
--- /dev/null
+++ b/tensorflow/js/BUILD
@@ -0,0 +1,52 @@
+# Description:
+# JavaScript/TypeScript code generation for TensorFlow.js
+
+visibility = [
+ "//tensorflow:internal",
+]
+
+package(default_visibility = visibility)
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+cc_library(
+ name = "ts_op_gen",
+ srcs = [
+ "ops/ts_op_gen.cc",
+ ],
+ hdrs = [
+ "ops/ts_op_gen.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:op_gen_lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+tf_cc_test(
+ name = "ts_op_gen_test",
+ srcs = [
+ "ops/ts_op_gen.cc",
+ "ops/ts_op_gen.h",
+ "ops/ts_op_gen_test.cc",
+ ],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:op_gen_lib",
+ "//tensorflow/core:proto_text",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/js/ops/ts_op_gen.cc b/tensorflow/js/ops/ts_op_gen.cc
new file mode 100644
index 0000000000..babf55cd5f
--- /dev/null
+++ b/tensorflow/js/ops/ts_op_gen.cc
@@ -0,0 +1,199 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/js/ops/ts_op_gen.h"
+#include <unordered_map>
+
+#include "tensorflow/core/framework/api_def.pb.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+namespace {
+
+static bool IsListAttr(const OpDef_ArgDef& arg) {
+ return !arg.type_list_attr().empty() || !arg.number_attr().empty();
+}
+
+// Struct to hold a combo OpDef and ArgDef for a given Op argument:
+struct ArgDefs {
+ ArgDefs(const OpDef::ArgDef& op_def_arg, const ApiDef::Arg& api_def_arg)
+ : op_def_arg(op_def_arg), api_def_arg(api_def_arg) {}
+
+ const OpDef::ArgDef& op_def_arg;
+ const ApiDef::Arg& api_def_arg;
+};
+
+// Helper class to generate TypeScript code for a given OpDef:
+class GenTypeScriptOp {
+ public:
+ GenTypeScriptOp(const OpDef& op_def, const ApiDef& api_def);
+ ~GenTypeScriptOp();
+
+ // Returns the generated code as a string:
+ string Code();
+
+ private:
+ void ProcessArgs();
+
+ void AddMethodSignature();
+ void AddMethodReturnAndClose();
+
+ const OpDef& op_def_;
+ const ApiDef& api_def_;
+
+ // Placeholder string for all generated code:
+ string result_;
+
+ // Holds in-order vector of Op inputs:
+ std::vector<ArgDefs> input_op_args_;
+
+ // Holds number of outputs:
+ int num_outputs_;
+};
+
+GenTypeScriptOp::GenTypeScriptOp(const OpDef& op_def, const ApiDef& api_def)
+ : op_def_(op_def), api_def_(api_def), num_outputs_(0) {}
+
+GenTypeScriptOp::~GenTypeScriptOp() {}
+
+string GenTypeScriptOp::Code() {
+ ProcessArgs();
+
+ // Generate exported function for Op:
+ AddMethodSignature();
+ AddMethodReturnAndClose();
+
+ strings::StrAppend(&result_, "\n");
+ return result_;
+}
+
+void GenTypeScriptOp::ProcessArgs() {
+ for (int i = 0; i < api_def_.arg_order_size(); i++) {
+ auto op_def_arg = FindInputArg(api_def_.arg_order(i), op_def_);
+ if (op_def_arg == nullptr) {
+ LOG(WARNING) << "Could not find OpDef::ArgDef for "
+ << api_def_.arg_order(i);
+ continue;
+ }
+ auto api_def_arg = FindInputArg(api_def_.arg_order(i), api_def_);
+ if (api_def_arg == nullptr) {
+ LOG(WARNING) << "Could not find ApiDef::Arg for "
+ << api_def_.arg_order(i);
+ continue;
+ }
+ input_op_args_.push_back(ArgDefs(*op_def_arg, *api_def_arg));
+ }
+
+ num_outputs_ = api_def_.out_arg_size();
+}
+
+void GenTypeScriptOp::AddMethodSignature() {
+ strings::StrAppend(&result_, "export function ", api_def_.endpoint(0).name(),
+ "(");
+
+ bool is_first = true;
+ for (auto& in_arg : input_op_args_) {
+ if (is_first) {
+ is_first = false;
+ } else {
+ strings::StrAppend(&result_, ", ");
+ }
+
+ auto op_def_arg = in_arg.op_def_arg;
+
+ strings::StrAppend(&result_, op_def_arg.name(), ": ");
+ if (IsListAttr(op_def_arg)) {
+ strings::StrAppend(&result_, "tfc.Tensor[]");
+ } else {
+ strings::StrAppend(&result_, "tfc.Tensor");
+ }
+ }
+
+ if (num_outputs_ == 1) {
+ strings::StrAppend(&result_, "): tfc.Tensor {\n");
+ } else {
+ strings::StrAppend(&result_, "): tfc.Tensor[] {\n");
+ }
+}
+
+void GenTypeScriptOp::AddMethodReturnAndClose() {
+ strings::StrAppend(&result_, " return null;\n}\n");
+}
+
+void WriteTSOp(const OpDef& op_def, const ApiDef& api_def, WritableFile* ts) {
+ GenTypeScriptOp ts_op(op_def, api_def);
+ TF_CHECK_OK(ts->Append(GenTypeScriptOp(op_def, api_def).Code()));
+}
+
+void StartFile(WritableFile* ts_file) {
+ const string header =
+ R"header(/**
+ * @license
+ * Copyright 2018 Google Inc. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+
+// This file is MACHINE GENERATED! Do not edit
+
+import * as tfc from '@tensorflow/tfjs-core';
+import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils';
+
+)header";
+
+ TF_CHECK_OK(ts_file->Append(header));
+}
+
+} // namespace
+
+void WriteTSOps(const OpList& ops, const ApiDefMap& api_def_map,
+ const string& ts_filename) {
+ Env* env = Env::Default();
+
+ std::unique_ptr<WritableFile> ts_file = nullptr;
+ TF_CHECK_OK(env->NewWritableFile(ts_filename, &ts_file));
+
+ StartFile(ts_file.get());
+
+ for (const auto& op_def : ops.op()) {
+ // Skip deprecated ops
+ if (op_def.has_deprecation() &&
+ op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) {
+ continue;
+ }
+
+ const auto* api_def = api_def_map.GetApiDef(op_def.name());
+ if (api_def->visibility() == ApiDef::VISIBLE) {
+ WriteTSOp(op_def, *api_def, ts_file.get());
+ }
+ }
+
+ TF_CHECK_OK(ts_file->Close());
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/delegates/eager/constants.h b/tensorflow/js/ops/ts_op_gen.h
index 7ed6ab7552..fcd46a17a7 100644
--- a/tensorflow/contrib/lite/delegates/eager/constants.h
+++ b/tensorflow/js/ops/ts_op_gen.h
@@ -12,18 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_CONSTANTS_H_
-#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_CONSTANTS_H_
-namespace tflite {
-namespace eager {
+#ifndef TENSORFLOW_JS_OPS_TS_OP_GEN_H_
+#define TENSORFLOW_JS_OPS_TS_OP_GEN_H_
-// The prefix of Eager op custom code.
-// This will be matched agains the `custom_code` field in `OperatorCode`
-// Flatbuffer Table.
-constexpr char kCustomCodePrefix[] = "Eager";
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/core/platform/types.h"
-} // namespace eager
-} // namespace tflite
+namespace tensorflow {
-#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_CONSTANTS_H_
+// Generated code is written to the file ts_filename:
+void WriteTSOps(const OpList& ops, const ApiDefMap& api_def_map,
+ const string& ts_filename);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_JS_OPS_TS_OP_GEN_H_
diff --git a/tensorflow/js/ops/ts_op_gen_test.cc b/tensorflow/js/ops/ts_op_gen_test.cc
new file mode 100644
index 0000000000..9a85c021b0
--- /dev/null
+++ b/tensorflow/js/ops/ts_op_gen_test.cc
@@ -0,0 +1,212 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/js/ops/ts_op_gen.h"
+
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+void ExpectContainsStr(StringPiece s, StringPiece expected) {
+ EXPECT_TRUE(str_util::StrContains(s, expected))
+ << "'" << s << "' does not contain '" << expected << "'";
+}
+
+void ExpectDoesNotContainStr(StringPiece s, StringPiece expected) {
+ EXPECT_FALSE(str_util::StrContains(s, expected))
+ << "'" << s << "' does not contain '" << expected << "'";
+}
+
+// TODO(kreeger): Add multiple outputs here?
+constexpr char kBaseOpDef[] = R"(
+op {
+ name: "Foo"
+ input_arg {
+ name: "images"
+ type_attr: "T"
+ number_attr: "N"
+ description: "Images to process."
+ }
+ input_arg {
+ name: "dim"
+ description: "Description for dim."
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "output"
+ description: "Description for output."
+ type: DT_FLOAT
+ }
+ attr {
+ name: "T"
+ type: "type"
+ description: "Type for images"
+ allowed_values {
+ list {
+ type: DT_UINT8
+ type: DT_INT8
+ }
+ }
+ default_value {
+ i: 1
+ }
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+ summary: "Summary for op Foo."
+ description: "Description for op Foo."
+}
+op {
+ name: "DeprecatedFoo"
+ input_arg {
+ name: "input"
+ description: "Description for input."
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "output"
+ description: "Description for output."
+ type: DT_FLOAT
+ }
+ deprecation {
+ explanation: "Deprecated."
+ }
+}
+op {
+ name: "MultiOutputFoo"
+ input_arg {
+ name: "input"
+ description: "Description for input."
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "output1"
+ description: "Description for output 1."
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "output2"
+ description: "Description for output 2."
+ type: DT_FLOAT
+ }
+ summary: "Summary for op MultiOutputFoo."
+ description: "Description for op MultiOutputFoo."
+}
+)";
+
+// Generate TypeScript code
+// @param api_def_str TODO doc me.
+void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) {
+ Env* env = Env::Default();
+ OpList op_defs;
+ protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
+ ApiDefMap api_def_map(op_defs);
+
+ if (!api_def_str.empty()) {
+ TF_ASSERT_OK(api_def_map.LoadApiDef(api_def_str));
+ }
+
+ const string& tmpdir = testing::TmpDir();
+ const auto ts_file_path = io::JoinPath(tmpdir, "test.ts");
+
+ WriteTSOps(op_defs, api_def_map, ts_file_path);
+ TF_ASSERT_OK(ReadFileToString(env, ts_file_path, ts_file_text));
+}
+
+TEST(TsOpGenTest, TestImports) {
+ string ts_file_text;
+ GenerateTsOpFileText("", &ts_file_text);
+
+ const string expected = R"(
+import * as tfc from '@tensorflow/tfjs-core';
+import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils';
+)";
+ ExpectContainsStr(ts_file_text, expected);
+}
+
+TEST(TsOpGenTest, InputSingleAndList) {
+ const string api_def = R"(
+op {
+ name: "Foo"
+ input_arg {
+ name: "images"
+ type_attr: "T"
+ number_attr: "N"
+ }
+}
+)";
+
+ string ts_file_text;
+ GenerateTsOpFileText(api_def, &ts_file_text);
+
+ const string expected = R"(
+export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
+ return null;
+}
+)";
+ ExpectContainsStr(ts_file_text, expected);
+}
+
+TEST(TsOpGenTest, TestVisibility) {
+ const string api_def = R"(
+op {
+ graph_op_name: "Foo"
+ visibility: HIDDEN
+}
+)";
+
+ string ts_file_text;
+ GenerateTsOpFileText(api_def, &ts_file_text);
+
+ const string expected = R"(
+export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
+ return null;
+}
+)";
+ ExpectDoesNotContainStr(ts_file_text, expected);
+}
+
+TEST(TsOpGenTest, SkipDeprecated) {
+ string ts_file_text;
+ GenerateTsOpFileText("", &ts_file_text);
+
+ ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo");
+}
+
+TEST(TsOpGenTest, MultiOutput) {
+ string ts_file_text;
+ GenerateTsOpFileText("", &ts_file_text);
+
+ const string expected = R"(
+export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] {
+ return null;
+}
+)";
+ ExpectContainsStr(ts_file_text, expected);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 5f985654f0..e6d78301a5 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -44,7 +44,10 @@ load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_mpi_deps")
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_gdr_deps")
load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
-load("//third_party/ngraph:build_defs.bzl","if_ngraph")
+load(
+ "//third_party/ngraph:build_defs.bzl",
+ "if_ngraph",
+)
py_library(
name = "python",
@@ -139,6 +142,7 @@ py_library(
"//tensorflow/python/ops/parallel_for",
"//tensorflow/python/profiler",
"//tensorflow/python/saved_model",
+ "//tensorflow/python/tools:component_api_helper",
"//third_party/py/numpy",
],
)
@@ -2780,11 +2784,13 @@ py_library(
srcs = ["ops/state_ops.py"],
srcs_version = "PY2AND3",
deps = [
+ ":array_ops",
":framework_ops",
+ ":math_ops_gen",
":resource_variable_ops_gen",
":state_ops_gen",
":tensor_shape",
- "//tensorflow/python/eager:context",
+ ":util",
],
)
@@ -3266,6 +3272,7 @@ py_library(
"@six_archive//:six",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/distribute:distribute_coordinator_context",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
# `layers` dependency only exists due to the use of a small utility.
@@ -3813,8 +3820,9 @@ tf_py_wrap_cc(
tf_additional_plugin_deps() +
tf_additional_verbs_deps() +
tf_additional_mpi_deps() +
- tf_additional_gdr_deps()) +
- if_ngraph(["@ngraph_tf//:ngraph_tf"])
+ tf_additional_gdr_deps()) + if_ngraph([
+ "@ngraph_tf//:ngraph_tf",
+ ]),
)
# ** Targets for Windows build (start) **
@@ -4660,7 +4668,10 @@ py_test(
size = "medium",
srcs = ["training/monitored_session_test.py"],
srcs_version = "PY2AND3",
- tags = ["notsan"], # b/67945581
+ tags = [
+ "no_pip",
+ "notsan", # b/67945581
+ ],
deps = [
":array_ops",
":checkpoint_management",
@@ -4678,6 +4689,7 @@ py_test(
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/testing:testing_py",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python/distribute:distribute_coordinator",
],
)
diff --git a/tensorflow/python/client/client_lib.py b/tensorflow/python/client/client_lib.py
index c94767a03c..80a256bf7a 100644
--- a/tensorflow/python/client/client_lib.py
+++ b/tensorflow/python/client/client_lib.py
@@ -15,7 +15,7 @@
"""Support for launching graphs and executing operations.
-See the @{$python/client} guide.
+See the [Client](https://tensorflow.org/api_guides/python/client) guide.
"""
from __future__ import absolute_import
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 28f26ad27e..1841dd998b 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -1500,7 +1500,7 @@ class Session(BaseSession):
Args:
target: (Optional.) The execution engine to connect to.
Defaults to using an in-process engine. See
- @{$distributed$Distributed TensorFlow}
+ [Distributed TensorFlow](https://tensorflow.org/deploy/distributed)
for more examples.
graph: (Optional.) The `Graph` to be launched (described above).
config: (Optional.) A
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 73adb7a559..e526bc89dd 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -14,8 +14,8 @@
# ==============================================================================
"""Utilities for API compatibility between TensorFlow release versions.
-See
-@{$guide/version_compat#backward_and_partial_forward_compatibility}
+See [Version
+Compatibility](https://tensorflow.org/guide/version_compat#backward_forward)
"""
from __future__ import absolute_import
@@ -26,14 +26,15 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 15)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 21)
@tf_export("compat.forward_compatible")
def forward_compatible(year, month, day):
"""Return true if the forward compatibility window has expired.
- See @{$guide/version_compat#backward_and_partial_forward_compatibility}.
+ See [Version
+ compatibility](https://tensorflow.org/guide/version_compat#backward_forward).
Forward-compatibility refers to scenarios where the producer of a TensorFlow
model (a GraphDef or SavedModel) is compiled against a version of the
@@ -91,7 +92,8 @@ def forward_compatible(year, month, day):
def forward_compatibility_horizon(year, month, day):
"""Context manager for testing forward compatibility of generated graphs.
- See @{$guide/version_compat#backward_and_partial_forward_compatibility}.
+ See [Version
+ compatibility](https://tensorflow.org/guide/version_compat#backward_forward).
To ensure forward compatibility of generated graphs (see `forward_compatible`)
with older binaries, new features can be gated with:
diff --git a/tensorflow/python/data/__init__.py b/tensorflow/python/data/__init__.py
index 3b9bf2469e..f8b561205e 100644
--- a/tensorflow/python/data/__init__.py
+++ b/tensorflow/python/data/__init__.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""`tf.data.Dataset` API for input pipelines.
-See @{$guide/datasets$Importing Data} for an overview.
+See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
"""
from __future__ import absolute_import
diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD
index 50ba5f403e..57517afae8 100644
--- a/tensorflow/python/data/ops/BUILD
+++ b/tensorflow/python/data/ops/BUILD
@@ -27,6 +27,7 @@ py_library(
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:random_seed",
"//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/util:structure",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 8ba98cb88d..fdab8abfae 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -225,7 +225,7 @@ class Dataset(object):
`tf.constant` operations. For large datasets (> 1 GB), this can waste
memory and run into byte limits of graph serialization. If tensors contains
one or more large NumPy arrays, consider the alternative described in
- @{$guide/datasets#consuming_numpy_arrays$this guide}.
+ [this guide](https://tensorflow.org/guide/datasets#consuming_numpy_arrays).
Args:
tensors: A nested structure of tensors.
@@ -244,7 +244,7 @@ class Dataset(object):
`tf.constant` operations. For large datasets (> 1 GB), this can waste
memory and run into byte limits of graph serialization. If tensors contains
one or more large NumPy arrays, consider the alternative described in
- @{$guide/datasets#consuming_numpy_arrays$this guide}.
+ [this guide](https://tensorflow.org/guide/datasets#consuming_numpy_arrays).
Args:
tensors: A nested structure of tensors, each having the same size in the
diff --git a/tensorflow/python/data/util/BUILD b/tensorflow/python/data/util/BUILD
index 5fcc62b60b..39082ce370 100644
--- a/tensorflow/python/data/util/BUILD
+++ b/tensorflow/python/data/util/BUILD
@@ -63,6 +63,41 @@ py_test(
)
py_library(
+ name = "structure",
+ srcs = ["structure.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":nest",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:ops",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:tensor_util",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_test(
+ name = "structure_test",
+ size = "small",
+ srcs = ["structure_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":nest",
+ ":structure",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:variables",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_library(
name = "convert",
srcs = ["convert.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py
new file mode 100644
index 0000000000..c5764b8dfe
--- /dev/null
+++ b/tensorflow/python/data/util/structure.py
@@ -0,0 +1,315 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for describing the structure of a `tf.data` type."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import sparse_ops
+
+
+class Structure(object):
+ """Represents structural information, such as type and shape, about a value.
+
+ A `Structure` generalizes the `tf.Tensor.dtype` and `tf.Tensor.shape`
+ properties, so that we can define generic containers of objects including:
+
+ * `tf.Tensor`
+ * `tf.SparseTensor`
+ * Nested structures of the above.
+
+ TODO(b/110122868): In the future, a single `Structure` will replace the
+ `tf.data.Dataset.output_types`, `tf.data.Dataset.output_shapes`,
+ and `tf.data.Dataset.output_classes`, and similar properties and arguments in
+ the `tf.data.Iterator` and `Optional` classes.
+ """
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractproperty
+ def _flat_shapes(self):
+ """A list of shapes matching the shapes of `self._to_tensor_list()`.
+
+ Returns:
+ A list of `tf.TensorShape` objects.
+ """
+ raise NotImplementedError("Structure._flat_shapes")
+
+ @abc.abstractproperty
+ def _flat_types(self):
+ """A list of types matching the types of `self._to_tensor_list()`.
+
+ Returns:
+ A list of `tf.DType` objects.
+ """
+ raise NotImplementedError("Structure._flat_shapes")
+
+ @abc.abstractmethod
+ def is_compatible_with(self, value):
+ """Returns `True` if `value` is compatible with this structure.
+
+ A value `value` is compatible with a structure `s` if
+ `Structure.from_value(value)` would return a structure `t` that is a
+ "subtype" of `s`. A structure `t` is a "subtype" of `s` if:
+
+ * `s` and `t` are instances of the same `Structure` subclass.
+ * The nested structures (if any) of `s` and `t` are the same, according to
+ `tf.contrib.framework.nest.assert_same_structure`, and each nested
+ structure of `t` is a "subtype" of the corresponding nested structure of
+ `s`.
+ * Any `tf.DType` components of `t` are the same as the corresponding
+ components in `s`.
+ * Any `tf.TensorShape` components of `t` are compatible with the
+ corresponding components in `s`, according to
+ `tf.TensorShape.is_compatible_with`.
+
+ Args:
+ value: A potentially structured value.
+
+ Returns:
+ `True` if `value` matches this structure, otherwise `False`.
+ """
+ raise NotImplementedError("Structure.is_compatible_with()")
+
+ @abc.abstractmethod
+ def _to_tensor_list(self, value):
+ """Returns a flat list of `tf.Tensor` representing `value`.
+
+ This method can be used, along with `self._flat_shapes` and
+ `self._flat_types` to represent structured values in lower level APIs
+ (such as plain TensorFlow operations) that do not understand structure.
+
+ Requires: `self.is_compatible_with(value)`.
+
+ Args:
+ value: A value with compatible structure.
+
+ Returns:
+ A flat list of `tf.Tensor` representing `value`.
+ """
+ raise NotImplementedError("Structure._to_tensor_list()")
+
+ @abc.abstractmethod
+ def _from_tensor_list(self, flat_value):
+ """Builds a flat list of `tf.Tensor` into a value matching this structure.
+
+ Requires: The shapes and types of the tensors in `flat_value` must be
+ compatible with `self._flat_shapes` and `self._flat_types` respectively.
+
+ Args:
+ flat_value: A list of `tf.Tensor` with compatible flat structure.
+
+ Returns:
+ A structured object matching this structure.
+ """
+ raise NotImplementedError("Structure._from_tensor_list()")
+
+ @staticmethod
+ def from_value(value):
+ """Returns a `Structure` that represents the given `value`.
+
+ Args:
+ value: A potentially structured value.
+
+ Returns:
+ A `Structure` that is compatible with `value`.
+
+ Raises:
+ TypeError: If a structure cannot be built for `value`, because its type
+ or one of its component types is not supported.
+ """
+
+ # TODO(b/110122868): Add support for custom types, Dataset, and Optional
+ # to this method.
+ if isinstance(
+ value,
+ (sparse_tensor_lib.SparseTensor, sparse_tensor_lib.SparseTensorValue)):
+ return SparseTensorStructure.from_value(value)
+ elif isinstance(value, (tuple, dict)):
+ return NestedStructure.from_value(value)
+ else:
+ try:
+ tensor = ops.convert_to_tensor(value)
+ except (ValueError, TypeError):
+ raise TypeError("Could not build a structure for %r" % value)
+ return TensorStructure.from_value(tensor)
+
+
+# NOTE(mrry): The following classes make extensive use of non-public methods of
+# their base class, so we disable the protected-access lint warning once here.
+# pylint: disable=protected-access
+class NestedStructure(Structure):
+ """Represents a nested structure in which each leaf is a `Structure`."""
+
+ def __init__(self, nested_structure):
+ self._nested_structure = nested_structure
+ self._flat_shapes_list = []
+ self._flat_types_list = []
+ for s in nest.flatten(nested_structure):
+ if not isinstance(s, Structure):
+ raise TypeError("nested_structure must be a (potentially nested) tuple "
+ "or dictionary of Structure objects.")
+ self._flat_shapes_list.extend(s._flat_shapes)
+ self._flat_types_list.extend(s._flat_types)
+
+ @property
+ def _flat_shapes(self):
+ return self._flat_shapes_list
+
+ @property
+ def _flat_types(self):
+ return self._flat_types_list
+
+ def is_compatible_with(self, value):
+ try:
+ nest.assert_shallow_structure(self._nested_structure, value)
+ except (ValueError, TypeError):
+ return False
+
+ return all(
+ s.is_compatible_with(v) for s, v in zip(
+ nest.flatten(self._nested_structure),
+ nest.flatten_up_to(self._nested_structure, value)))
+
+ def _to_tensor_list(self, value):
+ ret = []
+
+ try:
+ flat_value = nest.flatten_up_to(self._nested_structure, value)
+ except (ValueError, TypeError):
+ raise ValueError("The value %r is not compatible with the nested "
+ "structure %r." % (value, self._nested_structure))
+
+ for sub_value, structure in zip(flat_value,
+ nest.flatten(self._nested_structure)):
+ if not structure.is_compatible_with(sub_value):
+ raise ValueError("Component value %r is not compatible with the nested "
+ "structure %r." % (sub_value, structure))
+ ret.extend(structure._to_tensor_list(sub_value))
+ return ret
+
+ def _from_tensor_list(self, flat_value):
+ if len(flat_value) != len(self._flat_types):
+ raise ValueError("Expected %d flat values in NestedStructure but got %d."
+ % (len(self._flat_types), len(flat_value)))
+
+ flat_ret = []
+ for sub_value, structure in zip(flat_value,
+ nest.flatten(self._nested_structure)):
+ flat_ret.append(structure._from_tensor_list([sub_value]))
+
+ return nest.pack_sequence_as(self._nested_structure, flat_ret)
+
+ @staticmethod
+ def from_value(value):
+ flat_nested_structure = [
+ Structure.from_value(sub_value) for sub_value in nest.flatten(value)
+ ]
+ return NestedStructure(nest.pack_sequence_as(value, flat_nested_structure))
+
+
+class TensorStructure(Structure):
+ """Represents structural information about a `tf.Tensor`."""
+
+ def __init__(self, dtype, shape):
+ self._dtype = dtypes.as_dtype(dtype)
+ self._shape = tensor_shape.as_shape(shape)
+
+ @property
+ def _flat_shapes(self):
+ return [self._shape]
+
+ @property
+ def _flat_types(self):
+ return [self._dtype]
+
+ def is_compatible_with(self, value):
+ try:
+ value = ops.convert_to_tensor(value, dtype=self._dtype)
+ except (ValueError, TypeError):
+ return False
+
+ return (self._dtype.is_compatible_with(value.dtype) and
+ self._shape.is_compatible_with(value.shape))
+
+ def _to_tensor_list(self, value):
+ if not self.is_compatible_with(value):
+ raise ValueError("Value %r is not convertible to a tensor with dtype %s "
+ "and shape %s." % (value, self._dtype, self._shape))
+ return [value]
+
+ def _from_tensor_list(self, flat_value):
+ if len(flat_value) != 1:
+ raise ValueError("TensorStructure corresponds to a single tf.Tensor.")
+ if not self.is_compatible_with(flat_value[0]):
+ raise ValueError("Cannot convert %r to a tensor with dtype %s and shape "
+ "%s." % (flat_value[0], self._dtype, self._shape))
+ return flat_value[0]
+
+ @staticmethod
+ def from_value(value):
+ return TensorStructure(value.dtype, value.shape)
+
+
+class SparseTensorStructure(Structure):
+ """Represents structural information about a `tf.SparseTensor`."""
+
+ def __init__(self, dtype, dense_shape):
+ self._dtype = dtypes.as_dtype(dtype)
+ self._dense_shape = tensor_shape.as_shape(dense_shape)
+
+ @property
+ def _flat_shapes(self):
+ return [tensor_shape.vector(3)]
+
+ @property
+ def _flat_types(self):
+ return [dtypes.variant]
+
+ def is_compatible_with(self, value):
+ try:
+ value = sparse_tensor_lib.SparseTensor.from_value(value)
+ except TypeError:
+ return False
+ return (isinstance(value, (sparse_tensor_lib.SparseTensor,
+ sparse_tensor_lib.SparseTensorValue)) and
+ self._dtype.is_compatible_with(value.dtype) and
+ self._dense_shape.is_compatible_with(
+ tensor_util.constant_value_as_shape(value.dense_shape)))
+
+ def _to_tensor_list(self, value):
+ return [sparse_ops.serialize_sparse(value, out_type=dtypes.variant)]
+
+ def _from_tensor_list(self, flat_value):
+ if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
+ not flat_value[0].shape.is_compatible_with(tensor_shape.vector(3))):
+ raise ValueError("SparseTensorStructure corresponds to a single "
+ "tf.variant vector of length 3.")
+ return sparse_ops.deserialize_sparse(
+ flat_value[0], dtype=self._dtype, rank=self._dense_shape.ndims)
+
+ @staticmethod
+ def from_value(value):
+ sparse_tensor = sparse_tensor_lib.SparseTensor.from_value(value)
+ return SparseTensorStructure(
+ sparse_tensor.dtype,
+ tensor_util.constant_value_as_shape(sparse_tensor.dense_shape))
diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py
new file mode 100644
index 0000000000..d0c7df67ae
--- /dev/null
+++ b/tensorflow/python/data/util/structure_test.py
@@ -0,0 +1,327 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for utilities working with arbitrarily nested structures."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import structure
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class StructureTest(test.TestCase, parameterized.TestCase):
+ # pylint disable=protected-access
+
+ @parameterized.parameters(
+ (constant_op.constant(37.0), structure.TensorStructure, [dtypes.float32],
+ [[]]), (sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
+ structure.SparseTensorStructure, [dtypes.variant], [[3]]),
+ ((constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
+ structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), ({
+ "a": constant_op.constant(37.0),
+ "b": constant_op.constant([1, 2, 3])
+ }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
+ ({
+ "a":
+ constant_op.constant(37.0),
+ "b": (sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
+ sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
+ }, structure.NestedStructure,
+ [dtypes.float32, dtypes.variant, dtypes.variant], [[], [3], [3]]))
+ def testFlatStructure(self, value, expected_structure, expected_types,
+ expected_shapes):
+ s = structure.Structure.from_value(value)
+ self.assertIsInstance(s, expected_structure)
+ self.assertEqual(expected_types, s._flat_types)
+ self.assertEqual(expected_shapes, s._flat_shapes)
+
+ @parameterized.parameters(
+ (constant_op.constant(37.0), [
+ constant_op.constant(38.0),
+ array_ops.placeholder(dtypes.float32),
+ variables.Variable(100.0), 42.0,
+ np.array(42.0, dtype=np.float32)
+ ], [constant_op.constant([1.0, 2.0]),
+ constant_op.constant(37)]),
+ (sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
+ [
+ sparse_tensor.SparseTensor(
+ indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]),
+ sparse_tensor.SparseTensorValue(
+ indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]),
+ array_ops.sparse_placeholder(dtype=dtypes.int32),
+ array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None])
+ ], [
+ constant_op.constant(37, shape=[4, 5]),
+ sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[5, 6]),
+ array_ops.sparse_placeholder(
+ dtype=dtypes.int32, shape=[None, None, None]),
+ sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5])
+ ]),
+ ({
+ "a": constant_op.constant(37.0),
+ "b": constant_op.constant([1, 2, 3])
+ }, [{
+ "a": constant_op.constant(15.0),
+ "b": constant_op.constant([4, 5, 6])
+ }], [{
+ "a": constant_op.constant(15.0),
+ "b": constant_op.constant([4, 5, 6, 7])
+ }, {
+ "a": constant_op.constant(15),
+ "b": constant_op.constant([4, 5, 6])
+ }, {
+ "a":
+ constant_op.constant(15),
+ "b":
+ sparse_tensor.SparseTensor(
+ indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
+ }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
+ )
+ def testIsCompatibleWith(self, original_value, compatible_values,
+ incompatible_values):
+ s = structure.Structure.from_value(original_value)
+ for compatible_value in compatible_values:
+ self.assertTrue(s.is_compatible_with(compatible_value))
+ for incompatible_value in incompatible_values:
+ self.assertFalse(s.is_compatible_with(incompatible_value))
+
+ # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they
+ # will be executed before the (eager- or graph-mode) test environment has been
+ # set up.
+ # pylint: disable=g-long-lambda
+ @parameterized.parameters(
+ (lambda: constant_op.constant(37.0),),
+ (lambda: sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),),
+ (lambda: {"a": constant_op.constant(37.0),
+ "b": constant_op.constant([1, 2, 3])},),
+ (lambda: {"a": constant_op.constant(37.0),
+ "b": (sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
+ sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
+ },),
+ )
+ def testRoundTripConversion(self, value_fn):
+ value = value_fn()
+ s = structure.Structure.from_value(value)
+ before = self.evaluate(value)
+ after = self.evaluate(s._from_tensor_list(s._to_tensor_list(value)))
+
+ flat_before = nest.flatten(before)
+ flat_after = nest.flatten(after)
+ for b, a in zip(flat_before, flat_after):
+ if isinstance(b, sparse_tensor.SparseTensorValue):
+ self.assertAllEqual(b.indices, a.indices)
+ self.assertAllEqual(b.values, a.values)
+ self.assertAllEqual(b.dense_shape, a.dense_shape)
+ else:
+ self.assertAllEqual(b, a)
+ # pylint: enable=g-long-lambda
+
+ def testIncompatibleStructure(self):
+ # Define three mutually incompatible values/structures, and assert that:
+ # 1. Using one structure to flatten a value with an incompatible structure
+ # fails.
+ # 2. Using one structure to restructre a flattened value with an
+ # incompatible structure fails.
+ value_tensor = constant_op.constant(42.0)
+ s_tensor = structure.Structure.from_value(value_tensor)
+ flat_tensor = s_tensor._to_tensor_list(value_tensor)
+
+ value_sparse_tensor = sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1])
+ s_sparse_tensor = structure.Structure.from_value(value_sparse_tensor)
+ flat_sparse_tensor = s_sparse_tensor._to_tensor_list(value_sparse_tensor)
+
+ value_nest = {
+ "a": constant_op.constant(37.0),
+ "b": constant_op.constant([1, 2, 3])
+ }
+ s_nest = structure.Structure.from_value(value_nest)
+ flat_nest = s_nest._to_tensor_list(value_nest)
+
+ with self.assertRaisesRegexp(
+ ValueError, r"SparseTensor.* is not convertible to a tensor with "
+ r"dtype.*float32.* and shape \(\)"):
+ s_tensor._to_tensor_list(value_sparse_tensor)
+ with self.assertRaisesRegexp(
+ ValueError, r"Value \{.*\} is not convertible to a tensor with "
+ r"dtype.*float32.* and shape \(\)"):
+ s_tensor._to_tensor_list(value_nest)
+
+ with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"):
+ s_sparse_tensor._to_tensor_list(value_tensor)
+
+ with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"):
+ s_sparse_tensor._to_tensor_list(value_nest)
+
+ with self.assertRaisesRegexp(
+ ValueError, "Tensor.* not compatible with the nested structure "
+ ".*TensorStructure.*TensorStructure"):
+ s_nest._to_tensor_list(value_tensor)
+
+ with self.assertRaisesRegexp(
+ ValueError, "SparseTensor.* not compatible with the nested structure "
+ ".*TensorStructure.*TensorStructure"):
+ s_nest._to_tensor_list(value_sparse_tensor)
+
+ with self.assertRaisesRegexp(
+ ValueError, r"Cannot convert.*with dtype.*float32.* and shape \(\)"):
+ s_tensor._from_tensor_list(flat_sparse_tensor)
+
+ with self.assertRaisesRegexp(
+ ValueError, "TensorStructure corresponds to a single tf.Tensor."):
+ s_tensor._from_tensor_list(flat_nest)
+
+ with self.assertRaisesRegexp(
+ ValueError, "SparseTensorStructure corresponds to a single tf.variant "
+ "vector of length 3."):
+ s_sparse_tensor._from_tensor_list(flat_tensor)
+
+ with self.assertRaisesRegexp(
+ ValueError, "SparseTensorStructure corresponds to a single tf.variant "
+ "vector of length 3."):
+ s_sparse_tensor._from_tensor_list(flat_nest)
+
+ with self.assertRaisesRegexp(
+ ValueError, "Expected 2 flat values in NestedStructure but got 1."):
+ s_nest._from_tensor_list(flat_tensor)
+
+ with self.assertRaisesRegexp(
+ ValueError, "Expected 2 flat values in NestedStructure but got 1."):
+ s_nest._from_tensor_list(flat_sparse_tensor)
+
+ def testIncompatibleNestedStructure(self):
+ # Define three mutually incompatible nested values/structures, and assert
+ # that:
+ # 1. Using one structure to flatten a value with an incompatible structure
+ # fails.
+ # 2. Using one structure to restructre a flattened value with an
+ # incompatible structure fails.
+
+ value_0 = {
+ "a": constant_op.constant(37.0),
+ "b": constant_op.constant([1, 2, 3])
+ }
+ s_0 = structure.Structure.from_value(value_0)
+ flat_s_0 = s_0._to_tensor_list(value_0)
+
+ # `value_1` has compatible nested structure with `value_0`, but different
+ # classes.
+ value_1 = {
+ "a":
+ constant_op.constant(37.0),
+ "b":
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1])
+ }
+ s_1 = structure.Structure.from_value(value_1)
+ flat_s_1 = s_1._to_tensor_list(value_1)
+
+ # `value_2` has incompatible nested structure with `value_0` and `value_1`.
+ value_2 = {
+ "a":
+ constant_op.constant(37.0),
+ "b": (sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
+ sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
+ }
+ s_2 = structure.Structure.from_value(value_2)
+ flat_s_2 = s_2._to_tensor_list(value_2)
+
+ with self.assertRaisesRegexp(
+ ValueError, "SparseTensor.* not compatible with the nested structure "
+ ".*TensorStructure"):
+ s_0._to_tensor_list(value_1)
+
+ with self.assertRaisesRegexp(
+ ValueError, "SparseTensor.*SparseTensor.* not compatible with the "
+ "nested structure .*TensorStructure"):
+ s_0._to_tensor_list(value_2)
+
+ with self.assertRaisesRegexp(
+ ValueError, "Tensor.* not compatible with the nested structure "
+ ".*SparseTensorStructure"):
+ s_1._to_tensor_list(value_0)
+
+ with self.assertRaisesRegexp(
+ ValueError, "SparseTensor.*SparseTensor.* not compatible with the "
+ "nested structure .*TensorStructure"):
+ s_0._to_tensor_list(value_2)
+
+ # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
+ # needs to account for "a" coming before or after "b". It might be worth
+ # adding a deterministic repr for these error messages (among other
+ # improvements).
+ with self.assertRaisesRegexp(
+ ValueError, "Tensor.*Tensor.* not compatible with the nested structure "
+ ".*(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
+ "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"):
+ s_2._to_tensor_list(value_0)
+
+ with self.assertRaisesRegexp(
+ ValueError, "(Tensor.*SparseTensor|SparseTensor.*Tensor).* "
+ "not compatible with the nested structure .*"
+ "(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
+ "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"):
+ s_2._to_tensor_list(value_1)
+
+ with self.assertRaisesRegexp(
+ ValueError, r"Cannot convert.*with dtype.*int32.* and shape \(3,\)"):
+ s_0._from_tensor_list(flat_s_1)
+
+ with self.assertRaisesRegexp(
+ ValueError, "Expected 2 flat values in NestedStructure but got 3."):
+ s_0._from_tensor_list(flat_s_2)
+
+ with self.assertRaisesRegexp(
+ ValueError, "SparseTensorStructure corresponds to a single tf.variant "
+ "vector of length 3."):
+ s_1._from_tensor_list(flat_s_0)
+
+ with self.assertRaisesRegexp(
+ ValueError, "Expected 2 flat values in NestedStructure but got 3."):
+ s_1._from_tensor_list(flat_s_2)
+
+ with self.assertRaisesRegexp(
+ ValueError, "Expected 3 flat values in NestedStructure but got 2."):
+ s_2._from_tensor_list(flat_s_0)
+
+ with self.assertRaisesRegexp(
+ ValueError, "Expected 3 flat values in NestedStructure but got 2."):
+ s_2._from_tensor_list(flat_s_1)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/debug/__init__.py b/tensorflow/python/debug/__init__.py
index 34da44b60d..242215dccb 100644
--- a/tensorflow/python/debug/__init__.py
+++ b/tensorflow/python/debug/__init__.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""Public Python API of TensorFlow Debugger (tfdbg).
-See the @{$python/tfdbg} guide.
+See the [TFDBG](https://tensorflow.org/api_guides/python/tfdbg) guide.
@@add_debug_tensor_watch
@@watch_graph
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 68d8b8d13b..98ef9bf492 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -41,3 +41,43 @@ py_test(
"//tensorflow/python:variables",
],
)
+
+py_library(
+ name = "distribute_coordinator_context",
+ srcs = [
+ "distribute_coordinator_context.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [],
+)
+
+py_library(
+ name = "multi_worker_util",
+ srcs = [
+ "multi_worker_util.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:training",
+ ],
+)
+
+py_test(
+ name = "multi_worker_util_test",
+ srcs = ["multi_worker_util_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":multi_worker_util",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python/eager:test",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index fc9ca4ac4a..eb081b65fc 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""A unified and split coordinator for distributed TensorFlow."""
+"""A component for running distributed TensorFlow."""
from __future__ import absolute_import
from __future__ import division
@@ -24,6 +24,8 @@ import os
import threading
from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.distribute import distribute_coordinator_context
+from tensorflow.python.training import monitored_session
from tensorflow.python.training import server_lib
@@ -43,23 +45,12 @@ class CoordinatorMode(object):
# client and connects to remote servers for training. Each remote server can
# use the distribute coordinator binary with task_type set correctly which
# will then turn into standard servers.
- SPLIT_CLIENT = 0
+ STANDALONE_CLIENT = "standalone_client"
# The distribute coordinator runs on each worker. It will run a standard
# server on each worker and optionally run the `worker_fn` that is configured
# to talk to its standard server.
- INDEPENDENT_WORKER = 1
-
-
-_worker_context = threading.local()
-
-
-def get_current_worker_context():
- """Returns the current task context."""
- try:
- return _worker_context.current
- except AttributeError:
- return None
+ INDEPENDENT_WORKER = "independent_worker"
class _Barrier(object):
@@ -113,14 +104,17 @@ class _WorkerContext(object):
"""
def __init__(self,
+ strategy,
cluster_spec,
task_type,
task_id,
+ session_config=None,
rpc_layer="grpc",
worker_barrier=None):
"""Initialize the worker context object.
Args:
+ strategy: a `DistributionStrategy` object.
cluster_spec: a ClusterSpec object. It can be empty or None in the local
training case.
task_type: a string indicating the role of the corresponding task, such as
@@ -128,14 +122,17 @@ class _WorkerContext(object):
replicated training.
task_id: an integer indicating id of the corresponding task. It can be
None if it is local training or in-graph replicated training.
+ session_config: an optional @{tf.ConfigProto} object.
rpc_layer: optional string specifying the RPC protocol for communication
with worker masters. If None or empty, hosts in the `cluster_spec` will
be used directly.
worker_barrier: optional, the barrier object for worker synchronization.
"""
+ self._strategy = strategy
self._cluster_spec = cluster_spec
self._task_type = task_type
self._task_id = task_id
+ self._session_config = session_config
self._worker_barrier = worker_barrier
self._rpc_layer = rpc_layer
self._master_target = self._get_master_target()
@@ -143,26 +140,31 @@ class _WorkerContext(object):
self._is_chief_node = self._is_chief()
def _debug_message(self):
- return "[cluster_spec: %r, task_type: %r, task_id: %r]" % (
- self._cluster_spec, self.task_type, self.task_id)
+ if self._cluster_spec:
+ return "[cluster_spec: %r, task_type: %r, task_id: %r]" % (
+ self._cluster_spec, self.task_type, self.task_id)
+ else:
+ return "[local]"
def __enter__(self):
- old_context = get_current_worker_context()
+ old_context = distribute_coordinator_context.get_current_worker_context()
if old_context:
raise ValueError(
"You cannot run distribute coordinator in a `worker_fn`.\t" +
self._debug_message())
- _worker_context.current = self
+ # pylint: disable=protected-access
+ distribute_coordinator_context._worker_context.current = self
def __exit__(self, unused_exception_type, unused_exception_value,
unused_traceback):
- _worker_context.current = None
+ # pylint: disable=protected-access
+ distribute_coordinator_context._worker_context.current = None
def _get_master_target(self):
"""Return the master target for a task."""
# If cluster_spec is None or empty, we use local master.
if not self._cluster_spec:
- return "local"
+ return ""
# If task_type is None, then it is in-graph replicated training. In this
# case we use the chief or first worker's master target.
@@ -207,6 +209,47 @@ class _WorkerContext(object):
self._debug_message())
self._worker_barrier.wait()
+ def session_creator(self,
+ scaffold=None,
+ config=None,
+ checkpoint_dir=None,
+ checkpoint_filename_with_path=None,
+ max_wait_secs=7200):
+ """Returns a session creator.
+
+ The returned session creator will be configured with the correct master
+ target and session configs. It will also run either init ops or ready ops
+ by querying the `strategy` object when `create_session` is called on it.
+
+ Args:
+ scaffold: A `Scaffold` used for gathering or building supportive ops. If
+ not specified a default one is created. It's used to finalize the graph.
+ config: `ConfigProto` proto used to configure the session.
+ checkpoint_dir: A string. Optional path to a directory where to restore
+ variables.
+ checkpoint_filename_with_path: Full file name path to the checkpoint file.
+ Only one of `checkpoint_dir` or `checkpoint_filename_with_path` can be
+ specified.
+ max_wait_secs: Maximum time to wait for the session to become available.
+
+ Returns:
+ a descendant of SessionCreator.
+ """
+ # TODO(yuefengz): merge session config.
+ if self._strategy.should_init:
+ return monitored_session.ChiefSessionCreator(
+ scaffold,
+ master=self.master_target,
+ config=config or self._session_config,
+ checkpoint_dir=checkpoint_dir,
+ checkpoint_filename_with_path=checkpoint_filename_with_path)
+ else:
+ return monitored_session.WorkerSessionCreator(
+ scaffold,
+ master=self.master_target,
+ config=config or self._session_config,
+ max_wait_secs=max_wait_secs)
+
@property
def has_barrier(self):
"""Whether the barrier is set or not."""
@@ -247,21 +290,38 @@ class _WorkerContext(object):
"""Returns number of workers in the cluster, including chief."""
return self._num_workers
+ @property
+ def should_checkpoint(self):
+ """Whether to save checkpoint."""
+ return self._strategy.should_checkpoint
+
+ @property
+ def should_save_summary(self):
+ """Whether to save summaries."""
+ return self._strategy.should_save_summary
+
def _run_single_worker(worker_fn,
+ strategy,
cluster_spec,
task_type,
task_id,
- rpc_layer,
+ session_config,
+ rpc_layer="",
worker_barrier=None):
"""Runs a single worker by calling `worker_fn` under context."""
- with _WorkerContext(
+ strategy = copy.deepcopy(strategy)
+ strategy.configure(session_config, cluster_spec, task_type, task_id)
+ context = _WorkerContext(
+ strategy,
cluster_spec,
task_type,
task_id,
+ session_config=session_config,
rpc_layer=rpc_layer,
- worker_barrier=worker_barrier):
- worker_fn()
+ worker_barrier=worker_barrier)
+ with context:
+ worker_fn(strategy)
def _run_std_server(cluster_spec=None,
@@ -280,13 +340,15 @@ def _run_std_server(cluster_spec=None,
return server
-def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
+def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config,
+ rpc_layer):
"""Runs a standalone client for between-graph replication."""
eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread(
target=_run_single_worker,
- args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0),
+ args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
+ session_config),
kwargs={
"rpc_layer": rpc_layer,
})
@@ -298,7 +360,8 @@ def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
t = threading.Thread(
target=_run_single_worker,
- args=(worker_fn, cluster_spec, task_type, task_id),
+ args=(worker_fn, strategy, cluster_spec, task_type, task_id,
+ session_config),
kwargs={
"rpc_layer": rpc_layer,
"worker_barrier": worker_barrier
@@ -315,43 +378,53 @@ def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
eval_thread.join()
-def _run_in_graph_client(worker_fn, cluster_spec, rpc_layer):
+def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
+ rpc_layer):
"""Runs a standalone client for in-graph replication."""
eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread(
target=_run_single_worker,
- args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0),
+ args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
+ session_config),
kwargs={
"rpc_layer": rpc_layer,
})
eval_thread.start()
- _run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer)
+ _run_single_worker(
+ worker_fn,
+ strategy,
+ cluster_spec,
+ None,
+ None,
+ session_config,
+ rpc_layer=rpc_layer)
if eval_thread:
eval_thread.join()
-
-# TODO(yuefengz): propagate cluster_spec in the SPLIT_CLIENT mode.
+# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode.
# TODO(yuefengz): we may need a smart way to figure out whether the current task
# is the special task when we support cluster_spec propagation.
def run_distribute_coordinator(worker_fn,
- mode=CoordinatorMode.SPLIT_CLIENT,
+ strategy,
+ mode=CoordinatorMode.STANDALONE_CLIENT,
cluster_spec=None,
task_type=None,
task_id=None,
- between_graph=False,
+ session_config=None,
rpc_layer="grpc"):
"""Runs the coordinator for distributed TensorFlow.
This function runs a split coordinator for distributed TensorFlow in its
- default mode, i.e the SPLIT_CLIENT mode. Given a `cluster_spec` specifying
- server addresses and their roles in a cluster, this coordinator will figure
- out how to set them up, give the underlying function the right targets for
- master sessions via a scope object and coordinate their training. The cluster
- consisting of standard servers needs to be brought up either with the standard
- server binary or with a binary running distribute coordinator with `task_type`
- set to non-client type which will then turn into standard servers.
+ default mode, i.e the STANDALONE_CLIENT mode. Given a `cluster_spec`
+ specifying server addresses and their roles in a cluster, this coordinator
+ will figure out how to set them up, give the underlying function the right
+ targets for master sessions via a scope object and coordinate their training.
+ The cluster consisting of standard servers needs to be brought up either with
+ the standard server binary or with a binary running distribute coordinator
+ with `task_type` set to non-client type which will then turn into standard
+ servers.
In addition to be the distribute coordinator, this is also the source of
configurations for each job in the distributed training. As there are multiple
@@ -370,6 +443,14 @@ def run_distribute_coordinator(worker_fn,
`worker_fn` depending whether it is between-graph training or in-graph
replicated training.
+ The `strategy` object is expected to be a DistributionStrategy object which
+ has implemented methods needed by distributed coordinator such as
+ `configure(session_config, cluster_spec, task_type, task_id)` which configures
+ the strategy object for a specific task and `should_init` property which
+ instructs the distribute coordinator whether to run init ops for a task. The
+ distribute coordinator will make a copy of the `strategy` object, call its
+ `configure` method and pass it to `worker_fn` as an argument.
+
The `worker_fn` defines the training logic and is called under a its own
worker context which can be accessed to via `get_current_worker_context`. A
worker context provides access to configurations for each task, e.g. the
@@ -413,16 +494,20 @@ def run_distribute_coordinator(worker_fn,
evaluation.
Args:
- worker_fn: the function to be called and given the access to a coordinator
- context object.
+ worker_fn: the function to be called. The function should accept a
+ `strategy` object and will be given access to a context object via a
+ context manager scope.
+ strategy: a DistributionStrategy object which specifying whether it should
+ run between-graph replicated training or not, whether to run init ops,
+ etc. This object will also be configured given `session_config`,
+ `cluster_spc`, `task_type` and `task_id`.
mode: in which mode this distribute coordinator runs.
cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
in a cluster. If not set or empty, fall back to local training.
task_type: the current task type, optional if this is a client.
task_id: the current task id, optional if this is a client.
- between_graph: a boolean. It is only useful when `cluster_spec` is set and
- not empty. If true, it will use between-graph replicated training;
- otherwise it will use in-graph replicated training.
+ session_config: an optional @{tf.ConfigProto} object which will be passed
+ to `strategy`'s `configure` method and used to create a session.
rpc_layer: optional string, the protocol for RPC, e.g. "grpc".
Raises:
@@ -448,15 +533,18 @@ def run_distribute_coordinator(worker_fn,
if not cluster_spec:
# `mode` is ignored in the local case.
- _run_single_worker(worker_fn, None, None, None, rpc_layer)
- elif mode == CoordinatorMode.SPLIT_CLIENT:
+ _run_single_worker(worker_fn, strategy, None, None, None, session_config,
+ rpc_layer)
+ elif mode == CoordinatorMode.STANDALONE_CLIENT:
# The client must know the cluster but servers in the cluster don't have to
# know the client.
if task_type in [_TaskType.CLIENT, None]:
- if between_graph:
- _run_between_graph_client(worker_fn, cluster_spec, rpc_layer)
+ if strategy.between_graph:
+ _run_between_graph_client(worker_fn, strategy, cluster_spec,
+ session_config, rpc_layer)
else:
- _run_in_graph_client(worker_fn, cluster_spec, rpc_layer)
+ _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
+ rpc_layer)
else:
# If not a client job, run the standard server.
server = _run_std_server(
@@ -471,19 +559,21 @@ def run_distribute_coordinator(worker_fn,
cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
- if between_graph:
+ if strategy.between_graph:
# All jobs run `worker_fn` if between-graph.
- _run_single_worker(worker_fn, cluster_spec, task_type, task_id,
- rpc_layer)
+ _run_single_worker(worker_fn, strategy, cluster_spec, task_type,
+ task_id, session_config, rpc_layer)
else:
# Only one node runs `worker_fn` if in-graph.
- context = _WorkerContext(cluster_spec, task_type, task_id, rpc_layer)
+ context = _WorkerContext(strategy, cluster_spec, task_type, task_id)
if context.is_chief:
- _run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer)
+ _run_single_worker(worker_fn, strategy, cluster_spec, None, None,
+ session_config, rpc_layer)
else:
server.join()
elif task_type == _TaskType.EVALUATOR:
- _run_single_worker(worker_fn, cluster_spec, task_type, task_id, rpc_layer)
+ _run_single_worker(worker_fn, strategy, cluster_spec, task_type, task_id,
+ session_config, rpc_layer)
else:
if task_type != _TaskType.PS:
raise ValueError("Unexpected task_type: %r" % task_type)
diff --git a/tensorflow/python/distribute/distribute_coordinator_context.py b/tensorflow/python/distribute/distribute_coordinator_context.py
new file mode 100644
index 0000000000..dee65ce883
--- /dev/null
+++ b/tensorflow/python/distribute/distribute_coordinator_context.py
@@ -0,0 +1,31 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The context retrieval method for distribute coordinator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+_worker_context = threading.local()
+
+
+def get_current_worker_context():
+ """Returns the current task context."""
+ try:
+ return _worker_context.current
+ except AttributeError:
+ return None
diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py
index 319c29ba2f..97c6bdd15a 100644
--- a/tensorflow/python/distribute/distribute_coordinator_test.py
+++ b/tensorflow/python/distribute/distribute_coordinator_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for distribute coordinator."""
+"""Tests for Distribute Coordinator."""
from __future__ import absolute_import
from __future__ import division
@@ -37,6 +37,7 @@ except ImportError as _error:
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator
+from tensorflow.python.distribute import distribute_coordinator_context
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
@@ -44,17 +45,17 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training import monitored_session
+
CHIEF = distribute_coordinator._TaskType.CHIEF
WORKER = distribute_coordinator._TaskType.WORKER
PS = distribute_coordinator._TaskType.PS
EVALUATOR = distribute_coordinator._TaskType.EVALUATOR
-SPLIT_CLIENT = distribute_coordinator.CoordinatorMode.SPLIT_CLIENT
+STANDALONE_CLIENT = distribute_coordinator.CoordinatorMode.STANDALONE_CLIENT
INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER
-RUN_STD_SERVER_METHOD = "tensorflow.python.distribute.distribute_coordinator._run_std_server"
-
NUM_WORKERS = 3
NUM_PS = 2
@@ -74,6 +75,57 @@ def _strip_protocol(target):
return target
+class MockStrategy(object):
+
+ def __init__(self,
+ between_graph=False,
+ should_init=None,
+ should_checkpoint=None,
+ should_save_summary=None):
+ self._between_graph = between_graph
+ self._should_init = should_init
+ self._should_checkpoint = should_checkpoint
+ self._should_save_summary = should_save_summary
+
+ @property
+ def between_graph(self):
+ return self._between_graph
+
+ def configure(self,
+ session_options=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ del session_options, cluster_spec, task_type
+ if self._should_init is None:
+ if task_id == 0:
+ self._should_init = True
+ else:
+ self._should_init = False
+ if self._should_checkpoint is None:
+ if task_id == 0:
+ self._should_checkpoint = True
+ else:
+ self._should_checkpoint = False
+ if self._should_save_summary is None:
+ if task_id == 0:
+ self._should_save_summary = True
+ else:
+ self._should_save_summary = False
+
+ @property
+ def should_init(self):
+ return self._should_init
+
+ @property
+ def should_checkpoint(self):
+ return self._should_checkpoint
+
+ @property
+ def should_save_summary(self):
+ return self._should_save_summary
+
+
class MockServer(object):
def __init__(self):
@@ -108,6 +160,7 @@ class DistributeCoordinatorTestBase(test.TestCase):
self._result_correct = 0
self._lock = threading.Lock()
self._worker_context = {}
+ self._strategy_property = {}
self._std_servers = {}
self._barrier = distribute_coordinator._Barrier(NUM_WORKERS)
@@ -142,8 +195,8 @@ class DistributeCoordinatorTestBase(test.TestCase):
cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
return cluster_spec
- def _in_graph_worker_fn(self):
- context = distribute_coordinator.get_current_worker_context()
+ def _in_graph_worker_fn(self, strategy):
+ context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
with self._test_session(target=context.master_target) as sess:
xs = []
@@ -164,22 +217,23 @@ class DistributeCoordinatorTestBase(test.TestCase):
if result_value == expected:
self._result_correct += 1
- def _run_coordinator_in_thread(self, worker_fn, **kwargs):
+ def _run_coordinator_in_thread(self, worker_fn, strategy, **kwargs):
t = threading.Thread(
target=distribute_coordinator.run_distribute_coordinator,
- args=(worker_fn,),
+ args=(worker_fn, strategy),
kwargs=kwargs)
t.start()
return t
- def _run_multiple_coordinator_in_threads(self, worker_fn, cluster_spec,
- **kwargs):
+ def _run_multiple_coordinator_in_threads(self, worker_fn, strategy,
+ cluster_spec, **kwargs):
threads = {}
for task_type in cluster_spec.keys():
threads[task_type] = []
for task_id in range(len(cluster_spec[task_type])):
t = self._run_coordinator_in_thread(
worker_fn,
+ strategy,
cluster_spec=cluster_spec,
task_type=task_type,
task_id=task_id,
@@ -187,8 +241,8 @@ class DistributeCoordinatorTestBase(test.TestCase):
threads[task_type].append(t)
return threads
- def _between_graph_worker_fn(self):
- context = distribute_coordinator.get_current_worker_context()
+ def _between_graph_worker_fn(self, strategy):
+ context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
with self._test_session(target=context.master_target) as sess:
with ops.device("/job:ps/task:0"):
@@ -234,14 +288,50 @@ class DistributeCoordinatorTestBase(test.TestCase):
with self._lock:
self._result_correct += 1
- def _dump_worker_context(self):
+ def _between_graph_with_monitored_session(self, strategy):
+ context = distribute_coordinator_context.get_current_worker_context()
+ self.assertTrue(context is not None)
+ with ops.device("/job:ps/task:0"):
+ # TODO(yuefengz): investigate why not using resource variable will make
+ # the test flaky.
+ x = variable_scope.get_variable("x", initializer=10.0, use_resource=True)
+ with ops.device("/job:ps/task:1"):
+ y = variable_scope.get_variable("y", initializer=20.0, use_resource=True)
+
+ x_add = x.assign_add(2.0)
+ y_sub = y.assign_sub(2.0)
+ train_op = control_flow_ops.group([x_add, y_sub])
+
+ # The monitored session will run init or ready ops.
+ with monitored_session.MonitoredSession() as sess:
+ sess.run(train_op)
+
+ # Synchronize workers after one step to make sure they all have finished
+ # training.
+ if context.has_barrier:
+ context.wait_for_other_workers()
+ else:
+ self._barrier.wait()
+
+ x_val, y_val = sess.run([x, y])
+
+ self.assertEqual(x_val, 16.0)
+ self.assertEqual(y_val, 14.0)
+ if x_val == 16.0 and y_val == 14.0:
+ with self._lock:
+ self._result_correct += 1
+
+ def _dump_worker_context(self, strategy):
"""Dumps the propoerties of each worker context.
It dumps the context properties to a dict mapping from task_type to a list
of tuples of master_target, num_workers, is_chief and distribute_mode, where
the list is indexed by the task_id.
+
+ Args:
+ strategy: a `DistributionStrategy` object.
"""
- context = distribute_coordinator.get_current_worker_context()
+ context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
task_type = str(context.task_type)
task_id = context.task_id or 0
@@ -255,6 +345,25 @@ class DistributeCoordinatorTestBase(test.TestCase):
context.is_chief,
context.distributed_mode)
+ def _dump_strategy_property(self, strategy):
+ context = distribute_coordinator_context.get_current_worker_context()
+ self.assertTrue(context is not None)
+
+ self.assertEqual(context._strategy.should_init, strategy.should_init)
+ self.assertEqual(context.should_checkpoint, strategy.should_checkpoint)
+ self.assertEqual(context.should_save_summary, strategy.should_save_summary)
+
+ task_type = str(context.task_type)
+ task_id = context.task_id or 0
+ with self._lock:
+ if task_type not in self._strategy_property:
+ self._strategy_property[task_type] = []
+ while len(self._strategy_property[task_type]) <= task_id:
+ self._strategy_property[task_type].append(None)
+ self._strategy_property[task_type][task_id] = (
+ context._strategy.should_init, context.should_checkpoint,
+ context.should_save_summary)
+
def _run_mock_std_server(self,
session_config=None,
cluster_spec=None,
@@ -274,22 +383,32 @@ class DistributeCoordinatorTestBase(test.TestCase):
return server
-class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
+class DistributeCoordinatorTestStandaloneMode(DistributeCoordinatorTestBase):
- def testInGraphSplitMode(self):
- """Test it runs in-graph replication in split client mode."""
+ def testInGraphStandaloneMode(self):
+ """Test it runs in-graph replication in standalone client mode."""
distribute_coordinator.run_distribute_coordinator(
self._in_graph_worker_fn,
- cluster_spec=self._cluster_spec,
- between_graph=False)
+ MockStrategy(between_graph=False),
+ cluster_spec=self._cluster_spec)
self.assertEqual(self._result_correct, 1)
def testBetweenGraph(self):
- """Test it runs between-graph replication in split client mode."""
+ """Test it runs between-graph replication in standalone client mode."""
distribute_coordinator.run_distribute_coordinator(
self._between_graph_worker_fn,
- cluster_spec=self._cluster_spec,
- between_graph=True)
+ MockStrategy(between_graph=True),
+ cluster_spec=self._cluster_spec)
+
+ # Each finished worker will increment self._result_correct.
+ self.assertEqual(self._result_correct, NUM_WORKERS)
+
+ def testBetweenGraphWithMonitoredSession(self):
+ """Test monitored session in standalone client mode."""
+ distribute_coordinator.run_distribute_coordinator(
+ self._between_graph_with_monitored_session,
+ MockStrategy(between_graph=True),
+ cluster_spec=self._cluster_spec)
# Each finished worker will increment self._result_correct.
self.assertEqual(self._result_correct, NUM_WORKERS)
@@ -298,8 +417,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
- cluster_spec=self._cluster_spec,
- between_graph=True)
+ MockStrategy(between_graph=True),
+ cluster_spec=self._cluster_spec)
# There is only one type of task and there three such tasks.
self.assertEqual(len(self._worker_context), 1)
@@ -318,12 +437,30 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
self._worker_context[WORKER][2],
(_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True))
+ def testBetweenGraphStrategyProperties(self):
+ # Dumps properties of the strategy objects.
+ distribute_coordinator.run_distribute_coordinator(
+ self._dump_strategy_property,
+ MockStrategy(between_graph=True, should_init=True),
+ cluster_spec=self._cluster_spec)
+
+ # There is only one type of task and there three such tasks.
+ self.assertEqual(len(self._strategy_property), 1)
+ self.assertTrue(WORKER in self._strategy_property)
+ self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)
+
+ # Check whether each task has the right properties of should_init,
+ # should_checkpoint and should_save_summary.
+ self.assertEqual(self._strategy_property[WORKER][0], (True, True, True))
+ self.assertEqual(self._strategy_property[WORKER][1], (True, False, False))
+ self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
+
def testInGraphContext(self):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
- cluster_spec=self._cluster_spec,
- between_graph=False)
+ MockStrategy(between_graph=False),
+ cluster_spec=self._cluster_spec)
# There is only a "None" task in the dumped task context.
self.assertEqual(len(self._worker_context), 1)
@@ -339,7 +476,9 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
def testLocalContext(self):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
- self._dump_worker_context, cluster_spec=None, between_graph=True)
+ self._dump_worker_context,
+ MockStrategy(between_graph=False),
+ cluster_spec=None)
# There is only a "None" task.
self.assertEqual(len(self._worker_context), 1)
@@ -348,7 +487,7 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
- self.assertEqual(self._worker_context["None"][0], ("local", 0, True, False))
+ self.assertEqual(self._worker_context["None"][0], ("", 0, True, False))
def testBetweenGraphContextWithChief(self):
# Adds a chief node, so there are NUM_WORKERS + 1 workers in total.
@@ -358,8 +497,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
+ MockStrategy(between_graph=True),
cluster_spec=cluster_spec,
- between_graph=True,
rpc_layer="grpc")
# There are one CHIEF and three workers.
@@ -391,8 +530,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
+ MockStrategy(between_graph=False),
cluster_spec=cluster_spec,
- between_graph=False,
rpc_layer=None)
# There are one "None" task and one EVALUATOR task.
@@ -417,8 +556,8 @@ class DistributeCoordinatorTestInpendentWorkerMode(
cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
threads = self._run_multiple_coordinator_in_threads(
self._in_graph_worker_fn,
+ MockStrategy(between_graph=False),
cluster_spec,
- between_graph=False,
mode=INDEPENDENT_WORKER)
threads[WORKER][0].join()
self.assertEqual(self._result_correct, 1)
@@ -428,8 +567,22 @@ class DistributeCoordinatorTestInpendentWorkerMode(
num_workers=NUM_WORKERS, num_ps=NUM_PS)
threads = self._run_multiple_coordinator_in_threads(
self._between_graph_worker_fn,
+ MockStrategy(between_graph=True),
+ cluster_spec,
+ mode=INDEPENDENT_WORKER)
+ for task_id in range(NUM_WORKERS):
+ threads[WORKER][task_id].join()
+
+ # Each finished worker will increment self._result_correct.
+ self.assertEqual(self._result_correct, NUM_WORKERS)
+
+ def testBetweenGraphWithMonitoredSession(self):
+ cluster_spec = self._create_cluster_spec(
+ num_workers=NUM_WORKERS, num_ps=NUM_PS)
+ threads = self._run_multiple_coordinator_in_threads(
+ self._between_graph_with_monitored_session,
+ MockStrategy(between_graph=True),
cluster_spec,
- between_graph=True,
mode=INDEPENDENT_WORKER)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()
@@ -444,9 +597,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads(
self._dump_worker_context,
+ MockStrategy(between_graph=True),
cluster_spec,
mode=INDEPENDENT_WORKER,
- between_graph=True,
rpc_layer=None)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()
@@ -476,6 +629,31 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self.assertFalse(self._std_servers[WORKER][1].joined)
self.assertFalse(self._std_servers[WORKER][2].joined)
+ def testBetweenGraphStrategyProperties(self):
+ cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
+ # Dumps properties of the strategy objects.
+ with test.mock.patch.object(distribute_coordinator, "_run_std_server",
+ self._run_mock_std_server):
+ threads = self._run_multiple_coordinator_in_threads(
+ self._dump_strategy_property,
+ MockStrategy(between_graph=True, should_init=True),
+ cluster_spec,
+ mode=INDEPENDENT_WORKER,
+ rpc_layer=None)
+ for task_id in range(NUM_WORKERS):
+ threads[WORKER][task_id].join()
+
+ # There is only one type of task and there three such tasks.
+ self.assertEqual(len(self._strategy_property), 1)
+ self.assertTrue(WORKER in self._strategy_property)
+ self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)
+
+ # Check whether each task has the right properties of should_init,
+ # should_checkpoint and should_save_summary.
+ self.assertEqual(self._strategy_property[WORKER][0], (True, True, True))
+ self.assertEqual(self._strategy_property[WORKER][1], (True, False, False))
+ self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
+
def testInGraphContext(self):
cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
# Dumps the task contexts and std server arguments.
@@ -483,9 +661,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads(
self._dump_worker_context,
+ MockStrategy(between_graph=False),
cluster_spec,
mode=INDEPENDENT_WORKER,
- between_graph=False,
rpc_layer=None)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()
@@ -519,9 +697,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads(
self._dump_worker_context,
+ MockStrategy(between_graph=False),
cluster_spec,
mode=INDEPENDENT_WORKER,
- between_graph=False,
rpc_layer=None)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()
diff --git a/tensorflow/python/distribute/multi_worker_util.py b/tensorflow/python/distribute/multi_worker_util.py
new file mode 100644
index 0000000000..360733eff6
--- /dev/null
+++ b/tensorflow/python/distribute/multi_worker_util.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for multi-worker distribution strategies."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.training import server_lib
+
+
+def normalize_cluster_spec(cluster_spec):
+ """Makes `cluster_spec` into a `ClusterSpec` object.
+
+ Args:
+ cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
+ cluster configurations.
+
+ Returns:
+ a `ClusterSpec` object.
+
+ Raises:
+ ValueError: if `cluster_spec` is not a dict or a `ClusterSpec` or a
+ `ClusterDef`.
+ """
+ if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
+ return server_lib.ClusterSpec(cluster_spec)
+ elif not isinstance(cluster_spec, server_lib.ClusterSpec):
+ raise ValueError(
+ "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
+ "`tf.train.ClusterDef` object")
+ return cluster_spec
+
+
+def is_chief(cluster_spec, task_type, task_id):
+ """Returns whether the given task is chief in the cluster.
+
+ Args:
+ cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
+ cluster configurations.
+ task_type: the task type in the cluster.
+ task_id: the task id in the cluster.
+
+ Returns:
+ a boolean indicating whether the given task is chief.
+
+ Raises:
+ ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds
+ the maximum id of the `task_type`.
+ """
+ cluster_spec = normalize_cluster_spec(cluster_spec)
+ if task_type not in cluster_spec.jobs:
+ raise ValueError(
+ "The task_type \"%s\" is not in the `cluster_spec`." % task_type)
+ if task_id >= cluster_spec.num_tasks(task_type):
+ raise ValueError("The `task_id` %d exceeds the maximum id of %s." % (
+ task_id, task_type))
+
+ if task_type == "chief":
+ return True
+
+ # If chief not in the cluster_spec, use the first worker as chief. This is
+ # common in CollectiveAllReduceStrategy.
+ if ("chief" not in cluster_spec.jobs and task_type == "worker" and
+ task_id == 0):
+ return True
+ return False
diff --git a/tensorflow/python/distribute/multi_worker_util_test.py b/tensorflow/python/distribute/multi_worker_util_test.py
new file mode 100644
index 0000000000..bdc49725c7
--- /dev/null
+++ b/tensorflow/python/distribute/multi_worker_util_test.py
@@ -0,0 +1,107 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for multi_worker_util."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.distribute import multi_worker_util
+from tensorflow.python.eager import test
+from tensorflow.python.training import server_lib
+
+
+class NormalizeClusterSpecTest(test.TestCase):
+
+ def assert_same_cluster(self, lhs, rhs):
+ self.assertEqual(
+ server_lib.ClusterSpec(lhs).as_dict(),
+ server_lib.ClusterSpec(rhs).as_dict())
+
+ def testDictAsInput(self):
+ cluster_spec = {
+ "chief": ["127.0.0.1:1234"],
+ "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
+ "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
+ }
+ self.assert_same_cluster(
+ cluster_spec, multi_worker_util.normalize_cluster_spec(cluster_spec))
+
+ def testClusterDefAsInput(self):
+ cluster_def = cluster_pb2.ClusterDef()
+ job = cluster_def.job.add()
+ job.name = "chief"
+ job.tasks[0] = "127.0.0.1:1234"
+
+ job = cluster_def.job.add()
+ job.name = "worker"
+ job.tasks[0] = "127.0.0.1:8964"
+ job.tasks[1] = "127.0.0.1:2333"
+
+ job = cluster_def.job.add()
+ job.name = "ps"
+ job.tasks[0] = "127.0.0.1:1926"
+ job.tasks[1] = "127.0.0.1:3141"
+
+ self.assert_same_cluster(
+ cluster_def, multi_worker_util.normalize_cluster_spec(cluster_def))
+
+ def testClusterSpecAsInput(self):
+ cluster_spec = server_lib.ClusterSpec({
+ "chief": ["127.0.0.1:1234"],
+ "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
+ "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
+ })
+ self.assert_same_cluster(
+ cluster_spec, multi_worker_util.normalize_cluster_spec(cluster_spec))
+
+ def testUnexpectedInput(self):
+ cluster_spec = ["127.0.0.1:8964", "127.0.0.1:2333"]
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
+ "`tf.train.ClusterDef` object"):
+ multi_worker_util.normalize_cluster_spec(cluster_spec)
+
+
+class IsChiefTest(test.TestCase):
+
+ def testClusterWithChief(self):
+ cluster_spec = {
+ "chief": ["127.0.0.1:1234"],
+ "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
+ "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
+ }
+ self.assertTrue(multi_worker_util.is_chief(cluster_spec, "chief", 0))
+ self.assertFalse(multi_worker_util.is_chief(cluster_spec, "worker", 0))
+
+ def testClusterWithoutChief(self):
+ cluster_spec = {"worker": ["127.0.0.1:8964", "127.0.0.1:2333"]}
+ self.assertTrue(multi_worker_util.is_chief(cluster_spec, "worker", 0))
+ self.assertFalse(multi_worker_util.is_chief(cluster_spec, "worker", 1))
+
+ with self.assertRaisesRegexp(
+ ValueError, "The task_type \"chief\" is not in the `cluster_spec`."):
+ multi_worker_util.is_chief(cluster_spec, "chief", 0)
+
+ with self.assertRaisesRegexp(
+ ValueError, "The `task_id` 2 exceeds the maximum id of worker."):
+ multi_worker_util.is_chief(cluster_spec, "worker", 2)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index de93b1e2e1..bdabbf4ea3 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -47,7 +47,6 @@ py_library(
":core",
":execute",
":function",
- ":graph_callable",
":graph_only_ops",
":tape",
":test",
@@ -254,41 +253,6 @@ py_library(
)
py_library(
- name = "graph_callable",
- srcs = ["graph_callable.py"],
- srcs_version = "PY2AND3",
- visibility = ["//tensorflow:internal"],
- deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:util",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/eager:context",
- "//tensorflow/python/eager:function",
- "//tensorflow/python/eager:tape",
- ],
-)
-
-py_test(
- name = "graph_callable_test",
- srcs = ["graph_callable_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":backprop",
- ":graph_callable",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:function",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/eager:test",
- ],
-)
-
-py_library(
name = "backprop",
srcs = ["backprop.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 553f761a14..7978383e55 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -34,6 +34,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging as logging
@@ -180,10 +181,10 @@ def implicit_val_and_grad(f):
```
Args:
- f: function to be differentiated. If `f` returns a scalar, this scalar will
- be differentiated. If `f` returns a tensor or list of tensors, by default
- a scalar will be computed by adding all their values to produce a single
- scalar.
+ f: function to be differentiated. If `f` returns a scalar, this scalar will
+ be differentiated. If `f` returns a tensor or list of tensors, by default
+ a scalar will be computed by adding all their values to produce a single
+ scalar.
Returns:
A function which, when called, returns a tuple pair.
@@ -255,10 +256,10 @@ def implicit_grad(f):
```
Args:
- f: function to be differentiated. If `f` returns a scalar, this scalar will
- be differentiated. If `f` returns a tensor or list of tensors, by default
- a scalar will be computed by adding all their values to produce a single
- scalar.
+ f: function to be differentiated. If `f` returns a scalar, this scalar will
+ be differentiated. If `f` returns a tensor or list of tensors, by default
+ a scalar will be computed by adding all their values to produce a single
+ scalar.
Returns:
A function which, when called, returns a list of (gradient, variable) pairs.
@@ -343,24 +344,24 @@ def gradients_function(f, params=None):
Note that only tensors with real or complex dtypes are differentiable.
Args:
- f: function to be differentiated. If `f` returns a scalar, this scalar will
- be differentiated. If `f` returns a tensor or list of tensors, by default
- a scalar will be computed by adding all their values to produce a single
- scalar. If desired, the tensors can be elementwise multiplied by the
- tensors passed as the `dy` keyword argument to the returned gradient
- function.
- params: list of parameter names of f or list of integers indexing the
- parameters with respect to which we'll differentiate. Passing None
- differentiates with respect to all parameters.
+ f: function to be differentiated. If `f` returns a scalar, this scalar will
+ be differentiated. If `f` returns a tensor or list of tensors, by default
+ a scalar will be computed by adding all their values to produce a single
+ scalar. If desired, the tensors can be elementwise multiplied by the
+ tensors passed as the `dy` keyword argument to the returned gradient
+ function.
+ params: list of parameter names of f or list of integers indexing the
+ parameters with respect to which we'll differentiate. Passing None
+ differentiates with respect to all parameters.
Returns:
function which, when called, returns the value of f and the gradient
- of f with respect to all of `params`. The function takes an extra optional
- keyword argument "dy". Setting it allows computation of vector jacobian
+ of `f` with respect to all of `params`. The function takes an extra optional
+ keyword argument `dy`. Setting it allows computation of vector jacobian
products for vectors other than the vector of ones.
Raises:
- ValueError: if the params are not all strings or all integers.
+ ValueError: if the params are not all strings or all integers.
"""
def decorated(*args, **kwds):
@@ -440,23 +441,24 @@ def val_and_grad_function(f, params=None):
```
Args:
- f: function to be differentiated. If `f` returns a scalar, this scalar will
- be differentiated. If `f` returns a tensor or list of tensors, by default
- a scalar will be computed by adding all their values to produce a single
- scalar. If desired, the tensors can be elementwise multiplied by the
- tensors passed as the `dy` keyword argument to the returned gradient
- function.
- params: list of parameter names of f or list of integers indexing the
- parameters with respect to which we'll differentiate. Passing `None`
- differentiates with respect to all parameters.
-
- Returns: function which, when called, returns the value of f and the gradient
- of f with respect to all of `params`. The function takes an extra optional
- keyword argument "dy". Setting it allows computation of vector jacobian
- products for vectors other than the vector of ones.
+ f: function to be differentiated. If `f` returns a scalar, this scalar will
+ be differentiated. If `f` returns a tensor or list of tensors, by default
+ a scalar will be computed by adding all their values to produce a single
+ scalar. If desired, the tensors can be elementwise multiplied by the
+ tensors passed as the `dy` keyword argument to the returned gradient
+ function.
+ params: list of parameter names of f or list of integers indexing the
+ parameters with respect to which we'll differentiate. Passing `None`
+ differentiates with respect to all parameters.
+
+ Returns:
+ function which, when called, returns the value of f and the gradient
+ of f with respect to all of `params`. The function takes an extra optional
+ keyword argument "dy". Setting it allows computation of vector jacobian
+ products for vectors other than the vector of ones.
Raises:
- ValueError: if the params are not all strings or all integers.
+ ValueError: if the params are not all strings or all integers.
"""
def decorated(*args, **kwds):
@@ -557,7 +559,7 @@ def _aggregate_grads(gradients):
if len(gradients) == 1:
return gradients[0]
if all([isinstance(g, ops.Tensor) for g in gradients]):
- return math_ops.add_n(gradients)
+ return gen_math_ops.add_n(gradients)
else:
assert all([isinstance(g, (ops.Tensor, ops.IndexedSlices))
for g in gradients])
@@ -592,7 +594,9 @@ def _num_elements(grad):
def _fast_fill(value, shape, dtype):
- return array_ops.fill(shape, constant_op.constant(value, dtype=dtype))
+ return array_ops.fill(
+ constant_op.constant(shape, dtype=dtypes.int32),
+ constant_op.constant(value, dtype=dtype))
def _zeros(shape, dtype):
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index e2b1890c2f..a2e8422671 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -350,6 +350,21 @@ class MicroBenchmarks(test.Benchmark):
func = lambda: f(m, m, transpose_b)
self._run(func, num_iters, execution_mode=execution_mode)
+ def _benchmark_defun_matmul_forward_backward(self,
+ m,
+ transpose_b,
+ num_iters,
+ execution_mode=None):
+ f = function.defun(math_ops.matmul)
+
+ def func():
+ with backprop.GradientTape() as gt:
+ gt.watch(m)
+ y = f(m, m, transpose_b)
+ _ = gt.gradient(y, m)
+
+ self._run(func, num_iters, execution_mode=execution_mode)
+
def _benchmark_read_variable(self, m, num_iters):
self._run(m.value, num_iters)
@@ -421,6 +436,21 @@ class MicroBenchmarks(test.Benchmark):
num_iters=self._num_iters_2_by_2,
execution_mode=context.ASYNC)
+ def benchmark_defun_matmul_forward_backward_2_by_2_CPU(self):
+ with context.device(CPU):
+ m = self._m_2_by_2.cpu()
+ self._benchmark_defun_matmul_forward_backward(
+ m, transpose_b=False, num_iters=self._num_iters_2_by_2)
+
+ def benchmark_defun_matmul_forward_backward_2_by_2_CPU_async(self):
+ with context.device(CPU):
+ m = self._m_2_by_2.cpu()
+ self._benchmark_defun_matmul_forward_backward(
+ m,
+ transpose_b=False,
+ num_iters=self._num_iters_2_by_2,
+ execution_mode=context.ASYNC)
+
def benchmark_tf_matmul_2_by_2_GPU(self):
if not context.num_gpus():
return
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 5afba466bc..3f8dac0bd4 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -26,7 +26,6 @@ import threading
import numpy as np
import six
-from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
@@ -102,7 +101,7 @@ class CapturingGraph(ops.Graph):
The entries are in the order they were captured.
"""
- def __init__(self):
+ def __init__(self, graph=None):
super(CapturingGraph, self).__init__()
self.captures = collections.OrderedDict()
@@ -185,7 +184,6 @@ class FuncGraph(CapturingGraph):
Attributes:
name: The name of the function.
-
inputs: Placeholder tensors representing the inputs to this function. The
tensors are in this FuncGraph. This represents "regular" inputs as well as
captured inputs (i.e. the values of self.captures), with the regular
@@ -207,7 +205,7 @@ class FuncGraph(CapturingGraph):
graph: if specified, this FuncGraph will inherit its graph key,
collections, and seed from `graph`.
"""
- super(FuncGraph, self).__init__()
+ super(FuncGraph, self).__init__(graph=graph)
self.name = name
self.inputs = []
@@ -233,8 +231,12 @@ class FuncGraph(CapturingGraph):
if context.executing_eagerly():
self.seed = context.global_seed()
+ self._xla_compile = (context.context().device_spec.device_type == "TPU")
else:
self.seed = graph.seed
+ self._xla_compile = getattr(graph, "_xla_compile", False)
+ else:
+ self._xla_compile = False
def capture(self, tensor, name=None):
"""Calls CapturingGraph.capture and updates self.inputs if necessary."""
@@ -267,9 +269,6 @@ def _register(fn):
context.context().add_function(fn)
-_xla_compile_attr = "_XlaCompile"
-
-
# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction
# so it doesn't have the definition-generating logic and is just a container for
# an already-defined function.
@@ -282,18 +281,20 @@ class _EagerDefinedFunction(object):
class may be provided as the value of these `func` attributes.
"""
- def __init__(self, name, graph, operations, inputs, outputs, attrs):
+ def __init__(self, name, graph, inputs, outputs, attrs):
"""Initializes an eager defined function.
Args:
name: str, the name for the created function.
graph: Graph, the graph containing the operations in the function
- operations: list of Operation; the subset of operations in the graph
- which will be in the function
inputs: the tensors in the graph to be used as inputs to the function
outputs: the tensors in the graph which will be outputs to the function
attrs: dict mapping names of attributes to their AttrValue values
"""
+ operations = [
+ op for op in graph.get_operations()
+ if op not in set(arg.op for arg in inputs)
+ ]
fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
graph._c_graph, # pylint: disable=protected-access
compat.as_str(name),
@@ -311,7 +312,6 @@ class _EagerDefinedFunction(object):
# It might be worth creating a convenient way to re-use status.
pywrap_tensorflow.TF_FunctionSetAttrValueProto(
fn, compat.as_str(name), serialized)
- self._xla_compile = _xla_compile_attr in attrs
# TODO(apassos) avoid creating a FunctionDef (specially to grab the
# signature, but also in general it's nice not to depend on it.
@@ -327,6 +327,7 @@ class _EagerDefinedFunction(object):
self.signature = function_def.signature
self._num_outputs = len(self.signature.output_arg)
self._output_types = [o.type for o in self.signature.output_arg]
+ self._output_shapes = [o.shape for o in outputs]
self.grad_func_name = None
self.python_grad_func = None
self._c_func = c_api_util.ScopedTFFunction(fn)
@@ -347,7 +348,7 @@ class _EagerDefinedFunction(object):
def stateful_ops(self):
return self._stateful_ops
- def call(self, ctx, args, output_shapes):
+ def call(self, ctx, args):
"""Calls this function with `args` as inputs.
Function execution respects device annotations only if the function won't
@@ -356,8 +357,6 @@ class _EagerDefinedFunction(object):
Args:
ctx: a Context object
args: a list of arguments to supply this function with.
- output_shapes: shapes to which outputs should be set; ignored when
- executing eagerly.
Returns:
The outputs of the function call.
@@ -365,10 +364,7 @@ class _EagerDefinedFunction(object):
executing_eagerly = ctx.executing_eagerly()
- xla_compile = self._xla_compile or (executing_eagerly and
- ctx.device_spec.device_type == "TPU")
-
- if xla_compile:
+ if self._graph._xla_compile: # pylint: disable=protected-access
# XLA compilation relies upon a custom kernel creator to run functions.
signature = self.signature
if executing_eagerly:
@@ -406,7 +402,7 @@ class _EagerDefinedFunction(object):
if executing_eagerly:
return outputs
else:
- for i, shape in enumerate(output_shapes):
+ for i, shape in enumerate(self._output_shapes):
outputs[i].set_shape(shape)
return outputs
@@ -427,153 +423,106 @@ def _flatten(sequence):
return outputs
-# TODO(akshayka): Perhaps rename to something more appropriate.
-class GraphModeFunction(object):
+class GraphCallable(object):
"""Callable object encapsulating a function definition and its gradient.
- `GraphModeFunction` is a callable that encapsulates a function definition and
+ `GraphCallable` is a callable that encapsulates a function definition and
is differentiable under `tf.GradientTape` objects.
"""
- def __init__(self,
- name,
- input_placeholders,
- extra_inputs,
- graph,
- operations,
- outputs,
- python_func_outputs,
- output_shapes,
- variables=None,
- attrs=None):
- """Initialize a GraphModeFunction.
+ def __init__(self, func_graph, attrs=None):
+ """Initialize a GraphCallable.
Args:
- name: str the name of the created function
- input_placeholders: list of placeholder values (tensors) to feed when
- calling the wrapped function.
- extra_inputs: Tensor inputs this function definition closed over which
- are passed as arguments. Need to track so gradients are supported
- correctly.
- graph: the Graph from which the operations will be pulled. Used as
- a context when computing gradients.
- operations: the subset of Operations in the graph used in the function
- definition.
- outputs: a flat list of the Tensors in the graph used as outputs to the
- function
- python_func_outputs: a possibly nested python object which will be
- returned by this function. The Tensors in this structure will be
- replaced by their corresponding values in outputs. Note that this
- structure might contain Python `None`s.
- output_shapes: List of shapes of all tensors in outputs
- variables: (optional) List of variables to watch during function
- execution.
+ func_graph: An instance of FuncGraph: the function body to wrap.
attrs: (optional) dict mapping names of attributes to their AttrValue
values. Attributes in `attrs` will be included in this function's
definition.
+
+ Raises:
+ ValueError: If number of input_placeholders is not equal to the number
+ of function inputs.
"""
+ self._func_graph = func_graph
+ self._captured_inputs = list(self._func_graph.captures.keys())
+ self._num_outputs = len(self._func_graph.outputs)
+ self._output_shapes = tuple(
+ output.shape for output in self._func_graph.outputs)
self._attrs = attrs or {}
- defined_function = _EagerDefinedFunction(
- name, graph, operations, input_placeholders, outputs, self._attrs)
- if len(input_placeholders) != len(defined_function.signature.input_arg):
- raise ValueError("Internal error: invalid lengths. %s %s" % (
- len(input_placeholders), len(defined_function.signature.input_arg)))
- self._input_placeholders = input_placeholders
- self._extra_inputs = list(extra_inputs)
- self._graph = graph
- self._backward_function = None
- self._func_name = name
- self._function_def = defined_function
- self._num_outputs = len(defined_function.signature.output_arg)
- self._python_func_outputs = python_func_outputs
- self._python_returns = [python_func_outputs] if isinstance(
- python_func_outputs,
- (ops.Tensor, type(None))) else _flatten(python_func_outputs)
- self._output_shapes = output_shapes
- self._variables = variables if variables is not None else []
-
- # Find the variables that are components of something distributed and
- # put them into a {handle_tensor -> distributed variable object} map.
+
+ self._inference_function = _EagerDefinedFunction(
+ _inference_name(self._func_graph.name), self._func_graph,
+ self._func_graph.inputs, self._func_graph.outputs, self._attrs)
+ self._backward_graph_callable = None
+
+ # Map holding distributed variables, keyed by resource handle tensors.
self._distributed_variables = {}
strategy = distribution_strategy_context.get_distribution_strategy()
- for variable in self._variables:
+ for variable in self._func_graph.variables:
# If variable is not distributed, unwrap returns [variable].
component_variables = strategy.unwrap(variable)
- # Only add to the dictionary when the variable is actually distributed,
- # i.e. more than one component or the component is different from the
- # variable itself. component_variables cannot be empty.
+ # Only update the dictionary when the variable is actually distributed.
if (len(component_variables) > 1 or component_variables[0] != variable):
for component_variable in component_variables:
self._distributed_variables[component_variable.handle] = variable
@property
+ def graph(self):
+ return self._func_graph
+
+ @property
def variables(self):
- return self._variables
+ return self._func_graph.variables
def _construct_backprop_function(self):
"""Constructs the backprop function object for this function."""
- filtered_outputs = [x for x in self._python_returns if x is not None]
- # TODO(skyewm): use FuncGraph
- backwards_graph = CapturingGraph()
- backwards_graph._graph_key = self._graph._graph_key # pylint: disable=protected-access
- for collection in self._graph.collections:
- backwards_graph.get_collection_ref(
- collection)[:] = self._graph.get_collection(collection)
- backwards_graph.seed = self._graph.seed
+ backwards_graph = FuncGraph(
+ _backward_name(self._func_graph.name), self._func_graph)
with backwards_graph.as_default():
- self._out_grad_placeholders = [
- graph_placeholder(x.dtype, x.shape) for x in filtered_outputs]
- in_gradients = gradients_impl._GradientsHelper( # pylint: disable=protected-access
- filtered_outputs,
- self._input_placeholders,
- grad_ys=self._out_grad_placeholders,
- src_graph=self._graph)
-
- backward_outputs = tuple(
- grad for grad in _flatten(in_gradients) if grad is not None)
- output_shapes = tuple(grad.shape for grad in backward_outputs)
-
- extra_inputs = backwards_graph.captures.keys()
- extra_placeholders = backwards_graph.captures.values()
-
- forward_name = _forward_name(self._func_name)
- # Note: we cannot have placeholder ops in the graph or the TPU compilation
- # pass fails.
- placeholder_ops = set([y.op for y in self._input_placeholders])
- function_ops = [x for x in self._graph.get_operations()
- if x not in placeholder_ops]
- self._forward_fdef = _EagerDefinedFunction(
- forward_name, self._graph, function_ops,
- self._input_placeholders, filtered_outputs + list(extra_inputs),
+ gradients_wrt_outputs = [
+ graph_placeholder(x.dtype, x.shape) for x in self._func_graph.outputs
+ ]
+ gradients_wrt_inputs = gradients_impl._GradientsHelper( # pylint: disable=protected-access
+ self._func_graph.outputs,
+ self._func_graph.inputs,
+ grad_ys=gradients_wrt_outputs,
+ src_graph=self._func_graph)
+
+ self._forward_function = _EagerDefinedFunction(
+ _forward_name(
+ self._func_graph.name), self._func_graph, self._func_graph.inputs,
+ self._func_graph.outputs + list(backwards_graph.captures.keys()),
self._attrs)
- all_inputs = self._out_grad_placeholders + list(extra_placeholders)
- # Excluding input ops from the body as we do not intend to execute these
- # operations when the function is executed.
- all_ignored_ops = frozenset(x.op for x in all_inputs)
- # Enforce a deterministic order of operations in the generated graph. This
- # means rerunning the function-defining code will always define the same
- # function, which is useful if we serialize this etc.
- function_def_ops = tuple(x
- for x in sorted(backwards_graph.get_operations(),
- key=lambda x: x.name)
- if x not in all_ignored_ops)
- bname = _backward_name(self._func_name)
- self._backward_function = GraphModeFunction(
- bname, all_inputs, [], backwards_graph, function_def_ops,
- backward_outputs, in_gradients, output_shapes, attrs=self._attrs)
+
+ # The ordering of `backwards_graph.inputs` is important: inputs of
+ # `self._backward_graph_callable` correspond to outputs of
+ # `self._forward_function`.
+ backwards_graph.inputs = gradients_wrt_outputs + list(
+ backwards_graph.captures.values())
+ # Clear captures, since we pass them in as inputs.
+ backwards_graph.captures = {}
+ backwards_graph.outputs.extend(
+ grad for grad in _flatten(gradients_wrt_inputs) if grad is not None)
+ backwards_graph.structured_outputs = gradients_wrt_inputs
+ self._backward_graph_callable = GraphCallable(
+ backwards_graph, attrs=self._attrs)
def _backprop_call(self, args):
- """Calls the wrapped function and records the result on a tape.
+ """Calls the forward function and records the result on a tape.
(Only records results on a tape if the function has outputs)
Args:
- args: All inputs to the function, including resolved extra inputs
+ args: All inputs to the function, including resolved captured inputs
+
Returns:
The call output.
"""
+ if self._backward_graph_callable is None:
+ self._construct_backprop_function()
+
ctx = context.context()
- outputs = self._forward_fdef.call(ctx, args, self._output_shapes)
+ outputs = self._forward_function.call(ctx, args)
if isinstance(outputs, ops.Operation) or outputs is None:
return outputs
@@ -584,14 +533,10 @@ class GraphModeFunction(object):
side_outputs = outputs[self._num_outputs:]
def backward_function(*args):
- return self._backward_function(*(list(args) + side_outputs)) # pylint: disable=not-callable
-
- tape.record_operation(
- self._forward_fdef.signature.name,
- real_outputs,
- args,
- backward_function)
+ return self._backward_graph_callable(*(list(args) + side_outputs)) # pylint: disable=not-callable
+ tape.record_operation(self._forward_function.signature.name, real_outputs,
+ args, backward_function)
return self._build_call_outputs(real_outputs)
@property
@@ -599,7 +544,7 @@ class GraphModeFunction(object):
"""The function's output shapes."""
# TODO(ebrevdo): Should we only keep the output shapes associated
# with len(self._python_returns) outputs?
- outputs_list = nest.flatten(self._python_func_outputs)
+ outputs_list = nest.flatten(self._func_graph.structured_outputs)
j = 0
for i, o in enumerate(outputs_list):
if o is not None:
@@ -613,23 +558,25 @@ class GraphModeFunction(object):
else:
outputs_list[i] = self._output_shapes[j]
j += 1
- return nest.pack_sequence_as(self._python_func_outputs, outputs_list)
+ return nest.pack_sequence_as(self._func_graph.structured_outputs,
+ outputs_list)
@property
def output_dtypes(self):
- return nest.map_structure(
- lambda x: x.dtype if x is not None else None, self._python_func_outputs)
+ return nest.map_structure(lambda x: x.dtype if x is not None else None,
+ self._func_graph.structured_outputs)
@property
def captured_inputs(self):
- return self._extra_inputs
+ # TODO(akshayka): Should this return `_resolve_captured_inputs()`?
+ return self._captured_inputs
@property
def name(self):
"""Returns the name of the function in Eager-compatible format."""
- return self._function_def.name.encode("utf-8")
+ return self._inference_function.name.encode("utf-8")
- def _resolve_extra_inputs(self):
+ def _resolve_captured_inputs(self):
"""Resolve captured distributed variables to their current values.
Some inputs can be distributed variables. Such variables yield a different
@@ -637,43 +584,39 @@ class GraphModeFunction(object):
execution.
Returns:
- a list of resolved extra input tensors.
+ a list of resolved captured input tensors.
"""
if self._distributed_variables:
- # Loop over each extra_inputs and check if it corresponds to something
+ # Loop over each captured input and check if it corresponds to something
# distributed. If so, get its _distributed_container and fetch the
# component appropriate for the current execution context.
- resolved_extra_inputs = self._extra_inputs[:]
- for i, extra_input in enumerate(self._extra_inputs):
- distributed_var = self._distributed_variables.get(extra_input, None)
+ resolved_captured_inputs = self._captured_inputs[:]
+ for i, captured_input in enumerate(self._captured_inputs):
+ distributed_var = self._distributed_variables.get(captured_input, None)
if distributed_var is not None:
# distributed variables override __getattr__ and substitute the
# right component variable. In here, `distributed_var.handle`
# actually does the equivalent of
# distributed_var.get_current_component_var().handle.
- resolved_extra_inputs[i] = distributed_var.handle
- return resolved_extra_inputs
-
- return self._extra_inputs
+ resolved_captured_inputs[i] = distributed_var.handle
+ return resolved_captured_inputs
+ return self._captured_inputs
def __call__(self, *args):
"""Executes the passed function in eager mode."""
- for v in self._variables:
+ for v in self._func_graph.variables:
if v.trainable:
tape.watch_variable(v)
- resolved_extra_inputs = self._resolve_extra_inputs()
-
+ captures = self._resolve_captured_inputs()
tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
- args = tensor_inputs + resolved_extra_inputs
- if tape.should_record(tensor_inputs) or tape.should_record(
- resolved_extra_inputs):
- if self._backward_function is None:
- self._construct_backprop_function()
+ args = tensor_inputs + captures
+
+ if tape.should_record(tensor_inputs) or tape.should_record(captures):
return self._backprop_call(args)
ctx = context.context()
- outputs = self._function_def.call(ctx, args, self._output_shapes)
+ outputs = self._inference_function.call(ctx, args)
return self._build_call_outputs(outputs)
def _build_call_outputs(self, result):
@@ -684,12 +627,12 @@ class GraphModeFunction(object):
Returns:
The actual call output.
"""
- if self._python_func_outputs is None:
+ if self._func_graph.structured_outputs is None:
return result
# Use `nest.flatten` instead of `_flatten` in order to preserve any
- # IndexedSlices in `self._python_func_outputs`.
- outputs_list = nest.flatten(self._python_func_outputs)
+ # IndexedSlices in `self._func_graph.structured_outputs`.
+ outputs_list = nest.flatten(self._func_graph.structured_outputs)
j = 0
for i, o in enumerate(outputs_list):
if o is not None:
@@ -703,13 +646,13 @@ class GraphModeFunction(object):
j += 3
else:
outputs_list[i] = ops.IndexedSlices(
- values=result[j],
- indices=result[j + 1])
+ values=result[j], indices=result[j + 1])
j += 2
else:
outputs_list[i] = result[j]
j += 1
- ret = nest.pack_sequence_as(self._python_func_outputs, outputs_list)
+ ret = nest.pack_sequence_as(self._func_graph.structured_outputs,
+ outputs_list)
return ret
@@ -725,20 +668,18 @@ def _get_defun_inputs_from_signature(signature):
def _get_defun_inputs_from_args(args):
"""Maps python function args to graph-construction inputs."""
function_inputs = [
- graph_placeholder(arg.dtype, arg.shape) if isinstance(arg, ops.Tensor)
- else arg for arg in nest.flatten(args)
+ graph_placeholder(arg.dtype, arg.shape)
+ if isinstance(arg, ops.Tensor) else arg for arg in nest.flatten(args)
]
return nest.pack_sequence_as(args, function_inputs)
-def _trace_and_define_function(name, python_func, compiled, args, kwds,
- signature=None):
- """Defines and returns graph-mode version of `python_func`.
+def _func_graph_from_py_func(name, python_func, args, kwds, signature=None):
+ """Returns a `FuncGraph` generated from `python_func`.
Args:
name: an identifier for the function.
python_func: the Python function to trace.
- compiled: whether the graph function should be compiled through XLA.
args: the positional args with which the Python function should be called;
ignored if a signature is provided.
kwds: the keyword args with which the Python function should be called;
@@ -750,14 +691,13 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
inputs.
Returns:
- A GraphModeFunction.
+ A FuncGraph.
Raises:
TypeError: If any of `python_func`'s return values is neither `None` nor a
`Tensor`.
"""
- func_graph = FuncGraph(_inference_name(name), graph=ops.get_default_graph())
-
+ func_graph = FuncGraph(name, graph=ops.get_default_graph())
with func_graph.as_default(), AutomaticControlDependencies() as a:
variable_scope.get_variable_scope().set_use_resource(True)
@@ -771,8 +711,7 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
# Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
func_graph.inputs.extend(
x for x in nest.flatten(func_args) + nest.flatten(func_kwds)
- if isinstance(x, ops.Tensor)
- )
+ if isinstance(x, ops.Tensor))
# Variables to help check whether mutation happens in calling the function
# Copy the recursive list, tuple and map structure, but not base objects
@@ -797,6 +736,7 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
this_tape = tape.push_new_tape()
try:
func_outputs = python_func(*func_args, **func_kwds)
+ # invariant: `func_outputs` contains only Tensors and `None`s.
func_outputs = nest.map_structure(convert, func_outputs)
def check_mutation(n1, n2):
@@ -816,53 +756,34 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
check_mutation(func_args_before, func_args)
check_mutation(func_kwds_before, func_kwds)
-
finally:
tape.pop_tape(this_tape)
+
func_graph.structured_outputs = func_outputs
+ # Returning a closed-over tensor does not trigger convert_to_tensor.
+ func_graph.outputs.extend(
+ func_graph.capture(x)
+ for x in _flatten(func_graph.structured_outputs)
+ if x is not None)
+
+ # Some captured variables might be components of DistributedValues.
+ # Instead of storing non-distributed component variables, we
+ # store their distributed containers so we can retrieve the correct
+ # component variables at call-time.
variables = list(this_tape.watched_variables())
-
- # Some variables captured by the tape can come from a DistributedValue.
- # At call time, DistributedValue can return another variable (e.g. if
- # the function is run on a different device). Thus, instead of storing
- # the specific captured variable, we replace it with its distributed
- # container.
strategy = distribution_strategy_context.get_distribution_strategy()
for i, variable in enumerate(variables):
# If variable is not distributed value_container returns itself.
variables[i] = strategy.value_container(variable)
-
func_graph.variables = variables
- # Returning a closed-over tensor as an output does not trigger a
- # call to convert_to_tensor, so we manually capture all such tensors.
- func_graph.outputs.extend(
- func_graph.capture(x) for x in _flatten(func_graph.structured_outputs)
- if x is not None
- )
-
- output_shapes = tuple(
- x.shape if isinstance(x, ops.Tensor) else None
- for x in func_graph.outputs)
-
- all_ignored_ops = frozenset(x.op for x in func_graph.inputs)
- operations = tuple(x for x in func_graph.get_operations()
- if x not in all_ignored_ops)
- # Register any other functions defined in the graph
- # TODO(ashankar): Oh lord, forgive me for this lint travesty.
+ # Register any other functions defined in the graph.
if context.executing_eagerly():
for f in func_graph._functions.values(): # pylint: disable=protected-access
# TODO(ashankar): What about the gradient registry?
_register(f._c_func.func) # pylint: disable=protected-access
- attrs = {}
- if compiled:
- attrs[_xla_compile_attr] = attr_value_pb2.AttrValue(b=True)
-
- return GraphModeFunction(
- func_graph.name, func_graph.inputs, func_graph.captures.keys(),
- func_graph, operations, func_graph.outputs, func_graph.structured_outputs,
- output_shapes, func_graph.variables, attrs)
+ return func_graph
_TensorType = collections.namedtuple("_TensorType", ["dtype", "shape"])
@@ -926,8 +847,7 @@ class _PolymorphicFunction(object):
def __init__(self,
python_function,
name,
- input_signature=None,
- compiled=False):
+ input_signature=None):
"""Initializes a polymorphic function.
Args:
@@ -936,7 +856,6 @@ class _PolymorphicFunction(object):
input_signature: a possibly nested sequence of `TensorSpec` objects
specifying the input signature of this function. If `None`, a separate
function is instantiated for each inferred input signature.
- compiled: if True, the framework will attempt to compile func with XLA.
Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
@@ -955,7 +874,6 @@ class _PolymorphicFunction(object):
self._args_to_prepend = tuple()
self._kwds_to_include = {}
self._name = name
- self._compiled = compiled
self._arguments_to_functions = {}
self._variables = []
@@ -1085,8 +1003,9 @@ class _PolymorphicFunction(object):
if any(not isinstance(arg, ops.Tensor) for arg in flat_inputs):
raise ValueError("When input_signature is provided, all inputs to "
"the Python function must be Tensors.")
- tensor_specs = [tensor_spec.TensorSpec.from_tensor(tensor)
- for tensor in flat_inputs]
+ tensor_specs = [
+ tensor_spec.TensorSpec.from_tensor(tensor) for tensor in flat_inputs
+ ]
if any(not spec.is_compatible_with(other)
for spec, other in zip(self._flat_input_signature, tensor_specs)):
raise ValueError("Python inputs incompatible with input_signature: "
@@ -1120,9 +1039,9 @@ class _PolymorphicFunction(object):
"must be hashable.")
if graph_function is None:
- graph_function = _trace_and_define_function(
- self._name, self._python_function, self._compiled, args, kwds,
- self._input_signature)
+ graph_function = GraphCallable(
+ _func_graph_from_py_func(self._name, self._python_function, args,
+ kwds, self._input_signature))
self._variables.extend(
[v for v in graph_function.variables if v not in self._variables])
self._arguments_to_functions[cache_key] = graph_function
@@ -1143,10 +1062,7 @@ class _PolymorphicFunction(object):
return self._variables
-# TODO(akshayka): Remove the `compiled` flag and create a separate
-# API for xla compilation (`defun` is already complicated enough
-# as it is, and the keyword argument makes 'compiled' an overloaded concept)
-def defun(func=None, input_signature=None, compiled=False):
+def defun(func=None, input_signature=None):
"""Compiles a Python function into a callable TensorFlow graph.
`defun` (short for "define function") trace-compiles a Python function
@@ -1437,9 +1353,10 @@ def defun(func=None, input_signature=None, compiled=False):
func: function to be compiled. If `func` is None, returns a
decorator that can be invoked with a single argument - `func`. The
end result is equivalent to providing all the arguments up front.
- In other words, defun(compiled=True)(func) is equivalent to
- defun(func, compiled=True). The former allows the following use case:
- @tf.contrib.eager.defun(compiled=True)
+ In other words, defun(input_signature=...)(func) is equivalent to
+ defun(func, input_signature=...). The former allows
+ the following use case:
+ @tf.contrib.eager.defun(input_signature=...)
def foo(...):
...
@@ -1450,11 +1367,6 @@ def defun(func=None, input_signature=None, compiled=False):
signature is specified, every input to `func` must be a `Tensor`, and
`func` cannot accept `**kwargs`.
- compiled: If True, an attempt to compile `func` with XLA will be made.
- If it fails, function will be run normally. Experimental. Currently
- supported only for execution on TPUs. For the vast majority of users,
- this argument should be False.
-
Returns:
If `func` is not None, returns a callable that will execute the compiled
function (and return zero or more `tf.Tensor` objects).
@@ -1470,7 +1382,7 @@ def defun(func=None, input_signature=None, compiled=False):
return tf_decorator.make_decorator(
function,
_PolymorphicFunction(
- function, name, input_signature=input_signature, compiled=compiled))
+ function, name, input_signature=input_signature))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None:
@@ -1528,7 +1440,8 @@ def make_defun_op(func, *args, **kwds):
and which can be called directly the way a `@defun` wrapped function
can.
"""
- return _trace_and_define_function(func.__name__, func, False, args, kwds)
+ return GraphCallable(
+ _func_graph_from_py_func(func.__name__, func, args, kwds))
class AutomaticControlDependencies(object):
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
deleted file mode 100644
index 7105d2e399..0000000000
--- a/tensorflow/python/eager/graph_callable.py
+++ /dev/null
@@ -1,435 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Decorator that produces a callable object that executes a TensorFlow graph.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import contextlib
-
-from tensorflow.python.eager import context
-from tensorflow.python.eager import function
-from tensorflow.python.eager import tape
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops as tf_ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.util import nest
-from tensorflow.python.util import tf_decorator
-from tensorflow.python.util import tf_inspect
-
-
-def _default_initializer(name, shape, dtype):
- """The default initializer for variables."""
- # pylint: disable=protected-access
- store = variable_scope._get_default_variable_store()
- initializer = store._get_default_initializer(name, shape=shape, dtype=dtype)
- # pylint: enable=protected-access
- return initializer[0]
-
-
-class _CapturedVariable(object):
- """Variable captured by graph_callable.
-
- Internal to the implementation of graph_callable. Created only by
- _VariableCapturingScope and used only to read the variable values when calling
- the function after the variables are initialized.
- """
-
- def __init__(self, name, initializer, shape, dtype, trainable):
- self.name = name
- if initializer is None:
- initializer = _default_initializer(name, shape, dtype)
- initial_value = lambda: initializer(shape, dtype=dtype)
-
- with context.eager_mode():
- self.variable = resource_variable_ops.ResourceVariable(
- initial_value=initial_value, name=name, dtype=dtype,
- trainable=trainable)
- self.shape = shape
- self.dtype = dtype
- self.placeholder = None
- self.trainable = trainable
-
- def read(self, want_gradients=True):
- if want_gradients and self.trainable:
- v = tape.watch_variable(self.variable)
- else:
- v = self.variable
- return v.read_value()
-
-
-class _VariableCapturingScope(object):
- """Variable-scope-like object which captures tf.get_variable calls.
-
- This is responsible for the main difference between the initialization version
- of a function object and the calling version of a function object.
-
- capturing_scope replaces calls to tf.get_variable with placeholder tensors to
- be fed the variable's current value. TODO(apassos): these placeholders should
- instead be objects implementing a similar API to tf.Variable, for full
- compatibility.
-
- initializing_scope replaces calls to tf.get_variable with creation of
- variables and initialization of their values. This allows eventual support of
- initialized_value and friends.
-
- TODO(apassos): once the eager mode layers API is implemented support eager
- func-to-object as well.
- """
-
- def __init__(self):
- self.variables = {}
- self.tf_variables = {}
-
- @contextlib.contextmanager
- def capturing_scope(self):
- """Context manager to capture variable creations.
-
- Replaces variable accesses with placeholders.
-
- Yields:
- nothing
- """
- # TODO(apassos) ignoring the regularizer and partitioner here; figure out
- # how to deal with these.
- def _custom_getter( # pylint: disable=missing-docstring
- getter=None,
- name=None,
- shape=None,
- dtype=dtypes.float32,
- initializer=None,
- regularizer=None,
- reuse=None,
- trainable=None,
- collections=None,
- caching_device=None, # pylint: disable=redefined-outer-name
- partitioner=None,
- validate_shape=True,
- use_resource=None,
- aggregation=variable_scope.VariableAggregation.NONE,
- synchronization=variable_scope.VariableSynchronization.AUTO):
- del getter, regularizer, partitioner, validate_shape, use_resource, dtype
- del collections, initializer, trainable, reuse, caching_device, shape
- del aggregation, synchronization
- assert name in self.variables
- v = self.variables[name]
- return v.variable
-
- scope = variable_scope.get_variable_scope()
- with variable_scope.variable_scope(scope, custom_getter=_custom_getter):
- yield
-
- @contextlib.contextmanager
- def initializing_scope(self):
- """Context manager to capture variable creations.
-
- Forcibly initializes all created variables.
-
- Yields:
- nothing
- """
- # TODO(apassos) ignoring the regularizer and partitioner here; figure out
- # how to deal with these.
- def _custom_getter( # pylint: disable=missing-docstring
- getter=None,
- name=None,
- shape=None,
- dtype=dtypes.float32,
- initializer=None,
- regularizer=None,
- reuse=None,
- trainable=None,
- collections=None,
- caching_device=None, # pylint: disable=redefined-outer-name
- partitioner=None,
- validate_shape=True,
- use_resource=None,
- aggregation=variable_scope.VariableAggregation.NONE,
- synchronization=variable_scope.VariableSynchronization.AUTO):
- del getter, regularizer, collections, caching_device, partitioner
- del use_resource, validate_shape, aggregation, synchronization
- if name in self.tf_variables:
- if reuse:
- return self.tf_variables[name].initialized_value()
- else:
- raise ValueError("Specified reuse=%s but tried to reuse variables."
- % reuse)
- # TODO(apassos): ensure this is on the same device as above
- v = _CapturedVariable(name, initializer, shape, dtype, trainable)
- self.variables[name] = v
-
- graph_mode_resource = v.variable.handle
- if initializer is None:
- initializer = _default_initializer(name, shape, dtype)
- resource_variable_ops.shape_safe_assign_variable_handle(
- graph_mode_resource, v.variable.shape, initializer(shape, dtype))
- return v.variable
-
- scope = variable_scope.get_variable_scope()
- with variable_scope.variable_scope(scope, custom_getter=_custom_getter):
- yield
-
-
-class _InitializingFunctionObject(object):
- """Responsible for deciding which version of func-to-object to call.
-
- call_fn is the version which calls the function with the current values of the
- variables and init_fn is the version which calls the function to initialize
- all variables.
-
- TODO(apassos): figure out a way to support initializing only _some_
- variables. This requires a way to pull out a variable's initialization code
- from the graph, which might not be possible in general.
- """
-
- def __init__(self, call_fn, init_fn, shape_and_dtypes):
- self._init_fn = init_fn
- self._call_fn = call_fn
- self.shape_and_dtypes = shape_and_dtypes
- self.flattened_shapes = [tensor_shape.as_shape(sd.shape) for sd in
- nest.flatten(self.shape_and_dtypes)]
-
- @property
- def variables(self):
- return self._call_fn.variables
-
- def __call__(self, *args):
- nest.assert_same_structure(self.shape_and_dtypes, args, check_types=False)
- if not all([
- shape.is_compatible_with(arg.shape)
- for shape, arg in zip(self.flattened_shapes, nest.flatten(args))
- ]):
- raise ValueError(
- "Declared shapes do not match argument shapes: Expected %s, found %s."
- % (self.flattened_shapes, [arg.shape for arg in nest.flatten(args)]))
-
- initialized = [resource_variable_ops.var_is_initialized_op(
- v.handle).numpy() for v in self._call_fn.variables]
- if all(x for x in initialized):
- for v in self._call_fn.variables:
- if v.trainable:
- tape.watch_variable(v)
- return self._call_fn(*args)
- elif all(not x for x in initialized):
- return self._init_fn(*args)
- else:
- raise ValueError("Some, but not all, variables are initialized.")
-
-
-def _get_graph_callable_inputs(shape_and_dtypes):
- """Maps specified shape_and_dtypes to graph inputs."""
- ret = []
- for x in shape_and_dtypes:
- if isinstance(x, ShapeAndDtype):
- ret.append(array_ops.placeholder(x.dtype, x.shape))
- elif isinstance(x, (tuple, list)):
- ret.append(_get_graph_callable_inputs(x))
- else:
- raise errors.InvalidArgumentError(
- None, None, "Expected the argument to @graph_callable to be a "
- "(possibly nested) list or tuple of ShapeAndDtype objects, "
- "but got an object of type: %s" % type(x))
-
- return tuple(ret) if isinstance(shape_and_dtypes, tuple) else ret
-
-
-def _graph_callable_internal(func, shape_and_dtypes):
- """Defines and returns a template version of func.
-
- Under the hood we make two function objects, each wrapping a different version
- of the graph-mode code. One version immediately runs variable initialization
- before making the variable's Tensors available for use, while the other
- version replaces the Variables with placeholders which become function
- arguments and get the current variable's value.
-
- Limitations in (2) and (4) are because this does not implement a graph-mode
- Variable class which has a convert_to_tensor(as_ref=True) method and a
- initialized_value method. This is fixable.
-
- Args:
- func: The tfe Python function to compile.
- shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects.
-
- Raises:
- ValueError: If any one of func's outputs is not a Tensor.
-
- Returns:
- Callable graph object.
- """
- container = tf_ops.get_default_graph()._container # pylint: disable=protected-access
- graph_key = tf_ops.get_default_graph()._graph_key # pylint: disable=protected-access
- with context.graph_mode():
- # This graph will store both the initialization and the call version of the
- # wrapped function. It will later be used by the backprop code to build the
- # backprop graph, if necessary.
- tmp_graph = function.CapturingGraph()
- # Inherit the graph key from the original graph to ensure optimizers don't
- # misbehave.
- tmp_graph._container = container # pylint: disable=protected-access
- tmp_graph._graph_key = graph_key # pylint: disable=protected-access
- with tmp_graph.as_default():
- # Placeholders for the non-variable inputs.
- func_inputs = _get_graph_callable_inputs(shape_and_dtypes)
- func_num_args = len(tf_inspect.getfullargspec(func).args)
- if len(func_inputs) != func_num_args:
- raise TypeError("The number of arguments accepted by the decorated "
- "function `%s` (%d) must match the number of "
- "ShapeAndDtype objects passed to the graph_callable() "
- "decorator (%d)." %
- (func.__name__, func_num_args, len(func_inputs)))
-
- # First call the function to generate a graph which can initialize all
- # variables. As a side-effect this will populate the variable capturing
- # scope's view of which variables exist.
- variable_captures = _VariableCapturingScope()
- with variable_captures.initializing_scope(
- ), function.AutomaticControlDependencies() as a:
- func_outputs = func(*func_inputs)
- outputs_list = nest.flatten(func_outputs)
- for i, x in enumerate(outputs_list):
- if x is not None:
- outputs_list[i] = a.mark_as_return(x)
- if len(outputs_list) == 1 and outputs_list[0] is None:
- outputs_list = []
- output_shapes = [x.shape for x in outputs_list]
- if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list):
- raise ValueError("Found non-tensor output in %s" % str(outputs_list))
- initializing_operations = tmp_graph.get_operations()
-
- # Call the function again, now replacing usages of variables with
- # placeholders. This assumes the variable capturing scope created above
- # knows about all variables.
- tmp_graph.clear_resource_control_flow_state()
- with variable_captures.capturing_scope(
- ), function.AutomaticControlDependencies() as a:
- captured_outputs = func(*func_inputs)
- captured_outlist = nest.flatten(captured_outputs)
- for i, x in enumerate(captured_outlist):
- if x is not None:
- captured_outlist[i] = a.mark_as_return(x)
- capturing_operations = tmp_graph.get_operations()[
- len(initializing_operations):]
-
- sorted_variables = sorted(variable_captures.variables.values(),
- key=lambda x: x.name)
-
- extra_inputs = tmp_graph.captures.keys()
- extra_placeholders = tmp_graph.captures.values()
-
- flat_inputs = [x for x in nest.flatten(func_inputs)
- if isinstance(x, tf_ops.Tensor)]
- placeholder_inputs = flat_inputs+ list(extra_placeholders)
-
- func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)]
- initialization_name = function._inference_name(func.__name__) # pylint: disable=protected-access
- # TODO(ashankar): Oh lord, forgive me for this lint travesty.
- # Also, what about the gradient registry of these functions? Those need to be
- # addressed as well.
- for f in tmp_graph._functions.values(): # pylint: disable=protected-access
- function._register(f._c_func.func) # pylint: disable=protected-access
- initializer_function = function.GraphModeFunction(
- initialization_name,
- placeholder_inputs,
- extra_inputs,
- tmp_graph,
- initializing_operations,
- func_def_outputs,
- func_outputs,
- output_shapes)
-
- capture_func_def_outputs = [
- x for x in captured_outlist if isinstance(x, tf_ops.Tensor)]
- captured_function_name = function._inference_name(func.__name__) # pylint: disable=protected-access
- captured_function = function.GraphModeFunction(
- captured_function_name,
- placeholder_inputs,
- extra_inputs,
- tmp_graph,
- capturing_operations,
- capture_func_def_outputs,
- captured_outputs,
- output_shapes,
- variables=[x.variable for x in sorted_variables])
-
- return _InitializingFunctionObject(captured_function, initializer_function,
- shape_and_dtypes)
-
-
-class ShapeAndDtype(object):
- """Data type that packages together shape and type information.
-
- Used for arguments to graph callables. See graph_callable() for an example.
- """
-
- def __init__(self, shape, dtype):
- self.shape = shape
- self.dtype = dtype
-
-
-def graph_callable(shape_and_dtypes):
- """Decorator that produces a callable that executes a TensorFlow graph.
-
- When applied on a function that constructs a TensorFlow graph, this decorator
- produces a callable object that:
-
- 1. Executes the graph when invoked. The first call will initialize any
- variables defined in the graph.
-
- 2. Provides a .variables() method to return the list of TensorFlow variables
- defined in the graph.
-
- Note that the wrapped function is not allowed to change the values of the
- variables, just use them.
-
- The return value of the wrapped function must be one of the following:
- (1) None, (2) a Tensor, or (3) a possibly nested sequence of Tensors.
-
- Example:
-
- ```python
- @tfe.graph_callable([tfe.ShapeAndDtype(shape(), dtype=dtypes.float32)])
- def foo(x):
- v = tf.get_variable('v', initializer=tf.ones_initializer(), shape=())
- return v + x
-
- ret = foo(tfe.Tensor(2.0)) # `ret` here is a Tensor with value 3.0.
-
- foo.variables[0].assign(7.0) # Modify the value of variable `v`.
- ret = foo(tfe.Tensor(2.0)) # `ret` here now is a Tensor with value 9.0.
- ```
- Args:
- shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects
- that specifies shape and type information for each of the callable's
- arguments. The length of this list must be equal to the number of
- arguments accepted by the wrapped function.
-
- Returns:
- A callable graph object.
- """
- # TODO(alive,apassos): support initialized_value and friends from tf.Variable.
- assert context.executing_eagerly(), (
- "graph_callable can only be used when Eager execution is enabled.")
- def decorator(func):
- return tf_decorator.make_decorator(func,
- _graph_callable_internal(
- func, shape_and_dtypes))
-
- return decorator
diff --git a/tensorflow/python/eager/graph_callable_test.py b/tensorflow/python/eager/graph_callable_test.py
deleted file mode 100644
index b9e6ca2a93..0000000000
--- a/tensorflow/python/eager/graph_callable_test.py
+++ /dev/null
@@ -1,249 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.eager import backprop
-from tensorflow.python.eager import graph_callable
-from tensorflow.python.eager import test
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import variable_scope
-
-
-class GraphCallableTest(test.TestCase):
-
- def testBasic(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def my_function(x):
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- return v + x
-
- self.assertEqual(
- 2, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
-
- my_function.variables[0].assign(1.)
- self.assertEqual(
- 3, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
-
- def testFunctionWithoutReturnValue(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def my_function(x):
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- v.assign(x)
-
- my_function(constant_op.constant(4, dtype=dtypes.float32))
- self.assertAllEqual(4, my_function.variables[0].read_value())
-
- def testFunctionWithoutReturnValueAndArgs(self):
-
- @graph_callable.graph_callable([])
- def my_function():
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- v.assign(4)
-
- my_function()
- self.assertAllEqual(4, my_function.variables[0].read_value())
-
- def testVariableAPI(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def my_function(x):
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- return v.read_value() + x
-
- self.assertEqual(
- 2, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
-
- my_function.variables[0].assign(1.)
- self.assertEqual(
- 3, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
-
- def testTensorShape(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(1), dtype=dtypes.float32)])
- def my_function(x):
- _ = x.get_shape()
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=[x.shape[0]])
- self.assertEqual(v.shape[0], x.shape[0])
- return v + x
-
- self.assertEqual([2.],
- my_function(
- constant_op.constant([2.],
- dtype=dtypes.float32)).numpy())
-
- def testUpdatesAreOrdered(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def my_function(x):
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- v.assign(x + 1)
- v.assign(v * x)
- return v.read_value()
-
- self.assertAllEqual(my_function(constant_op.constant(2.0)), 6.0)
-
- def testEmptyInitializer(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(1), dtype=dtypes.float32)])
- def my_function(x):
- v = variable_scope.get_variable("v", shape=[1])
- return x + 0 * v
-
- self.assertEqual([2.],
- my_function(
- constant_op.constant([2.],
- dtype=dtypes.float32)).numpy())
-
- def testMismatchingNumArgs(self):
- # pylint: disable=anomalous-backslash-in-string
- with self.assertRaisesRegexp(TypeError,
- "The number of arguments accepted by the "
- "decorated function `my_function` \(2\) must "
- "match the number of ShapeAndDtype objects "
- "passed to the graph_callable\(\) decorator "
- "\(1\)."):
- @graph_callable.graph_callable([
- graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def my_function(x, y): # pylint: disable=unused-variable
- return x + y
- # pylint: enable=anomalous-backslash-in-string
-
- def testPureFunction(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
- def f(x):
- return math_ops.add(x, constant_op.constant(3))
-
- self.assertAllEqual(5, f(constant_op.constant(2)))
-
- def testNestedFunction(self):
- # TensorFlow function (which is what would be used in TensorFlow graph
- # construction).
- @function.Defun(dtypes.int32, dtypes.int32)
- def add(a, b):
- return math_ops.add(a, b)
-
- # A graph_callable that will invoke the TensorFlow function.
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
- def add_one(x):
- return add(x, 1)
-
- self.assertAllEqual(3, add_one(constant_op.constant(2)))
-
- # TODO(ashankar): Make this work.
- # The problem is that the two graph_callables (for add_one and add_two)
- # are both trying to register the FunctionDef corresponding to "add".
- def DISABLED_testRepeatedUseOfSubFunction(self):
-
- @function.Defun(dtypes.int32, dtypes.int32)
- def add(a, b):
- return math_ops.add(a, b)
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
- def add_one(x):
- return add(x, 1)
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
- def add_two(x):
- return add(x, 2)
-
- two = constant_op.constant(2)
- self.assertAllEqual(3, add_one(two))
- self.assertAllEqual(4, add_two(two))
-
- def testNestedSequenceInputs(self):
- sd = graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)
- @graph_callable.graph_callable([[sd, tuple([sd, sd]), sd]])
- def my_op(inputs):
- a, b, c = inputs
- e, f = b
- v = variable_scope.get_variable(
- "my_v", initializer=init_ops.zeros_initializer(), shape=())
- return [a + a + v, tuple([e + e, f + f]), c + c], a + e + f + c + v
-
- inputs = [constant_op.constant(1.),
- [constant_op.constant(2.), constant_op.constant(3.)],
- constant_op.constant(4.)]
- ret = my_op(inputs)
- self.assertEqual(len(ret), 2.)
- self.assertAllEqual(ret[1], 10.)
-
- my_op.variables[0].assign(1.)
- ret = my_op(inputs)
- self.assertAllEqual(ret[1], 11.)
-
- def testVariableShapeIsTensorShape(self):
- @graph_callable.graph_callable([])
- def my_function():
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- self.assertIsInstance(v.get_shape(), tensor_shape.TensorShape)
-
- my_function()
-
- def testIncorrectlyShapedInputs(self):
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(3), dtype=dtypes.float32)])
- def my_function(x):
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- return v + x
-
- with self.assertRaises(ValueError):
- my_function([1, 2])
-
- self.assertTrue(([1, 2, 3] == my_function(
- constant_op.constant([1, 2, 3], dtype=dtypes.float32)).numpy()).all())
-
- def testGradients(self):
- @graph_callable.graph_callable([])
- def my_function():
- v = variable_scope.get_variable(
- "v", initializer=init_ops.constant_initializer(3.), shape=())
- return v * v
-
- grad_fn = backprop.implicit_grad(my_function)
- grads_and_vars = list(zip(*grad_fn()))
- self.assertAllEqual(6., grads_and_vars[0][0])
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index 15d2ccf9d2..c12bf89f8f 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -800,9 +800,6 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
EagerTensorType = &_EagerTensorType;
Py_INCREF(EagerTensorType);
#endif
- // We disable instance based attribute lookup. Its not clear if these
- // dictionaries are correctly initialized in the first place.
- EagerTensorType->tp_dictoffset = 0;
return reinterpret_cast<PyObject*>(EagerTensorType);
}
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 16928ca4b7..ef7c217190 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -404,18 +404,21 @@ class _EnsembleGrower(object):
training_ops.append(grow_op)
"""
- def __init__(self, tree_ensemble, tree_hparams):
+ def __init__(self, tree_ensemble, tree_hparams, feature_ids_list):
"""Initializes a grower object.
Args:
tree_ensemble: A TreeEnsemble variable.
tree_hparams: TODO. collections.namedtuple for hyper parameters.
+ feature_ids_list: a list of lists of feature ids for each bucket size.
+
Raises:
ValueError: when pruning mode is invalid or pruning is used and no tree
complexity is set.
"""
self._tree_ensemble = tree_ensemble
self._tree_hparams = tree_hparams
+ self._feature_ids_list = feature_ids_list
# pylint: disable=protected-access
self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str(
tree_hparams.pruning_mode)
@@ -440,14 +443,12 @@ class _EnsembleGrower(object):
"""
@abc.abstractmethod
- def grow_tree(self, stats_summaries_list, feature_ids_list,
- last_layer_nodes_range):
+ def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
"""Grows a tree, if ready, based on provided statistics.
Args:
stats_summaries_list: List of stats summary tensors, representing sums of
gradients and hessians for each feature bucket.
- feature_ids_list: a list of lists of feature ids for each bucket size.
last_layer_nodes_range: A tensor representing ids of the nodes in the
current layer, to be split.
@@ -455,6 +456,10 @@ class _EnsembleGrower(object):
An op for growing a tree.
"""
+ def chief_init_op(self):
+ """Ops that chief needs to run to initialize the state."""
+ return control_flow_ops.no_op()
+
# ============= Helper methods ===========
def _center_bias_fn(self, center_bias_var, mean_gradients, mean_hessians):
@@ -468,7 +473,7 @@ class _EnsembleGrower(object):
return center_bias_var.assign(continue_centering)
def _grow_tree_from_stats_summaries(self, stats_summaries_list,
- feature_ids_list, last_layer_nodes_range):
+ last_layer_nodes_range):
"""Updates ensemble based on the best gains from stats summaries."""
node_ids_per_feature = []
gains_list = []
@@ -476,11 +481,11 @@ class _EnsembleGrower(object):
left_node_contribs_list = []
right_node_contribs_list = []
all_feature_ids = []
- assert len(stats_summaries_list) == len(feature_ids_list)
+ assert len(stats_summaries_list) == len(self._feature_ids_list)
max_splits = _get_max_splits(self._tree_hparams)
- for i, feature_ids in enumerate(feature_ids_list):
+ for i, feature_ids in enumerate(self._feature_ids_list):
(numeric_node_ids_per_feature, numeric_gains_list,
numeric_thresholds_list, numeric_left_node_contribs_list,
numeric_right_node_contribs_list) = (
@@ -516,12 +521,13 @@ class _EnsembleGrower(object):
class _InMemoryEnsembleGrower(_EnsembleGrower):
- """A base class for ensemble growers."""
+ """An in-memory ensemble grower."""
- def __init__(self, tree_ensemble, tree_hparams):
+ def __init__(self, tree_ensemble, tree_hparams, feature_ids_list):
super(_InMemoryEnsembleGrower, self).__init__(
- tree_ensemble=tree_ensemble, tree_hparams=tree_hparams)
+ tree_ensemble=tree_ensemble, tree_hparams=tree_hparams,
+ feature_ids_list=feature_ids_list)
def center_bias(self, center_bias_var, gradients, hessians):
# For in memory, we already have a full batch of gradients and hessians,
@@ -531,83 +537,98 @@ class _InMemoryEnsembleGrower(_EnsembleGrower):
mean_heassians = array_ops.expand_dims(math_ops.reduce_mean(hessians, 0), 0)
return self._center_bias_fn(center_bias_var, mean_gradients, mean_heassians)
- def grow_tree(self, stats_summaries_list, feature_ids_list,
- last_layer_nodes_range):
+ def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
# For in memory, we already have full data in one batch, so we can grow the
# tree immediately.
return self._grow_tree_from_stats_summaries(
- stats_summaries_list, feature_ids_list, last_layer_nodes_range)
+ stats_summaries_list, last_layer_nodes_range)
class _AccumulatorEnsembleGrower(_EnsembleGrower):
- """A base class for ensemble growers."""
+ """An accumulator based ensemble grower."""
def __init__(self, tree_ensemble, tree_hparams, stamp_token,
- n_batches_per_layer, bucket_size_list, is_chief):
+ n_batches_per_layer, bucket_size_list, is_chief, center_bias,
+ feature_ids_list):
super(_AccumulatorEnsembleGrower, self).__init__(
- tree_ensemble=tree_ensemble, tree_hparams=tree_hparams)
+ tree_ensemble=tree_ensemble, tree_hparams=tree_hparams,
+ feature_ids_list=feature_ids_list)
self._stamp_token = stamp_token
self._n_batches_per_layer = n_batches_per_layer
self._bucket_size_list = bucket_size_list
self._is_chief = is_chief
+ self._growing_accumulators = []
+ self._chief_init_ops = []
+ max_splits = _get_max_splits(self._tree_hparams)
+ for i, feature_ids in enumerate(self._feature_ids_list):
+ accumulator = data_flow_ops.ConditionalAccumulator(
+ dtype=dtypes.float32,
+ # The stats consist of grads and hessians (the last dimension).
+ shape=[len(feature_ids), max_splits, self._bucket_size_list[i], 2],
+ shared_name='numeric_stats_summary_accumulator_' + str(i))
+ self._chief_init_ops.append(
+ accumulator.set_global_step(self._stamp_token))
+ self._growing_accumulators.append(accumulator)
+ self._center_bias = center_bias
+ if center_bias:
+ self._bias_accumulator = data_flow_ops.ConditionalAccumulator(
+ dtype=dtypes.float32,
+ # The stats consist of grads and hessians means only.
+ # TODO(nponomareva): this will change for a multiclass
+ shape=[2, 1],
+ shared_name='bias_accumulator')
+ self._chief_init_ops.append(
+ self._bias_accumulator.set_global_step(self._stamp_token))
def center_bias(self, center_bias_var, gradients, hessians):
# For not in memory situation, we need to accumulate enough of batches first
# before proceeding with centering bias.
# Create an accumulator.
+ if not self._center_bias:
+ raise RuntimeError('center_bias called but bias centering is disabled.')
bias_dependencies = []
- bias_accumulator = data_flow_ops.ConditionalAccumulator(
- dtype=dtypes.float32,
- # The stats consist of grads and hessians means only.
- # TODO(nponomareva): this will change for a multiclass
- shape=[2, 1],
- shared_name='bias_accumulator')
-
grads_and_hess = array_ops.stack([gradients, hessians], axis=0)
grads_and_hess = math_ops.reduce_mean(grads_and_hess, axis=1)
- apply_grad = bias_accumulator.apply_grad(grads_and_hess, self._stamp_token)
+ apply_grad = self._bias_accumulator.apply_grad(
+ grads_and_hess, self._stamp_token)
bias_dependencies.append(apply_grad)
# Center bias if enough batches were processed.
with ops.control_dependencies(bias_dependencies):
if not self._is_chief:
return control_flow_ops.no_op()
+ def _set_accumulators_stamp():
+ return control_flow_ops.group(
+ [acc.set_global_step(self._stamp_token + 1) for acc in
+ self._growing_accumulators])
def center_bias_from_accumulator():
- accumulated = array_ops.unstack(bias_accumulator.take_grad(1), axis=0)
- return self._center_bias_fn(center_bias_var,
- array_ops.expand_dims(accumulated[0], 0),
- array_ops.expand_dims(accumulated[1], 0))
+ accumulated = array_ops.unstack(self._bias_accumulator.take_grad(1),
+ axis=0)
+ center_bias_op = self._center_bias_fn(
+ center_bias_var,
+ array_ops.expand_dims(accumulated[0], 0),
+ array_ops.expand_dims(accumulated[1], 0))
+ with ops.control_dependencies([center_bias_op]):
+ return control_flow_ops.cond(center_bias_var,
+ control_flow_ops.no_op,
+ _set_accumulators_stamp)
center_bias_op = control_flow_ops.cond(
- math_ops.greater_equal(bias_accumulator.num_accumulated(),
+ math_ops.greater_equal(self._bias_accumulator.num_accumulated(),
self._n_batches_per_layer),
center_bias_from_accumulator,
control_flow_ops.no_op,
name='wait_until_n_batches_for_bias_accumulated')
return center_bias_op
- def grow_tree(self, stats_summaries_list, feature_ids_list,
- last_layer_nodes_range):
- # For not in memory situation, we need to accumulate enough of batches first
- # before proceeding with building a tree layer.
- max_splits = _get_max_splits(self._tree_hparams)
-
- # Prepare accumulators.
- accumulators = []
+ def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
dependencies = []
- for i, feature_ids in enumerate(feature_ids_list):
+ for i in range(len(self._feature_ids_list)):
stats_summaries = stats_summaries_list[i]
- accumulator = data_flow_ops.ConditionalAccumulator(
- dtype=dtypes.float32,
- # The stats consist of grads and hessians (the last dimension).
- shape=[len(feature_ids), max_splits, self._bucket_size_list[i], 2],
- shared_name='numeric_stats_summary_accumulator_' + str(i))
- accumulators.append(accumulator)
-
- apply_grad = accumulator.apply_grad(
+ apply_grad = self._growing_accumulators[i].apply_grad(
array_ops.stack(stats_summaries, axis=0), self._stamp_token)
dependencies.append(apply_grad)
@@ -617,7 +638,8 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
return control_flow_ops.no_op()
min_accumulated = math_ops.reduce_min(
- array_ops.stack([acc.num_accumulated() for acc in accumulators]))
+ array_ops.stack([acc.num_accumulated() for acc in
+ self._growing_accumulators]))
def grow_tree_from_accumulated_summaries_fn():
"""Updates tree with the best layer from accumulated summaries."""
@@ -625,10 +647,11 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
stats_summaries_list = []
stats_summaries_list = [
array_ops.unstack(accumulator.take_grad(1), axis=0)
- for accumulator in accumulators
+ for accumulator in self._growing_accumulators
]
grow_op = self._grow_tree_from_stats_summaries(
- stats_summaries_list, feature_ids_list, last_layer_nodes_range)
+ stats_summaries_list, last_layer_nodes_range
+ )
return grow_op
grow_model = control_flow_ops.cond(
@@ -638,6 +661,10 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
name='wait_until_n_batches_accumulated')
return grow_model
+ def chief_init_op(self):
+ """Ops that chief needs to run to initialize the state."""
+ return control_flow_ops.group(self._chief_init_ops)
+
def _bt_model_fn(
features,
@@ -683,21 +710,7 @@ def _bt_model_fn(
Raises:
ValueError: mode or params are invalid, or features has the wrong type.
"""
- is_single_machine = (config.num_worker_replicas <= 1)
sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name)
- center_bias = tree_hparams.center_bias
-
- if train_in_memory:
- assert n_batches_per_layer == 1, (
- 'When train_in_memory is enabled, input_fn should return the entire '
- 'dataset as a single batch, and n_batches_per_layer should be set as '
- '1.')
- if (not config.is_chief or config.num_worker_replicas > 1 or
- config.num_ps_replicas > 0):
- raise ValueError('train_in_memory is supported only for '
- 'non-distributed training.')
- worker_device = control_flow_ops.no_op().device
- train_op = []
with ops.name_scope(name) as name:
# Prepare.
global_step = training_util.get_or_create_global_step()
@@ -724,6 +737,20 @@ def _bt_model_fn(
logits=logits)
# ============== Training graph ==============
+ center_bias = tree_hparams.center_bias
+ is_single_machine = (config.num_worker_replicas <= 1)
+
+ if train_in_memory:
+ assert n_batches_per_layer == 1, (
+ 'When train_in_memory is enabled, input_fn should return the entire '
+ 'dataset as a single batch, and n_batches_per_layer should be set as '
+ '1.')
+ if (not config.is_chief or config.num_worker_replicas > 1 or
+ config.num_ps_replicas > 0):
+ raise ValueError('train_in_memory is supported only for '
+ 'non-distributed training.')
+ worker_device = control_flow_ops.no_op().device
+ train_op = []
# Extract input features and set up cache for training.
training_state_cache = None
if train_in_memory:
@@ -742,22 +769,6 @@ def _bt_model_fn(
example_ids = features[example_id_column_name]
training_state_cache = _CacheTrainingStatesUsingHashTable(
example_ids, head.logits_dimension)
-
- # Variable that determines whether bias centering is needed.
- center_bias_var = variable_scope.variable(
- initial_value=center_bias, name='center_bias_needed', trainable=False)
- if is_single_machine:
- local_tree_ensemble = tree_ensemble
- ensemble_reload = control_flow_ops.no_op()
- else:
- # Have a local copy of ensemble for the distributed setting.
- with ops.device(worker_device):
- local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
- name=name + '_local', is_local=True)
- # TODO(soroush): Do partial updates if this becomes a bottleneck.
- ensemble_reload = local_tree_ensemble.deserialize(
- *tree_ensemble.serialize())
-
if training_state_cache:
cached_tree_ids, cached_node_ids, cached_logits = (
training_state_cache.lookup())
@@ -770,21 +781,46 @@ def _bt_model_fn(
array_ops.zeros(
[batch_size, head.logits_dimension], dtype=dtypes.float32))
+ if is_single_machine:
+ local_tree_ensemble = tree_ensemble
+ ensemble_reload = control_flow_ops.no_op()
+ else:
+ # Have a local copy of ensemble for the distributed setting.
+ with ops.device(worker_device):
+ local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ name=name + '_local', is_local=True)
+ # TODO(soroush): Do partial updates if this becomes a bottleneck.
+ ensemble_reload = local_tree_ensemble.deserialize(
+ *tree_ensemble.serialize())
with ops.control_dependencies([ensemble_reload]):
(stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
last_layer_nodes_range) = local_tree_ensemble.get_states()
- summary.scalar('ensemble/num_trees', num_trees)
- summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
- summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
-
partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
tree_ensemble_handle=local_tree_ensemble.resource_handle,
cached_tree_ids=cached_tree_ids,
cached_node_ids=cached_node_ids,
bucketized_features=input_feature_list,
logits_dimension=head.logits_dimension)
- logits = cached_logits + partial_logits
+ logits = cached_logits + partial_logits
+
+ if train_in_memory:
+ grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams,
+ feature_ids_list=feature_ids_list)
+ else:
+ grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams,
+ stamp_token, n_batches_per_layer,
+ bucket_size_list, config.is_chief,
+ center_bias=center_bias,
+ feature_ids_list=feature_ids_list)
+
+ summary.scalar('ensemble/num_trees', num_trees)
+ summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
+ summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
+ # Variable that determines whether bias centering is needed.
+ center_bias_var = variable_scope.variable(
+ initial_value=center_bias, name='center_bias_needed', trainable=False,
+ use_resource=True)
# Create training graph.
def _train_op_fn(loss):
"""Run one training iteration."""
@@ -823,24 +859,20 @@ def _bt_model_fn(
axis=0) for f in feature_ids
]
stats_summaries_list.append(summaries)
-
- if train_in_memory and is_single_machine:
- grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams)
+ if center_bias:
+ update_model = control_flow_ops.cond(
+ center_bias_var,
+ functools.partial(
+ grower.center_bias,
+ center_bias_var,
+ gradients,
+ hessians,
+ ),
+ functools.partial(grower.grow_tree, stats_summaries_list,
+ last_layer_nodes_range))
else:
- grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams,
- stamp_token, n_batches_per_layer,
- bucket_size_list, config.is_chief)
-
- update_model = control_flow_ops.cond(
- center_bias_var,
- functools.partial(
- grower.center_bias,
- center_bias_var,
- gradients,
- hessians,
- ),
- functools.partial(grower.grow_tree, stats_summaries_list,
- feature_ids_list, last_layer_nodes_range))
+ update_model = grower.grow_tree(stats_summaries_list,
+ last_layer_nodes_range)
train_op.append(update_model)
with ops.control_dependencies([update_model]):
@@ -859,10 +891,22 @@ def _bt_model_fn(
estimator_spec = estimator_spec._replace(
training_hooks=estimator_spec.training_hooks +
(_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
- tree_hparams.n_trees, tree_hparams.max_depth),))
+ tree_hparams.n_trees, tree_hparams.max_depth),),
+ training_chief_hooks=[GrowerInitializationHook(grower.chief_init_op())] +
+ list(estimator_spec.training_chief_hooks))
return estimator_spec
+class GrowerInitializationHook(session_run_hook.SessionRunHook):
+ """A SessionRunHook handles initialization of `_EnsembleGrower`."""
+
+ def __init__(self, init_op):
+ self._init_op = init_op
+
+ def after_create_session(self, session, coord):
+ session.run(self._init_op)
+
+
def _create_classification_head(n_classes,
weight_column=None,
label_vocabulary=None):
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index ec597e4686..08026a93c5 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -173,6 +173,26 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
eval_res = est.evaluate(input_fn=input_fn, steps=1)
self.assertAllClose(eval_res['accuracy'], 1.0)
+ def testTrainTwiceAndEvaluateBinaryClassifier(self):
+ input_fn = _make_train_input_fn(is_classification=True)
+
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=5,
+ max_depth=10)
+
+ num_steps = 2
+ # Train for a few steps, and validate final checkpoint.
+ est.train(input_fn, steps=num_steps)
+ est.train(input_fn, steps=num_steps)
+
+ self._assert_checkpoint(
+ est.model_dir, global_step=num_steps * 2,
+ finalized_trees=0, attempted_layers=4)
+ eval_res = est.evaluate(input_fn=input_fn, steps=1)
+ self.assertAllClose(eval_res['accuracy'], 1.0)
+
def testInferBinaryClassifier(self):
train_input_fn = _make_train_input_fn(is_classification=True)
predict_input_fn = numpy_io.numpy_input_fn(
diff --git a/tensorflow/python/estimator/canned/prediction_keys.py b/tensorflow/python/estimator/canned/prediction_keys.py
index 16890ec09a..daa275b46b 100644
--- a/tensorflow/python/estimator/canned/prediction_keys.py
+++ b/tensorflow/python/estimator/canned/prediction_keys.py
@@ -32,3 +32,4 @@ class PredictionKeys(object):
LOGITS = 'logits'
PREDICTIONS = 'predictions'
PROBABILITIES = 'probabilities'
+ TOP_K = 'top_k'
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index eab608813b..f7ee42c7f6 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -120,7 +120,8 @@ class Estimator(object):
warm_start_from=None):
"""Constructs an `Estimator` instance.
- See @{$estimators} for more information. To warm-start an `Estimator`:
+ See [estimators](https://tensorflow.org/guide/estimators) for more information.
+ To warm-start an `Estimator`:
```python
estimator = tf.estimator.DNNClassifier(
@@ -152,9 +153,9 @@ class Estimator(object):
* `params`: Optional `dict` of hyperparameters. Will receive what
is passed to Estimator in `params` parameter. This allows
to configure Estimators from hyper parameter tuning.
- * `config`: Optional configuration object. Will receive what is passed
- to Estimator in `config` parameter, or the default `config`.
- Allows updating things in your `model_fn` based on
+ * `config`: Optional `estimator.RunConfig` object. Will receive what
+ is passed to Estimator as its `config` parameter, or a default
+ value. Allows setting up things in your `model_fn` based on
configuration such as `num_ps_replicas`, or `model_dir`.
* Returns:
@@ -166,7 +167,7 @@ class Estimator(object):
path will be resolved. If `None`, the model_dir in `config` will be used
if set. If both are set, they must be same. If both are `None`, a
temporary directory will be used.
- config: Configuration object.
+ config: `estimator.RunConfig` configuration object.
params: `dict` of hyper parameters that will be passed into `model_fn`.
Keys are names of parameters, values are basic python types.
warm_start_from: Optional string filepath to a checkpoint or SavedModel to
@@ -184,8 +185,8 @@ class Estimator(object):
"""
Estimator._assert_members_are_not_overridden(self)
- config = maybe_overwrite_model_dir_and_session_config(config, model_dir)
- self._config = config
+ self._config = maybe_overwrite_model_dir_and_session_config(config,
+ model_dir)
# The distribute field contains an instance of DistributionStrategy.
self._train_distribution = self._config.train_distribute
@@ -285,8 +286,10 @@ class Estimator(object):
Args:
input_fn: A function that provides input data for training as minibatches.
- See @{$premade_estimators#create_input_functions} for more information.
- The function should construct and return one of the following: * A
+ See [Premade
+ Estimators](https://tensorflow.org/guide/premade_estimators#create_input_functions)
+ for more information. The function should construct and return one of
+ the following: * A
`tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple
`(features, labels)` with same constraints as below. * A tuple
`(features, labels)`: Where `features` is a `tf.Tensor` or a dictionary
@@ -321,6 +324,14 @@ class Estimator(object):
ValueError: If both `steps` and `max_steps` are not `None`.
ValueError: If either `steps` or `max_steps <= 0`.
"""
+ if self.config.task_type in (run_config.TaskType.EVALUATOR,
+ run_config.TaskType.PS):
+ raise ValueError(
+ 'Train has been called wrong configuration. Please use '
+ 'tf.estimator.train_and_evaluate which calls propper API according '
+ 'to given configuration. Current configuration: {}.'.format(
+ self.config))
+
with context.graph_mode():
if (steps is not None) and (max_steps is not None):
raise ValueError('Can not provide both steps and max_steps.')
@@ -394,7 +405,8 @@ class Estimator(object):
Args:
input_fn: A function that constructs the input data for evaluation. See
- @{$premade_estimators#create_input_functions} for more information. The
+ [Premade Estimators](https://tensorflow.org/guide/premade#create_input_functions}
+ for more information. The
function should construct and return one of the following: * A
`tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple
`(features, labels)` with same constraints as below. * A tuple
@@ -450,9 +462,7 @@ class Estimator(object):
output_dir=self.eval_dir(name))
with ops.Graph().as_default():
- # TODO(priyag): Support distributed eval on TPUs.
- if (self._eval_distribution
- and self._eval_distribution.__class__.__name__ != 'TPUStrategy'):
+ if self._eval_distribution:
with self._eval_distribution.scope():
return _evaluate()
else:
@@ -478,8 +488,9 @@ class Estimator(object):
input_fn: A function that constructs the features. Prediction continues
until `input_fn` raises an end-of-input exception
(`tf.errors.OutOfRangeError` or `StopIteration`).
- See @{$premade_estimators#create_input_functions} for more
- information. The function should construct and return one of
+ See [Premade
+ Estimators](https://tensorflow.org/guide/premade_estimators#create_input_functions)
+ for more information. The function should construct and return one of
the following:
* A `tf.data.Dataset` object: Outputs of `Dataset` object must have
@@ -595,8 +606,7 @@ class Estimator(object):
"""Exports inference graph as a `SavedModel` into the given dir.
For a detailed guide, see
- @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with
- Estimators}.
+ [Using SavedModel with Estimators](https://tensorflow.org/guide/saved_model#using_savedmodel_with_estimators).
This method builds a new graph by first calling the
`serving_input_receiver_fn` to obtain feature `Tensor`s, and then calling
@@ -1020,16 +1030,21 @@ class Estimator(object):
'QueueRunner. That means predict yields forever. '
'This is probably a mistake.')
- def _get_features_and_labels_from_input_fn(self, input_fn, mode,
- distribution=None):
- """Extracts the `features` and labels from return values of `input_fn`."""
+ def _get_iterator_from_input_fn(self, input_fn, mode, distribution=None):
if distribution is not None:
result = distribution.distribute_dataset(
lambda: self._call_input_fn(input_fn, mode))
else:
result = self._call_input_fn(input_fn, mode)
- return estimator_util.parse_input_fn_result(result)
+ iterator = result.make_initializable_iterator()
+ input_hooks = [estimator_util._DatasetInitializerHook(iterator)] # pylint: disable=protected-access
+ return iterator, input_hooks
+
+ def _get_features_and_labels_from_input_fn(self, input_fn, mode):
+ """Extracts the `features` and labels from return values of `input_fn`."""
+ return estimator_util.parse_input_fn_result(
+ self._call_input_fn(input_fn, mode))
def _extract_batch_length(self, preds_evaluated):
"""Extracts batch length of predictions."""
@@ -1225,29 +1240,20 @@ class Estimator(object):
steps_per_run_variable = training.get_or_create_steps_per_run_variable()
with self._train_distribution.scope():
random_seed.set_random_seed(self._config.tf_random_seed)
+ iterator, input_hooks = self._get_iterator_from_input_fn(
+ input_fn, model_fn_lib.ModeKeys.TRAIN, self._train_distribution)
+ worker_hooks.extend(input_hooks)
+ global_step_tensor = self._create_and_assert_global_step(g)
+ # we want to add to the global collection in the main thread not the
+ # tower threads.
+ ops.add_to_collection(
+ training_util.GLOBAL_STEP_READ_KEY,
+ self._train_distribution.read_var(global_step_tensor))
if is_tpu_strategy:
- # Create the iterator for run_on_dataset function
- # TODO(sourabhbajaj): refactor this out to call a function on the
- # strategy
- dataset = self._train_distribution.distribute_dataset(
- lambda: self._call_input_fn(input_fn, # pylint: disable=g-long-lambda
- model_fn_lib.ModeKeys.TRAIN))
- iterator = dataset.make_initializable_iterator()
- worker_hooks.append(
- estimator_util._DatasetInitializerHook(iterator)) # pylint: disable=protected-access
-
- global_step_tensor = self._create_and_assert_global_step(g)
- # we want to add to the global collection in the main thread not the
- # tower threads.
- ops.add_to_collection(
- training_util.GLOBAL_STEP_READ_KEY,
- self._train_distribution.read_var(global_step_tensor))
-
# Create a step_fn from the train_op of grouped_estimator_spec
- def step_fn(ctx, inputs):
+ def step_fn(ctx, features, labels):
"""A single step that is passed to run_on_dataset."""
- features, labels = inputs
estimator_spec = self._train_distribution.call_for_each_tower(
self._call_model_fn,
features,
@@ -1268,26 +1274,22 @@ class Estimator(object):
step_fn, iterator, iterations=steps_per_run_variable,
initial_loop_values={'loss': initial_training_loss})
distributed_train_op = ctx.run_op
- tpu_result = ctx.last_step_outputs
+ loss = ctx.last_step_outputs['loss']
grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec']
else:
- features, labels, input_hooks = (
- self._get_features_and_labels_from_input_fn(
- input_fn, model_fn_lib.ModeKeys.TRAIN,
- self._train_distribution))
- worker_hooks.extend(input_hooks)
- global_step_tensor = self._create_and_assert_global_step(g)
- # we want to add to the global collection in the main thread not the
- # tower threads.
- ops.add_to_collection(
- training_util.GLOBAL_STEP_READ_KEY,
- self._train_distribution.read_var(global_step_tensor))
+ features, labels = iterator.get_next()
grouped_estimator_spec = self._train_distribution.call_for_each_tower(
self._call_model_fn,
features,
labels, # although this will be None it seems
model_fn_lib.ModeKeys.TRAIN,
self.config)
+ loss = self._train_distribution.unwrap(
+ self._train_distribution.reduce(
+ distribute_lib.get_loss_reduction(),
+ grouped_estimator_spec.loss,
+ destinations='/device:CPU:0'))[0]
+ distributed_train_op = grouped_estimator_spec.train_op
scaffold = _combine_distributed_scaffold(
grouped_estimator_spec.scaffold, self._train_distribution)
@@ -1301,21 +1303,10 @@ class Estimator(object):
grouped_estimator_spec.training_hooks)
training_chief_hooks = get_hooks_from_the_first_device(
grouped_estimator_spec.training_chief_hooks)
-
- # TODO(sourabhbajaj): Merge the two code paths and clean up the code
- if is_tpu_strategy:
- loss = tpu_result['loss']
- worker_hooks.append(
- estimator_util.StrategyInitFinalizeHook(
- self._train_distribution.initialize,
- self._train_distribution.finalize))
- else:
- loss = self._train_distribution.unwrap(
- self._train_distribution.reduce(
- distribute_lib.get_loss_reduction(),
- grouped_estimator_spec.loss,
- destinations='/device:CPU:0'))[0]
- distributed_train_op = grouped_estimator_spec.train_op
+ worker_hooks.append(
+ estimator_util.StrategyInitFinalizeHook(
+ self._train_distribution.initialize,
+ self._train_distribution.finalize))
estimator_spec = model_fn_lib.EstimatorSpec(
mode=grouped_estimator_spec.mode,
@@ -1416,31 +1407,18 @@ class Estimator(object):
"""Builds the graph and related hooks to run evaluation."""
random_seed.set_random_seed(self._config.tf_random_seed)
self._create_and_assert_global_step(ops.get_default_graph())
- features, labels, input_hooks = (
- self._get_features_and_labels_from_input_fn(
- input_fn, model_fn_lib.ModeKeys.EVAL, self._eval_distribution))
if self._eval_distribution:
- (loss_metric, scaffold, evaluation_hooks, eval_metric_ops) = (
- self._call_model_fn_eval_distributed(features, labels, self.config))
+ (scaffold, evaluation_hooks, input_hooks, update_op, eval_dict) = (
+ self._call_model_fn_eval_distributed(input_fn, self.config))
else:
- (loss_metric, scaffold, evaluation_hooks, eval_metric_ops) = (
- self._call_model_fn_eval(features, labels, self.config))
+ (scaffold, evaluation_hooks, input_hooks, update_op, eval_dict) = (
+ self._call_model_fn_eval(input_fn, self.config))
global_step_tensor = training_util.get_global_step(ops.get_default_graph())
# Call to warm_start has to be after model_fn is called.
self._maybe_warm_start(checkpoint_path)
- if model_fn_lib.LOSS_METRIC_KEY in eval_metric_ops:
- raise ValueError(
- 'Metric with name "%s" is not allowed, because Estimator ' %
- (model_fn_lib.LOSS_METRIC_KEY) +
- 'already defines a default metric with the same name.')
- eval_metric_ops[model_fn_lib.LOSS_METRIC_KEY] = loss_metric
-
- update_op, eval_dict = _extract_metric_update_ops(eval_metric_ops,
- self._eval_distribution)
-
if ops.GraphKeys.GLOBAL_STEP in eval_dict:
raise ValueError(
'Metric with name `global_step` is not allowed, because Estimator '
@@ -1465,26 +1443,70 @@ class Estimator(object):
return scaffold, update_op, eval_dict, all_hooks
- def _call_model_fn_eval(self, features, labels, config):
+ def _call_model_fn_eval(self, input_fn, config):
+ """Call model_fn for evaluation and handle return values."""
+ features, labels, input_hooks = self._get_features_and_labels_from_input_fn(
+ input_fn, model_fn_lib.ModeKeys.EVAL)
+
estimator_spec = self._call_model_fn(
features, labels, model_fn_lib.ModeKeys.EVAL, config)
- loss_metric = metrics_lib.mean(estimator_spec.loss)
- return (loss_metric, estimator_spec.scaffold,
- estimator_spec.evaluation_hooks, estimator_spec.eval_metric_ops)
+ eval_metric_ops = _verify_and_create_loss_metric(
+ estimator_spec.eval_metric_ops, estimator_spec.loss)
+ update_op, eval_dict = _extract_metric_update_ops(eval_metric_ops)
+ return (estimator_spec.scaffold, estimator_spec.evaluation_hooks,
+ input_hooks, update_op, eval_dict)
- def _call_model_fn_eval_distributed(self, features, labels, config):
+ def _call_model_fn_eval_distributed(self, input_fn, config):
"""Call model_fn in distribution mode and handle return values."""
- grouped_estimator_spec = self._eval_distribution.call_for_each_tower(
- self._call_model_fn, features, labels,
- model_fn_lib.ModeKeys.EVAL, config)
+
+ iterator, input_hooks = self._get_iterator_from_input_fn(
+ input_fn, model_fn_lib.ModeKeys.EVAL, self._eval_distribution)
+
+ is_tpu_strategy = (
+ self._eval_distribution.__class__.__name__ == 'TPUStrategy')
+
+ if is_tpu_strategy:
+ def step_fn(ctx, features, labels):
+ """Runs one step of the eval computation and captures outputs."""
+ estimator_spec = self._eval_distribution.call_for_each_tower(
+ self._call_model_fn, features, labels, model_fn_lib.ModeKeys.EVAL,
+ config)
+ eval_metric_ops = _verify_and_create_loss_metric(
+ estimator_spec.eval_metric_ops, estimator_spec.loss,
+ self._eval_distribution)
+ update_op, eval_dict = _extract_metric_update_ops(
+ eval_metric_ops, self._eval_distribution)
+ ctx.set_non_tensor_output(name='estimator_spec', output=estimator_spec)
+ ctx.set_non_tensor_output(name='eval_dict', output=eval_dict)
+ return update_op
+
+ # TODO(priyag): Fix eval step hook to account for steps_per_run.
+ ctx = self._eval_distribution.run_steps_on_dataset(
+ step_fn, iterator, iterations=self._eval_distribution.steps_per_run)
+ update_op = ctx.run_op
+ eval_dict = ctx.non_tensor_outputs['eval_dict']
+ grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec']
+ else:
+ features, labels = iterator.get_next()
+ grouped_estimator_spec = self._eval_distribution.call_for_each_tower(
+ self._call_model_fn, features, labels,
+ model_fn_lib.ModeKeys.EVAL, config)
+ eval_metric_ops = _verify_and_create_loss_metric(
+ grouped_estimator_spec.eval_metric_ops, grouped_estimator_spec.loss,
+ self._eval_distribution)
+ update_op, eval_dict = _extract_metric_update_ops(
+ eval_metric_ops, self._eval_distribution)
+
scaffold = _combine_distributed_scaffold(
grouped_estimator_spec.scaffold, self._eval_distribution)
evaluation_hooks = self._eval_distribution.unwrap(
grouped_estimator_spec.evaluation_hooks)[0]
- loss_metric = self._eval_distribution.call_for_each_tower(
- metrics_lib.mean, grouped_estimator_spec.loss)
- return (loss_metric, scaffold,
- evaluation_hooks, grouped_estimator_spec.eval_metric_ops)
+ evaluation_hooks = evaluation_hooks + (
+ estimator_util.StrategyInitFinalizeHook(
+ self._eval_distribution.initialize,
+ self._eval_distribution.finalize),)
+
+ return (scaffold, evaluation_hooks, input_hooks, update_op, eval_dict)
def _evaluate_run(self, checkpoint_path, scaffold, update_op, eval_dict,
all_hooks, output_dir):
@@ -1520,6 +1542,23 @@ class Estimator(object):
warm_starting_util.warm_start(*self._warm_start_settings)
+def _verify_and_create_loss_metric(eval_metric_ops, loss, distribution=None):
+ """Creates a metric for loss and throws an error if one already exists."""
+ if model_fn_lib.LOSS_METRIC_KEY in eval_metric_ops:
+ raise ValueError(
+ 'Metric with name "%s" is not allowed, because Estimator ' %
+ (model_fn_lib.LOSS_METRIC_KEY) +
+ 'already defines a default metric with the same name.')
+
+ if distribution is None:
+ loss_metric = metrics_lib.mean(loss)
+ else:
+ loss_metric = distribution.call_for_each_tower(
+ metrics_lib.mean, loss)
+ eval_metric_ops[model_fn_lib.LOSS_METRIC_KEY] = loss_metric
+ return eval_metric_ops
+
+
def maybe_overwrite_model_dir_and_session_config(config, model_dir):
"""Overwrite estimator config by `model_dir` and `session_config` if needed.
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 05d1a04d2f..d316742a83 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -955,6 +955,19 @@ class EstimatorTrainTest(test.TestCase):
est = estimator.Estimator(model_fn=_model_fn)
est.train(dummy_input_fn, steps=1)
+ def test_config_should_not_be_evaluator_or_ps(self):
+
+ class FakeEvaluatorConfig(run_config.RunConfig):
+
+ @property
+ def task_type(self):
+ return run_config.TaskType.EVALUATOR
+
+ est = estimator.Estimator(
+ model_fn=dummy_model_fn, config=FakeEvaluatorConfig())
+ with self.assertRaisesRegexp(ValueError, 'train_and_evaluate'):
+ est.train(dummy_input_fn, steps=1)
+
def _model_fn_with_eval_metric_ops(features, labels, mode, params):
_, _ = features, labels
diff --git a/tensorflow/python/estimator/exporter_test.py b/tensorflow/python/estimator/exporter_test.py
index c4b006955c..fcccfbde7a 100644
--- a/tensorflow/python/estimator/exporter_test.py
+++ b/tensorflow/python/estimator/exporter_test.py
@@ -323,6 +323,43 @@ class LatestExporterTest(test.TestCase):
self.assertTrue(gfile.Exists(export_dir_3))
self.assertTrue(gfile.Exists(export_dir_4))
+ def test_garbage_collect_exports_with_trailing_delimiter(self):
+ export_dir_base = tempfile.mkdtemp() + "export/"
+ gfile.MkDir(export_dir_base)
+ export_dir_1 = _create_test_export_dir(export_dir_base)
+ export_dir_2 = _create_test_export_dir(export_dir_base)
+ export_dir_3 = _create_test_export_dir(export_dir_base)
+ export_dir_4 = _create_test_export_dir(export_dir_base)
+
+ self.assertTrue(gfile.Exists(export_dir_1))
+ self.assertTrue(gfile.Exists(export_dir_2))
+ self.assertTrue(gfile.Exists(export_dir_3))
+ self.assertTrue(gfile.Exists(export_dir_4))
+
+ def _serving_input_receiver_fn():
+ return array_ops.constant([1]), None
+
+ exporter = exporter_lib.LatestExporter(
+ name="latest_exporter",
+ serving_input_receiver_fn=_serving_input_receiver_fn,
+ exports_to_keep=1)
+ estimator = test.mock.Mock(spec=estimator_lib.Estimator)
+ # Garbage collect all but the most recent 2 exports,
+ # where recency is determined based on the timestamp directory names.
+ with test.mock.patch.object(gfile, "ListDirectory") as mock_list_directory:
+ mock_list_directory.return_value = [
+ os.path.basename(export_dir_1) + b"/",
+ os.path.basename(export_dir_2) + b"/",
+ os.path.basename(export_dir_3) + b"/",
+ os.path.basename(export_dir_4) + b"/",
+ ]
+ exporter.export(estimator, export_dir_base, None, None, False)
+
+ self.assertFalse(gfile.Exists(export_dir_1))
+ self.assertFalse(gfile.Exists(export_dir_2))
+ self.assertFalse(gfile.Exists(export_dir_3))
+ self.assertTrue(gfile.Exists(export_dir_4))
+
def _create_test_export_dir(export_dir_base):
export_dir = _get_timestamped_export_dir(export_dir_base)
diff --git a/tensorflow/python/estimator/gc.py b/tensorflow/python/estimator/gc.py
index 9f8a463ec1..03ad33dd6b 100644
--- a/tensorflow/python/estimator/gc.py
+++ b/tensorflow/python/estimator/gc.py
@@ -201,9 +201,11 @@ def _get_paths(base_dir, parser):
raw_paths = gfile.ListDirectory(base_dir)
paths = []
for r in raw_paths:
- p = parser(Path(os.path.join(compat.as_str_any(base_dir),
- compat.as_str_any(r)),
- None))
+ # ListDirectory() return paths with "/" at the last if base_dir was GCS URL
+ r = compat.as_str_any(r)
+ if r[-1] == '/':
+ r = r[0:len(r)-1]
+ p = parser(Path(os.path.join(compat.as_str_any(base_dir), r), None))
if p:
paths.append(p)
return sorted(paths)
diff --git a/tensorflow/python/estimator/gc_test.py b/tensorflow/python/estimator/gc_test.py
index 2cbdd511d1..53c3d4ca2a 100644
--- a/tensorflow/python/estimator/gc_test.py
+++ b/tensorflow/python/estimator/gc_test.py
@@ -140,6 +140,17 @@ class GcTest(test_util.TensorFlowTestCase):
gfile.MakeDirs(os.path.join(compat.as_str_any(base_dir), "42"))
gc._get_paths(base_dir, _create_parser(base_dir))
+ def testGcsDirWithSeparator(self):
+ base_dir = "gs://bucket/foo"
+ with test.mock.patch.object(gfile, "ListDirectory") as mock_list_directory:
+ # gfile.ListDirectory returns directory names with separator '/'
+ mock_list_directory.return_value = ["0/", "1/"]
+ self.assertEqual(
+ gc._get_paths(base_dir, _create_parser(base_dir)),
+ [
+ gc.Path(os.path.join(base_dir, "0"), 0),
+ gc.Path(os.path.join(base_dir, "1"), 1)
+ ])
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index e4ce5339d0..6361c6acc1 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -33,9 +33,6 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import models
from tensorflow.python.keras import optimizers
-from tensorflow.python.keras.engine.base_layer import Layer
-from tensorflow.python.keras.engine.network import Network
-from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_module
@@ -47,8 +44,6 @@ from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
-from tensorflow.python.training.checkpointable import base as checkpointable
-from tensorflow.python.training.checkpointable import data_structures
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -92,184 +87,78 @@ def _any_weight_initialized(keras_model):
return False
-def _create_ordered_io(keras_model, estimator_io, is_input=True):
- """Create a list of tensors from IO dictionary based on Keras IO order.
+def _convert_estimator_io_to_keras(keras_model, features, labels):
+ """Converts estimator features and labels to keras input and target tensors.
Args:
- keras_model: An instance of compiled keras model.
- estimator_io: The features or labels (dict or plain array) from model_fn.
- is_input: True if dictionary is for inputs.
+ keras_model: a compiled `tf.keras.Model` instance, used to determine the
+ order of the returned lists.
+ features: Dict of tensors or `None`.
+ labels: Dict of tensors, a single tensor, or `None`.
Returns:
- A list of tensors based on Keras IO order.
-
- Raises:
- ValueError: if dictionary keys cannot be found in Keras model input_names
- or output_names.
- """
- if isinstance(estimator_io, (list, tuple)):
- # Case currently not supported by most built-in input_fn,
- # but it's good to have for sanity
- return [_convert_tensor(x) for x in estimator_io]
- elif isinstance(estimator_io, dict):
- if is_input:
- if keras_model._is_graph_network:
- keras_io_names = keras_model.input_names
- else:
- keras_io_names = [
- 'input_%d' % i for i in range(1, len(estimator_io) + 1)]
- else:
- if keras_model._is_graph_network:
- keras_io_names = keras_model.output_names
- else:
- keras_io_names = [
- 'output_%d' % i for i in range(1, len(estimator_io) + 1)]
-
- for key in estimator_io:
- if key not in keras_io_names:
- raise ValueError(
- 'Cannot find %s with name "%s" in Keras Model. '
- 'It needs to match one '
- 'of the following: %s' % ('input' if is_input else 'output', key,
- ', '.join(keras_io_names)))
- tensors = [_convert_tensor(estimator_io[io_name])
- for io_name in keras_io_names]
- return tensors
- else:
- # Plain array.
- return _convert_tensor(estimator_io)
-
-
-def _in_place_subclassed_model_reset(model):
- """Substitute for model cloning that works for subclassed models.
-
- Subclassed models cannot be cloned because their topology is not serializable.
- To "instantiate" an identical model in a new TF graph, we reuse the original
- model object, but we clear its state.
-
- After calling this function on a model instance, you can use the model
- instance as if it were a model clone (in particular you can use it in a new
- graph).
-
- This method clears the state of the input model. It is thus destructive.
- However the original state can be restored fully by calling
- `_in_place_subclassed_model_state_restoration`.
-
- Args:
- model: Instance of a Keras model created via subclassing.
-
- Raises:
- ValueError: In case the model uses a subclassed model as inner layer.
+ Tuple of (
+ list of input tensors or `None`,
+ list of target tensors or `None`)
+ The order of tensors is determined by the order set in the keras model.
"""
- assert not model._is_graph_network # Only makes sense for subclassed networks
- # Retrieve all layers tracked by the model as well as their attribute names
- attributes_cache = {}
- for name in dir(model):
- try:
- value = getattr(model, name)
- except (AttributeError, ValueError, TypeError):
- continue
- if isinstance(value, Layer):
- attributes_cache[name] = value
- assert value in model._layers
- elif isinstance(value, (list, tuple)) and name not in ('layers', '_layers'):
- # Handle case: list/tuple of layers (also tracked by the Network API).
- if value and all(isinstance(val, Layer) for val in value):
- raise ValueError('We do not support the use of list-of-layers '
- 'attributes in subclassed models used with '
- '`model_to_estimator` at this time. Found list '
- 'model: %s' % name)
-
- # Replace layers on the model with fresh layers
- layers_to_names = {value: key for key, value in attributes_cache.items()}
- original_layers = model._layers[:]
- model._layers = data_structures.NoDependency([])
- for layer in original_layers: # We preserve layer order.
- config = layer.get_config()
- # This will not work for nested subclassed models used as layers.
- # This would be theoretically possible to support, but would add complexity.
- # Only do it if users complain.
- if isinstance(layer, Network) and not layer._is_graph_network:
- raise ValueError('We do not support the use of nested subclassed models '
- 'in `model_to_estimator` at this time. Found nested '
- 'model: %s' % layer)
- fresh_layer = layer.__class__.from_config(config)
- name = layers_to_names[layer]
- setattr(model, name, fresh_layer)
-
- # Cache original model build attributes (in addition to layers)
- if (not hasattr(model, '_original_attributes_cache') or
- model._original_attributes_cache is None):
- if model.built:
- attributes_to_cache = [
- 'inputs',
- 'outputs',
- '_feed_outputs',
- '_feed_output_names',
- '_feed_output_shapes',
- '_feed_loss_fns',
- 'loss_weights_list',
- 'targets',
- '_feed_targets',
- 'sample_weight_modes',
- 'weighted_metrics',
- 'metrics_names',
- 'metrics_tensors',
- 'metrics_updates',
- 'stateful_metric_names',
- 'total_loss',
- 'sample_weights',
- '_feed_sample_weights',
- 'train_function',
- 'test_function',
- 'predict_function',
- '_collected_trainable_weights',
- '_feed_inputs',
- '_feed_input_names',
- '_feed_input_shapes',
- 'optimizer',
- ]
- for name in attributes_to_cache:
- attributes_cache[name] = getattr(model, name)
- model._original_attributes_cache = data_structures.NoDependency(
- attributes_cache)
- # Reset built state
- model.built = False
- model.inputs = None
- model.outputs = None
-
-
-def _in_place_subclassed_model_state_restoration(model):
- """Restores the original state of a model after it was "reset".
-
- This undoes this action of `_in_place_subclassed_model_reset`.
- Args:
- model: Instance of a Keras model created via subclassing, on which
- `_in_place_subclassed_model_reset` was previously called.
- """
- assert not model._is_graph_network
- # Restore layers and build attributes
- if (hasattr(model, '_original_attributes_cache') and
- model._original_attributes_cache is not None):
- # Models have sticky attribute assignment, so we want to be careful to add
- # back the previous attributes and track Layers by their original names
- # without adding dependencies on "utility" attributes which Models exempt
- # when they're constructed.
- model._layers = data_structures.NoDependency([])
- for name, value in model._original_attributes_cache.items():
- if not isinstance(value, checkpointable.CheckpointableBase):
- # If this value is not already checkpointable, it's probably that way
- # for a reason; we don't want to start tracking data structures that the
- # original Model didn't.
- value = data_structures.NoDependency(value)
- setattr(model, name, value)
- model._original_attributes_cache = None
- else:
- # Restore to the state of a never-called model.
- model.built = False
- model.inputs = None
- model.outputs = None
+ def _to_ordered_tensor_list(obj, key_order, obj_name, order_name):
+ """Convert obj to an ordered list of tensors.
+
+ Args:
+ obj: List, dict, or single tensor. May be `None`.
+ key_order: List of strings with the order to return (used if obj is a
+ dict).
+ obj_name: String name of object (e.g. "features" or "labels")
+ order_name: String name of the key order (e.g. "inputs" or "outputs")
+
+ Returns:
+ List of tensors, or `None`
+
+ Raises:
+ KeyError: If obj has invalid keys.
+ """
+ if obj is None:
+ return None
+ elif isinstance(obj, (list, tuple)):
+ return [_convert_tensor(x) for x in obj]
+ elif isinstance(obj, dict):
+ # Ensure that the obj keys and keys in key_order are exactly the same.
+ different_keys = set(obj.keys()) ^ set(key_order)
+
+ if different_keys:
+ raise KeyError(
+ 'The dictionary passed into {obj_name} does not have the expected '
+ '{order_name} keys defined in the keras model.'
+ '\n\tExpected keys: {order_keys}'
+ '\n\t{obj_name} keys: {obj_keys}'
+ '\n\tDifference: {different_keys}'.format(
+ order_name=order_name, order_keys=set(key_order),
+ obj_name=obj_name, obj_keys=set(obj.keys()),
+ different_keys=different_keys))
+
+ return [_convert_tensor(obj[key]) for key in key_order]
+ else: # Assume obj is a tensor.
+ return [_convert_tensor(obj)]
+
+ input_names = None
+ output_names = None
+ if isinstance(features, dict):
+ input_names = (
+ keras_model.input_names if keras_model._is_graph_network else
+ ['input_%d' % i for i in range(1, len(features) + 1)])
+ if isinstance(labels, dict):
+ output_names = (
+ keras_model.output_names if keras_model._is_graph_network else
+ ['output_%d' % i for i in range(1, len(labels) + 1)])
+
+ input_tensors = _to_ordered_tensor_list(
+ features, input_names, 'features', 'inputs')
+ target_tensors = _to_ordered_tensor_list(
+ labels, output_names, 'labels', 'outputs')
+
+ return input_tensors, target_tensors
def _clone_and_build_model(mode,
@@ -289,61 +178,14 @@ def _clone_and_build_model(mode,
Returns:
The newly built model.
"""
- # Set to True during training, False for inference.
+ # Set to True during training, False for inference or testing.
K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN)
-
- # Get list of inputs.
- if features is None:
- input_tensors = None
- else:
- input_tensors = _create_ordered_io(keras_model,
- estimator_io=features,
- is_input=True)
- # Get list of outputs.
- if labels is None:
- target_tensors = None
- elif isinstance(labels, dict):
- target_tensors = _create_ordered_io(keras_model,
- estimator_io=labels,
- is_input=False)
- else:
- target_tensors = [
- _convert_tensor(labels)
- ]
-
- if keras_model._is_graph_network:
- if custom_objects:
- with CustomObjectScope(custom_objects):
- model = models.clone_model(keras_model, input_tensors=input_tensors)
- else:
- model = models.clone_model(keras_model, input_tensors=input_tensors)
- else:
- model = keras_model
- _in_place_subclassed_model_reset(model)
- if input_tensors is not None:
- model._set_inputs(input_tensors)
-
- # Compile/Build model
- if mode is model_fn_lib.ModeKeys.PREDICT:
- if isinstance(model, models.Sequential):
- model.build()
- else:
- if isinstance(keras_model.optimizer, optimizers.TFOptimizer):
- optimizer = keras_model.optimizer
- else:
- optimizer_config = keras_model.optimizer.get_config()
- optimizer = keras_model.optimizer.__class__.from_config(optimizer_config)
- optimizer.iterations = training_util.get_or_create_global_step()
-
- model.compile(
- optimizer,
- keras_model.loss,
- metrics=keras_model.metrics,
- loss_weights=keras_model.loss_weights,
- sample_weight_mode=keras_model.sample_weight_mode,
- weighted_metrics=keras_model.weighted_metrics,
- target_tensors=target_tensors)
- return model
+ input_tensors, target_tensors = _convert_estimator_io_to_keras(
+ keras_model, features, labels)
+ return models.clone_and_build_model(
+ keras_model, input_tensors, target_tensors, custom_objects,
+ compile_clone=(mode != model_fn_lib.ModeKeys.PREDICT),
+ in_place_reset=(not keras_model._is_graph_network))
def _create_keras_model_fn(keras_model, custom_objects=None):
@@ -423,7 +265,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
if not model._is_graph_network:
# Reset model state to original state,
# to avoid `model_fn` being destructive for the initial model argument.
- _in_place_subclassed_model_state_restoration(keras_model)
+ models.in_place_subclassed_model_state_restoration(keras_model)
return model_fn_lib.EstimatorSpec(
mode=mode,
predictions=predictions,
@@ -487,8 +329,9 @@ def model_to_estimator(keras_model=None,
config=None):
"""Constructs an `Estimator` instance from given keras model.
- For usage example, please see
- @{$guide/estimators$creating_estimators_from_keras_models}.
+ For usage example, please see:
+ [Creating estimators from Keras
+ Models](https://tensorflow.org/guide/estimators#model_to_estimator).
Args:
keras_model: A compiled Keras model object. This argument is mutually
diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py
index 332e385726..290c4604ce 100644
--- a/tensorflow/python/estimator/keras_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -184,12 +184,14 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
gfile.MakeDirs(self._base_dir)
self._config = run_config_lib.RunConfig(
tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir)
+ super(TestKerasEstimator, self).setUp()
def tearDown(self):
# Make sure nothing is stuck in limbo.
writer_cache.FileWriterCache.clear()
if os.path.isdir(self._base_dir):
gfile.DeleteRecursively(self._base_dir)
+ super(TestKerasEstimator, self).tearDown()
def test_train(self):
for model_type in ['sequential', 'functional']:
@@ -511,19 +513,19 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
input_dict = {'input_1': x_train}
output_dict = {'invalid_output_name': y_train}
return input_dict, output_dict
-
model = simple_functional_model()
model.compile(
loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])
with self.test_session():
est_keras = keras_lib.model_to_estimator(
keras_model=model, config=self._config)
-
with self.test_session():
- with self.assertRaises(ValueError):
+ with self.assertRaisesRegexp(KeyError,
+ 'Difference: .*invalid_input_name'):
est_keras.train(input_fn=invald_input_name_input_fn, steps=100)
- with self.assertRaises(ValueError):
+ with self.assertRaisesRegexp(KeyError,
+ 'Difference: .*invalid_output_name'):
est_keras.train(input_fn=invald_output_name_input_fn, steps=100)
def test_custom_objects(self):
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 9db9ccd01d..007970bef7 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -141,7 +141,7 @@ class EstimatorSpec(
prediction.
predictions: Predictions `Tensor` or dict of `Tensor`.
loss: Training loss `Tensor`. Must be either scalar, or with shape `[1]`.
- train_op: Op for the training step.
+ train_op: Op to run one training step.
eval_metric_ops: Dict of metric results keyed by name. The values of the
dict are the results of calling a metric function, namely a
`(metric_tensor, update_op)` tuple. `metric_tensor` should be evaluated
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index bb1305767f..e6bd263c80 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -129,8 +129,8 @@ class TrainSpec(
Args:
input_fn: A function that provides input data for training as minibatches.
- See @{$premade_estimators#create_input_functions} for more
- information. The function should construct and return one of
+ See [Premade Estimators](https://tensorflow.org/guide/premade_estimators#create_input_functions)
+ for more information. The function should construct and return one of
the following:
* A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
tuple (features, labels) with same constraints as below.
@@ -193,8 +193,8 @@ class EvalSpec(
Args:
input_fn: A function that constructs the input data for evaluation.
- See @{$premade_estimators#create_input_functions} for more
- information. The function should construct and return one of
+ See [Premade Estimators](https://tensorflow.org/api_guides/premade_estimators#create_input_functions)
+ for more information. The function should construct and return one of
the following:
* A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
tuple (features, labels) with same constraints as below.
@@ -837,6 +837,13 @@ class _TrainingExecutor(object):
if difference > 0:
logging.info('Waiting %f secs before starting next eval run.', difference)
time.sleep(difference)
+ elif (throttle_secs == 0 and
+ eval_result.status != _EvalStatus.EVALUATED):
+ # Prints a user-actionable warning to avoid unnecessary load on evaluator.
+ logging.warning(
+ 'EvalSpec.throttle_secs is set as 0. This might overload the job '
+ 'before finding (next) new checkpoint. Please consider to increase '
+ 'it.')
return (eval_result, should_early_stop)
diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py
index dc106c7d3b..7d46917a6f 100644
--- a/tensorflow/python/estimator/training_test.py
+++ b/tensorflow/python/estimator/training_test.py
@@ -83,6 +83,9 @@ _INVALID_EVAL_LISTENER_MSG = 'must have type `_ContinuousEvalListener`'
_INVALID_CONFIG_FOR_STD_SERVER_MSG = 'Could not start server; .*TF_CONFIG'
_INVALID_LOCAL_TASK_WITH_CLUSTER = '`task.type` in TF_CONFIG cannot be `local`'
_INVALID_TASK_TYPE = '`estimator.config` must have task_type set.'
+_INPROPER_THROTTL_SECS = (
+ 'EvalSpec.throttle_secs is set as 0.*Please consider to increase')
+
# The message should NOT have 'local' word as part of it. As (?!word) is looking
# ahead, so, the $ (ending) check is required; otherwise, it will match
# partially and return successuful.
@@ -1281,7 +1284,7 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
]
eval_spec = training.EvalSpec(
- input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
+ input_fn=lambda: 1, start_delay_secs=0, throttle_secs=2)
executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
with test.mock.patch.object(logging, 'warning') as mock_log:
@@ -1295,6 +1298,34 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
# successuful evaluation)
self.assertEqual(2, mock_log.call_count)
+ def test_warning_if_throttle_secs_is_zero(self):
+ training_max_step = 200
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_est.evaluate.side_effect = [
+ {_GLOBAL_STEP_KEY: training_max_step}
+ ]
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_train_spec.max_steps = training_max_step
+
+ self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
+
+ # We need to make the first one invalid, so it will check the
+ # throttle_secs=0.
+ mock_est.latest_checkpoint.side_effect = [None, 'path']
+
+ eval_spec = training.EvalSpec(
+ input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
+
+ executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
+ with test.mock.patch.object(logging, 'warning') as mock_log:
+ executor.run_evaluator()
+
+ # First ckpt is invalid.
+ self.assertEqual(2, mock_est.latest_checkpoint.call_count)
+ self.assertEqual(1, mock_est.evaluate.call_count)
+
+ self.assertRegexpMatches(str(mock_log.call_args), _INPROPER_THROTTL_SECS)
+
def test_continuous_eval_listener_eval_result(self):
training_max_step = 200
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py
index b3eb57d067..eca34ac26e 100644
--- a/tensorflow/python/framework/constant_op.py
+++ b/tensorflow/python/framework/constant_op.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""Operations that generate constants.
-See the @{$python/constant_op$constants guide}.
+See the [constants guide](https://tensorflow.org/api_guides/python/constant_op).
"""
# Must be separate from array_ops to avoid a cyclic dependency.
@@ -145,6 +145,17 @@ def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
[-1. -1. -1.]]
```
+ `tf.constant` differs from `tf.fill` in a few ways:
+
+ * `tf.constant` supports arbitrary constants, not just uniform scalar
+ Tensors like `tf.fill`.
+ * `tf.constant` creates a `Const` node in the computation graph with the
+ exact value at graph construction time. On the other hand, `tf.fill`
+ creates an Op in the graph that is expanded at runtime.
+ * Because `tf.constant` only embeds constant values in the graph, it does
+ not support dynamic shapes based on other runtime Tensors, whereas
+ `tf.fill` does.
+
Args:
value: A constant value (or list) of output type `dtype`.
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 5527f52860..21eb306865 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -5784,6 +5784,38 @@ class GraphKeys(object):
return cls.GLOBAL_VARIABLES
+def dismantle_graph(graph):
+ """Cleans up reference cycles from a `Graph`.
+
+ Helpful for making sure the garbage collector doesn't need to run after a
+ temporary `Graph` is no longer needed.
+
+ Args:
+ graph: A `Graph` object to destroy. Neither it nor any of its ops are usable
+ after this function runs.
+ """
+ # pylint: disable=protected-access
+ # OrderedDict, constructed on Graph creation, makes a simple reference loop
+ # and hides it in an __attribute in some Python versions. We don't need to
+ # throw an error if we can't find it, but if we do find it we can break the
+ # loop to avoid creating work for the garbage collector.
+ graph_operations = graph.get_operations()
+ problematic_cycle = graph._functions.__dict__.get("_OrderedDict__root", None)
+ # pylint: enable=protected-access
+ if problematic_cycle:
+ try:
+ del problematic_cycle[0][:]
+ except TypeError:
+ # This is probably not one of the problematic Python versions. Continue
+ # with the rest of our cleanup.
+ pass
+ # Now clean up Operation<->Graph reference cycles by clearing all of the
+ # attributes for the Graph and its ops.
+ for op in graph_operations:
+ op.__dict__ = {}
+ graph.__dict__ = {}
+
+
@tf_export("add_to_collection")
def add_to_collection(name, value):
"""Wrapper for `Graph.add_to_collection()` using the default graph.
diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc
index 76d4c2017c..2022fbcbaa 100644
--- a/tensorflow/python/framework/python_op_gen.cc
+++ b/tensorflow/python/framework/python_op_gen.cc
@@ -102,15 +102,6 @@ string TensorPBString(const TensorProto& pb) {
return strings::StrCat("\"\"\"", ProtoShortDebugString(pb), "\"\"\"");
}
-const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
- for (int i = 0; i < api_def.in_arg_size(); ++i) {
- if (api_def.in_arg(i).name() == name) {
- return &api_def.in_arg(i);
- }
- }
- return nullptr;
-}
-
class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
public:
GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
diff --git a/tensorflow/python/framework/python_op_gen_internal.cc b/tensorflow/python/framework/python_op_gen_internal.cc
index 031b4a384e..f2270342b0 100644
--- a/tensorflow/python/framework/python_op_gen_internal.cc
+++ b/tensorflow/python/framework/python_op_gen_internal.cc
@@ -483,15 +483,6 @@ const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) {
return nullptr;
}
-const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
- for (int i = 0; i < api_def.in_arg_size(); ++i) {
- if (api_def.in_arg(i).name() == name) {
- return &api_def.in_arg(i);
- }
- }
- return nullptr;
-}
-
GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
const string& function_name)
: op_def_(op_def),
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py
index bd0f691a61..11b681d544 100644
--- a/tensorflow/python/framework/tensor_shape.py
+++ b/tensorflow/python/framework/tensor_shape.py
@@ -498,7 +498,8 @@ class TensorShape(object):
If a tensor is produced by an operation of type `"Foo"`, its shape
may be inferred if there is a registered shape function for
- `"Foo"`. See @{$adding_an_op#shape-functions-in-c$`Shape functions in C++`}
+ `"Foo"`. See [Shape
+ functions](https://tensorflow.org/extend/adding_an_op#shape_functions_in_c)
for details of shape functions and how to register them. Alternatively,
the shape may be set explicitly using `tf.Tensor.set_shape`.
"""
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index c2c97dd684..d690f08d88 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -369,6 +369,7 @@ def enable_c_shapes(fn):
fn(*args, **kwargs)
finally:
ops._USE_C_SHAPES = prev_value
+
# pylint: enable=protected-access
return wrapper
@@ -418,7 +419,8 @@ def assert_no_new_pyobjects_executing_eagerly(f):
previous_count = len(gc.get_objects())
collection_sizes_before = {
collection: len(ops.get_collection(collection))
- for collection in ops.get_default_graph().collections}
+ for collection in ops.get_default_graph().collections
+ }
for _ in range(3):
f(self, **kwargs)
# Note that gc.get_objects misses anything that isn't subject to garbage
@@ -430,8 +432,8 @@ def assert_no_new_pyobjects_executing_eagerly(f):
if len(collection) > size_before:
raise AssertionError(
("Collection %s increased in size from "
- "%d to %d (current items %s).")
- % (collection_key, size_before, len(collection), collection))
+ "%d to %d (current items %s).") % (collection_key, size_before,
+ len(collection), collection))
# Make sure our collection checks don't show up as leaked memory by
# removing references to temporary variables.
del collection
@@ -446,8 +448,8 @@ def assert_no_new_pyobjects_executing_eagerly(f):
# Using plain assert because not all classes using this decorator
# have assertLessEqual
assert new_count <= previous_count, (
- "new_count(%d) is not less than or equal to previous_count(%d)" % (
- new_count, previous_count))
+ "new_count(%d) is not less than or equal to previous_count(%d)" %
+ (new_count, previous_count))
gc.enable()
return decorator
@@ -547,10 +549,12 @@ def assert_no_garbage_created(f):
return "<%s %d>" % (obj.__class__.__name__, id(obj))
logging.error(" Object type: %s", _safe_object_str(obj))
- logging.error(" Referrer types: %s", ", ".join(
- [_safe_object_str(ref) for ref in gc.get_referrers(obj)]))
- logging.error(" Referent types: %s", ", ".join(
- [_safe_object_str(ref) for ref in gc.get_referents(obj)]))
+ logging.error(
+ " Referrer types: %s", ", ".join(
+ [_safe_object_str(ref) for ref in gc.get_referrers(obj)]))
+ logging.error(
+ " Referent types: %s", ", ".join(
+ [_safe_object_str(ref) for ref in gc.get_referents(obj)]))
logging.error(" Object attribute names: %s", dir(obj))
logging.error(" Object __str__:")
logging.error(obj)
@@ -629,9 +633,8 @@ def generate_combinations_with_testcase_name(**kwargs):
for combination in combinations:
assert isinstance(combination, OrderedDict)
name = "".join([
- "_{}_{}".format(
- "".join(filter(str.isalnum, key)),
- "".join(filter(str.isalnum, str(value))))
+ "_{}_{}".format("".join(filter(str.isalnum, key)), "".join(
+ filter(str.isalnum, str(value))))
for key, value in combination.items()
])
named_combinations.append(
@@ -718,7 +721,7 @@ def run_in_graph_and_eager_modes(func=None,
def decorated(self, **kwargs):
try:
- with ops.Graph().as_default():
+ with context.graph_mode():
with self.test_session(use_gpu=use_gpu, config=config):
f(self, **kwargs)
except unittest.case.SkipTest:
@@ -736,15 +739,19 @@ def run_in_graph_and_eager_modes(func=None,
run_eagerly = assert_no_new_tensors(
assert_no_garbage_created(run_eagerly))
- with context.eager_mode():
+ if reset_test:
+ # This decorator runs the wrapped test twice.
+ # Reset the test environment between runs.
+ self.tearDown()
+ self._tempdir = None
+ # Create a new graph for the eagerly executed version of this test for
+ # better isolation.
+ graph_for_eager_test = ops.Graph()
+ with graph_for_eager_test.as_default(), context.eager_mode():
if reset_test:
- # This decorator runs the wrapped test twice.
- # Reset the test environment between runs.
- self.tearDown()
- self._tempdir = None
self.setUp()
-
run_eagerly(self, **kwargs)
+ ops.dismantle_graph(graph_for_eager_test)
return decorated
@@ -967,21 +974,64 @@ class TensorFlowTestCase(googletest.TestCase):
# pylint: disable=g-doc-return-or-yield
@contextlib.contextmanager
- def test_session(self,
- graph=None,
- config=None,
- use_gpu=False,
- force_gpu=False):
+ def session(self, graph=None, config=None, use_gpu=False, force_gpu=False):
"""Returns a TensorFlow Session for use in executing tests.
- This method should be used for all functional tests.
+ Note that this will set this session and the graph as global defaults.
- This method behaves different than session.Session: for performance reasons
- `test_session` will by default (if `graph` is None) reuse the same session
- across tests. This means you may want to either call the function
- `reset_default_graph()` before tests, or if creating an explicit new graph,
- pass it here (simply setting it with `as_default()` won't do it), which will
- trigger the creation of a new session.
+ Use the `use_gpu` and `force_gpu` options to control where ops are run. If
+ `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
+ `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
+ possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to
+ the CPU.
+
+ Example:
+ ```python
+ class MyOperatorTest(test_util.TensorFlowTestCase):
+ def testMyOperator(self):
+ with self.session(use_gpu=True):
+ valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
+ result = MyOperator(valid_input).eval()
+ self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
+ invalid_input = [-1.0, 2.0, 7.0]
+ with self.assertRaisesOpError("negative input not supported"):
+ MyOperator(invalid_input).eval()
+ ```
+
+ Args:
+ graph: Optional graph to use during the returned session.
+ config: An optional config_pb2.ConfigProto to use to configure the
+ session.
+ use_gpu: If True, attempt to run as many ops as possible on GPU.
+ force_gpu: If True, pin all ops to `/device:GPU:0`.
+
+ Yields:
+ A Session object that should be used as a context manager to surround
+ the graph building and execution code in a test case.
+ """
+ if context.executing_eagerly():
+ yield None
+ else:
+ sess = self._create_session(graph, config, use_gpu, force_gpu)
+ with self._constrain_devices_and_set_default(
+ sess, use_gpu, force_gpu) as constrained_sess:
+ # We need to do this to make sure the session closes, otherwise, even
+ # if the user does with self.session():, it will not close the session.
+ with constrained_sess:
+ yield constrained_sess
+
+ @contextlib.contextmanager
+ def cached_session(self,
+ graph=None,
+ config=None,
+ use_gpu=False,
+ force_gpu=False):
+ """Returns a TensorFlow Session for use in executing tests.
+
+ This method behaves differently than self.session(): for performance reasons
+ `cached_session` will by default reuse the same session within the same
+ test. The session returned by this function will only be closed at the end
+ of the test (in the TearDown function).
Use the `use_gpu` and `force_gpu` options to control where ops are run. If
`force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
@@ -993,7 +1043,7 @@ class TensorFlowTestCase(googletest.TestCase):
```python
class MyOperatorTest(test_util.TensorFlowTestCase):
def testMyOperator(self):
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True) as sess:
valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
result = MyOperator(valid_input).eval()
self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
@@ -1009,74 +1059,39 @@ class TensorFlowTestCase(googletest.TestCase):
use_gpu: If True, attempt to run as many ops as possible on GPU.
force_gpu: If True, pin all ops to `/device:GPU:0`.
- Returns:
+ Yields:
A Session object that should be used as a context manager to surround
the graph building and execution code in a test case.
"""
+ if context.executing_eagerly():
+ yield None
+ else:
+ with self._get_cached_session(
+ graph, config, use_gpu, force_gpu,
+ crash_if_inconsistent_args=True) as sess:
+ yield sess
+
+ @contextlib.contextmanager
+ def test_session(self,
+ graph=None,
+ config=None,
+ use_gpu=False,
+ force_gpu=False):
+ """Use cached_session instead."""
if self.id().endswith(".test_session"):
self.skipTest("Not a test.")
- def prepare_config(config):
- """Returns a config for sessions.
-
- Args:
- config: An optional config_pb2.ConfigProto to use to configure the
- session.
- Returns:
- A config_pb2.ConfigProto object.
- """
- if config is None:
- config = config_pb2.ConfigProto()
- config.allow_soft_placement = not force_gpu
- config.gpu_options.per_process_gpu_memory_fraction = 0.3
- elif force_gpu and config.allow_soft_placement:
- config = config_pb2.ConfigProto().CopyFrom(config)
- config.allow_soft_placement = False
- # Don't perform optimizations for tests so we don't inadvertently run
- # gpu ops on cpu
- config.graph_options.optimizer_options.opt_level = -1
- config.graph_options.rewrite_options.constant_folding = (
- rewriter_config_pb2.RewriterConfig.OFF)
- config.graph_options.rewrite_options.arithmetic_optimization = (
- rewriter_config_pb2.RewriterConfig.OFF)
- return config
-
if context.executing_eagerly():
yield None
- elif graph is None:
- if self._cached_session is None:
- self._cached_session = session.Session(
- graph=None, config=prepare_config(config))
- sess = self._cached_session
- with sess.graph.as_default(), sess.as_default():
- if force_gpu:
- # Use the name of an actual device if one is detected, or '/device:GPU:0'
- # otherwise
- gpu_name = gpu_device_name()
- if not gpu_name:
- gpu_name = "/device:GPU:0"
- with sess.graph.device(gpu_name):
- yield sess
- elif use_gpu:
- yield sess
- else:
- with sess.graph.device("/cpu:0"):
- yield sess
else:
- with session.Session(graph=graph, config=prepare_config(config)) as sess:
- if force_gpu:
- # Use the name of an actual device if one is detected, or '/device:GPU:0'
- # otherwise
- gpu_name = gpu_device_name()
- if not gpu_name:
- gpu_name = "/device:GPU:0"
- with sess.graph.device(gpu_name):
- yield sess
- elif use_gpu:
+ if graph is None:
+ with self._get_cached_session(
+ graph, config, use_gpu, force_gpu,
+ crash_if_inconsistent_args=False) as sess:
+ yield sess
+ else:
+ with self.session(graph, config, use_gpu, force_gpu) as sess:
yield sess
- else:
- with sess.graph.device("/cpu:0"):
- yield sess
# pylint: enable=g-doc-return-or-yield
@@ -1202,9 +1217,10 @@ class TensorFlowTestCase(googletest.TestCase):
msg: An optional string message to append to the failure message.
"""
# f1 == f2 is needed here as we might have: f1, f2 = inf, inf
- self.assertTrue(f1 == f2 or math.fabs(f1 - f2) <= err,
- "%f != %f +/- %f%s" % (f1, f2, err, " (%s)" % msg
- if msg is not None else ""))
+ self.assertTrue(
+ f1 == f2 or math.fabs(f1 - f2) <= err,
+ "%f != %f +/- %f%s" % (f1, f2, err, " (%s)" % msg
+ if msg is not None else ""))
def assertArrayNear(self, farray1, farray2, err, msg=None):
"""Asserts that two float arrays are near each other.
@@ -1250,8 +1266,9 @@ class TensorFlowTestCase(googletest.TestCase):
def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
a = self._GetNdArray(a)
b = self._GetNdArray(b)
- self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s." %
- (a.shape, b.shape))
+ self.assertEqual(
+ a.shape, b.shape,
+ "Shape mismatch: expected %s, got %s." % (a.shape, b.shape))
if not np.allclose(a, b, rtol=rtol, atol=atol):
# Prints more details than np.testing.assert_allclose.
#
@@ -1453,8 +1470,9 @@ class TensorFlowTestCase(googletest.TestCase):
msg = msg if msg else ""
a = self._GetNdArray(a)
b = self._GetNdArray(b)
- self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s."
- " %s" % (a.shape, b.shape, msg))
+ self.assertEqual(
+ a.shape, b.shape, "Shape mismatch: expected %s, got %s."
+ " %s" % (a.shape, b.shape, msg))
same = (a == b)
if (a.dtype in [
@@ -1682,8 +1700,8 @@ class TensorFlowTestCase(googletest.TestCase):
self.fail(exception_type.__name__ + " not raised")
except Exception as e: # pylint: disable=broad-except
if not isinstance(e, exception_type) or not predicate(e):
- raise AssertionError("Exception of type %s: %s" % (str(type(e)),
- str(e)))
+ raise AssertionError(
+ "Exception of type %s: %s" % (str(type(e)), str(e)))
# pylint: enable=g-doc-return-or-yield
@@ -1719,8 +1737,9 @@ class TensorFlowTestCase(googletest.TestCase):
"""
device1 = pydev.canonical_name(device1)
device2 = pydev.canonical_name(device2)
- self.assertEqual(device1, device2, "Devices %s and %s are not equal. %s" %
- (device1, device2, msg))
+ self.assertEqual(
+ device1, device2,
+ "Devices %s and %s are not equal. %s" % (device1, device2, msg))
# Fix Python 3 compatibility issues
if six.PY3:
@@ -1734,6 +1753,113 @@ class TensorFlowTestCase(googletest.TestCase):
# pylint: enable=invalid-name
+ @contextlib.contextmanager
+ def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
+ """Set the session and its graph to global default and constrain devices."""
+ if context.executing_eagerly():
+ yield None
+ else:
+ with sess.graph.as_default(), sess.as_default():
+ if force_gpu:
+ # Use the name of an actual device if one is detected, or
+ # '/device:GPU:0' otherwise
+ gpu_name = gpu_device_name()
+ if not gpu_name:
+ gpu_name = "/device:GPU:0"
+ with sess.graph.device(gpu_name):
+ yield sess
+ elif use_gpu:
+ yield sess
+ else:
+ with sess.graph.device("/cpu:0"):
+ yield sess
+
+ def _create_session(self, graph, config, use_gpu, force_gpu):
+ """See session() for details."""
+ if context.executing_eagerly():
+ return None
+ else:
+
+ def prepare_config(config):
+ """Returns a config for sessions.
+
+ Args:
+ config: An optional config_pb2.ConfigProto to use to configure the
+ session.
+ Returns:
+ A config_pb2.ConfigProto object.
+ """
+ if config is None:
+ config = config_pb2.ConfigProto()
+ config.allow_soft_placement = not force_gpu
+ config.gpu_options.per_process_gpu_memory_fraction = 0.3
+ elif force_gpu and config.allow_soft_placement:
+ config = config_pb2.ConfigProto().CopyFrom(config)
+ config.allow_soft_placement = False
+ # Don't perform optimizations for tests so we don't inadvertently run
+ # gpu ops on cpu
+ config.graph_options.optimizer_options.opt_level = -1
+ config.graph_options.rewrite_options.constant_folding = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ config.graph_options.rewrite_options.arithmetic_optimization = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ return config
+
+ return session.Session(graph=graph, config=prepare_config(config))
+
+ @contextlib.contextmanager
+ def _get_cached_session(self,
+ graph=None,
+ config=None,
+ use_gpu=False,
+ force_gpu=False,
+ crash_if_inconsistent_args=True):
+ """See cached_session() for documentation."""
+ if context.executing_eagerly():
+ yield None
+ else:
+ if self._cached_session is None:
+ sess = self._create_session(
+ graph=graph, config=config, use_gpu=use_gpu, force_gpu=force_gpu)
+ self._cached_session = sess
+ self._cached_graph = graph
+ self._cached_config = config
+ self._cached_use_gpu = use_gpu
+ self._cached_force_gpu = force_gpu
+ with self._constrain_devices_and_set_default(
+ sess, use_gpu, force_gpu) as constrained_sess:
+ yield constrained_sess
+ else:
+ if crash_if_inconsistent_args and self._cached_graph is not graph:
+ raise ValueError("The graph used to get the cached session is "
+ "different than the one that was used to create the "
+ "session. Maybe create a new session with "
+ "self.session()")
+ if crash_if_inconsistent_args and self._cached_config is not config:
+ raise ValueError("The config used to get the cached session is "
+ "different than the one that was used to create the "
+ "session. Maybe create a new session with "
+ "self.session()")
+ if crash_if_inconsistent_args and self._cached_use_gpu is not use_gpu:
+ raise ValueError(
+ "The use_gpu value used to get the cached session is "
+ "different than the one that was used to create the "
+ "session. Maybe create a new session with "
+ "self.session()")
+ if crash_if_inconsistent_args and (self._cached_force_gpu is
+ not force_gpu):
+ raise ValueError(
+ "The force_gpu value used to get the cached session is "
+ "different than the one that was used to create the "
+ "session. Maybe create a new session with "
+ "self.session()")
+ # If you modify this logic, make sure to modify it in _create_session
+ # as well.
+ sess = self._cached_session
+ with self._constrain_devices_and_set_default(
+ sess, use_gpu, force_gpu) as constrained_sess:
+ yield constrained_sess
+
@tf_export("test.create_local_cluster")
def create_local_cluster(num_workers,
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index f983cbef04..f68c0ddecb 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -22,6 +22,7 @@ import collections
import copy
import random
import threading
+import weakref
import numpy as np
@@ -40,6 +41,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -57,6 +59,33 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertRaises(ValueError, test_util.assert_ops_in_graph,
{"hello": "Variable"}, ops.get_default_graph())
+ def test_session_functions(self):
+ with self.test_session() as sess:
+ sess_ref = weakref.ref(sess)
+ with self.cached_session(graph=None, config=None) as sess2:
+ # We make sure that sess2 is sess.
+ assert sess2 is sess
+ # We make sure we raise an exception if we use cached_session with
+ # different values.
+ with self.assertRaises(ValueError):
+ with self.cached_session(graph=ops.Graph()) as sess2:
+ pass
+ with self.assertRaises(ValueError):
+ with self.cached_session(use_gpu=True) as sess2:
+ pass
+ with self.assertRaises(ValueError):
+ with self.cached_session(force_gpu=True) as sess2:
+ pass
+ # We make sure that test_session will cache the session even after the
+ # with scope.
+ assert not sess_ref()._closed
+ with self.session() as unique_sess:
+ unique_sess_ref = weakref.ref(unique_sess)
+ with self.session() as sess2:
+ assert sess2 is not unique_sess
+ # We make sure the session is closed when we leave the with statement.
+ assert unique_sess_ref()._closed
+
def test_assert_equal_graph_def(self):
with ops.Graph().as_default() as g:
def_empty = g.as_graph_def()
@@ -666,6 +695,22 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertEqual(modes[2:], ["setup_eager", "run_eager"])
+# Its own test case to reproduce variable sharing issues which only pop up when
+# setUp() is overridden and super() is not called.
+class GraphAndEagerNoVariableSharing(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ pass # Intentionally does not call TensorFlowTestCase's super()
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_no_variable_sharing(self):
+ variable_scope.get_variable(
+ name="step_size",
+ initializer=np.array(1e-5, np.float32),
+ use_resource=True,
+ trainable=False)
+
+
class GarbageCollectionTest(test_util.TensorFlowTestCase):
def test_no_reference_cycle_decorator(self):
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index fa1ec51aa7..e145b894f5 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -688,7 +688,7 @@ py_test(
py_test(
name = "training_test",
- size = "large",
+ size = "enormous",
srcs = ["engine/training_test.py"],
srcs_version = "PY2AND3",
tags = ["notsan"],
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 418586b85f..26068b2556 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -2766,7 +2766,8 @@ class Function(object):
outputs: Output tensors to fetch.
updates: Additional update ops to be run at function call.
name: A name to help users identify what this function does.
- session_kwargs: Arguments to `tf.Session.run()`: `fetches`, `feed_dict`.
+ session_kwargs: Arguments to `tf.Session.run()`:
+ `fetches`, `feed_dict`, `options`, `run_metadata`.
"""
def __init__(self, inputs, outputs, updates=None, name=None,
@@ -2800,6 +2801,8 @@ class Function(object):
self.fetches = session_kwargs.pop('fetches', [])
if not isinstance(self.fetches, list):
self.fetches = [self.fetches]
+ self.run_options = session_kwargs.pop('options', None)
+ self.run_metadata = session_kwargs.pop('run_metadata', None)
# The main use case of `fetches` being passed to a model is the ability
# to run custom updates
# This requires us to wrap fetches in `identity` ops.
@@ -2857,6 +2860,9 @@ class Function(object):
callable_opts.fetch.append(x.name)
# Handle updates.
callable_opts.target.append(self.updates_op.name)
+ # Handle run_options.
+ if self.run_options:
+ callable_opts.run_options.CopyFrom(self.run_options)
# Create callable.
callable_fn = session._make_callable_from_options(callable_opts)
# Cache parameters corresponding to the generated callable, so that
@@ -2915,7 +2921,8 @@ class Function(object):
session != self._session):
self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)
- fetched = self._callable_fn(*array_vals)
+ fetched = self._callable_fn(*array_vals,
+ run_metadata=self.run_metadata)
self._call_fetch_callbacks(fetched[-len(self._fetches):])
return fetched[:len(self.outputs)]
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index 40e7910061..a63267a5dd 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -21,6 +21,7 @@ from absl.testing import parameterized
import numpy as np
import scipy.sparse
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python import keras
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -277,6 +278,29 @@ class BackendUtilsTest(test.TestCase):
self.assertEqual(
keras.backend.get_session().run(fetches=[x, y]), [30., 40.])
+ def test_function_tf_run_options_with_run_metadata(self):
+ with self.test_session():
+ x_placeholder = keras.backend.placeholder(shape=())
+ y_placeholder = keras.backend.placeholder(shape=())
+
+ run_options = config_pb2.RunOptions(output_partition_graphs=True)
+ run_metadata = config_pb2.RunMetadata()
+ # enable run_options.
+ f = keras.backend.function(inputs=[x_placeholder, y_placeholder],
+ outputs=[x_placeholder + y_placeholder],
+ options=run_options,
+ run_metadata=run_metadata)
+ output = f([10., 20.])
+ self.assertEqual(output, [30.])
+ self.assertGreater(len(run_metadata.partition_graphs), 0)
+ # disable run_options.
+ f1 = keras.backend.function(inputs=[x_placeholder, y_placeholder],
+ outputs=[x_placeholder + y_placeholder],
+ run_metadata=run_metadata)
+ output1 = f1([10., 20.])
+ self.assertEqual(output1, [30.])
+ self.assertEqual(len(run_metadata.partition_graphs), 0)
+
def test_function_fetch_callbacks(self):
class CallbackStub(object):
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index e84e023384..7675a6586f 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -235,11 +235,8 @@ class KerasCallbacksTest(test.TestCase):
num_classes=NUM_CLASSES)
y_test = keras.utils.to_categorical(y_test)
y_train = keras.utils.to_categorical(y_train)
- model = keras.models.Sequential()
- model.add(
- keras.layers.Dense(
- NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
- model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
@@ -298,9 +295,8 @@ class KerasCallbacksTest(test.TestCase):
test_samples=50,
input_shape=(1,),
num_classes=NUM_CLASSES)
- model = keras.models.Sequential((keras.layers.Dense(
- 1, input_dim=1, activation='relu'), keras.layers.Dense(
- 1, activation='sigmoid'),))
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=1, num_classes=1, input_dim=1)
model.compile(
optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy'])
@@ -334,11 +330,8 @@ class KerasCallbacksTest(test.TestCase):
num_classes=NUM_CLASSES)
y_test = keras.utils.to_categorical(y_test)
y_train = keras.utils.to_categorical(y_train)
- model = keras.models.Sequential()
- model.add(
- keras.layers.Dense(
- NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
- model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
model.compile(
loss='categorical_crossentropy',
optimizer='sgd',
@@ -388,12 +381,8 @@ class KerasCallbacksTest(test.TestCase):
def make_model():
random_seed.set_random_seed(1234)
np.random.seed(1337)
- model = keras.models.Sequential()
- model.add(
- keras.layers.Dense(
- NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
- model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
-
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
model.compile(
loss='categorical_crossentropy',
optimizer=keras.optimizers.SGD(lr=0.1),
@@ -498,12 +487,8 @@ class KerasCallbacksTest(test.TestCase):
def make_model():
np.random.seed(1337)
- model = keras.models.Sequential()
- model.add(
- keras.layers.Dense(
- NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
- model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
-
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
model.compile(
loss='categorical_crossentropy',
optimizer=keras.optimizers.SGD(lr=0.1),
@@ -985,9 +970,8 @@ class KerasCallbacksTest(test.TestCase):
yield x, y
with self.test_session():
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(10, input_dim=100, activation='relu'))
- model.add(keras.layers.Dense(10, activation='softmax'))
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=10, num_classes=10, input_dim=100)
model.compile(
loss='categorical_crossentropy',
optimizer='sgd',
@@ -1083,11 +1067,8 @@ class KerasCallbacksTest(test.TestCase):
y_test = keras.utils.to_categorical(y_test)
y_train = keras.utils.to_categorical(y_train)
- model = keras.models.Sequential()
- model.add(
- keras.layers.Dense(
- NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
- model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
model.compile(
loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy'])
@@ -1179,40 +1160,36 @@ class KerasCallbacksTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_Tensorboard_eager(self):
- with self.test_session():
- temp_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
- self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
-
- (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
- train_samples=TRAIN_SAMPLES,
- test_samples=TEST_SAMPLES,
- input_shape=(INPUT_DIM,),
- num_classes=NUM_CLASSES)
- y_test = keras.utils.to_categorical(y_test)
- y_train = keras.utils.to_categorical(y_train)
-
- model = keras.models.Sequential()
- model.add(
- keras.layers.Dense(
- NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
- model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
- model.compile(
- loss='binary_crossentropy',
- optimizer=adam.AdamOptimizer(0.01),
- metrics=['accuracy'])
-
- cbks = [keras.callbacks.TensorBoard(log_dir=temp_dir)]
+ temp_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
+ self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
- model.fit(
- x_train,
- y_train,
- batch_size=BATCH_SIZE,
- validation_data=(x_test, y_test),
- callbacks=cbks,
- epochs=2,
- verbose=0)
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=TRAIN_SAMPLES,
+ test_samples=TEST_SAMPLES,
+ input_shape=(INPUT_DIM,),
+ num_classes=NUM_CLASSES)
+ y_test = keras.utils.to_categorical(y_test)
+ y_train = keras.utils.to_categorical(y_train)
- self.assertTrue(os.path.exists(temp_dir))
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
+ model.compile(
+ loss='binary_crossentropy',
+ optimizer=adam.AdamOptimizer(0.01),
+ metrics=['accuracy'])
+
+ cbks = [keras.callbacks.TensorBoard(log_dir=temp_dir)]
+
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=BATCH_SIZE,
+ validation_data=(x_test, y_test),
+ callbacks=cbks,
+ epochs=2,
+ verbose=0)
+
+ self.assertTrue(os.path.exists(temp_dir))
def test_RemoteMonitorWithJsonPayload(self):
if requests is None:
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 708fa1c807..cd74e36e68 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -394,10 +394,10 @@ class Network(base_layer.Layer):
no_dependency = isinstance(value, data_structures.NoDependency)
value = data_structures.sticky_attribute_assignment(
checkpointable=self, value=value, name=name)
- if isinstance(value, (
- base_layer.Layer,
- Network,
- data_structures.CheckpointableDataStructure)):
+ if (isinstance(value, (base_layer.Layer,
+ Network,
+ data_structures.CheckpointableDataStructure))
+ or checkpointable_layer_utils.has_weights(value)):
try:
is_graph_network = self._is_graph_network
except AttributeError:
@@ -689,14 +689,14 @@ class Network(base_layer.Layer):
def trainable_weights(self):
return checkpointable_layer_utils.gather_trainable_weights(
trainable=self.trainable,
- sub_layers=self.layers,
+ sub_layers=self._layers,
extra_variables=self._extra_variables)
@property
def non_trainable_weights(self):
return checkpointable_layer_utils.gather_non_trainable_weights(
trainable=self.trainable,
- sub_layers=self.layers,
+ sub_layers=self._layers,
extra_variables=self._extra_variables)
@property
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py
index 415b15fde1..cf6fb44275 100644
--- a/tensorflow/python/keras/engine/sequential.py
+++ b/tensorflow/python/keras/engine/sequential.py
@@ -239,9 +239,9 @@ class Sequential(Model):
x = inputs
for layer in self.layers:
kwargs = {}
- if 'mask' in tf_inspect.getargspec(layer.call).args:
+ if 'mask' in tf_inspect.getfullargspec(layer.call).args:
kwargs['mask'] = mask
- if 'training' in tf_inspect.getargspec(layer.call).args:
+ if 'training' in tf_inspect.getfullargspec(layer.call).args:
kwargs['training'] = training
if isinstance(layer, Network) and layer._compute_output_and_mask_jointly:
diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py
index 3f8e120df0..28af8d61bc 100644
--- a/tensorflow/python/keras/engine/sequential_test.py
+++ b/tensorflow/python/keras/engine/sequential_test.py
@@ -25,22 +25,12 @@ from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import function
from tensorflow.python.framework import test_util as tf_test_util
+from tensorflow.python.keras import testing_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.training import rmsprop
-def _get_small_mlp(num_hidden, num_classes, input_dim=None):
- model = keras.models.Sequential()
- if input_dim:
- model.add(keras.layers.Dense(num_hidden, activation='relu',
- input_dim=input_dim))
- else:
- model.add(keras.layers.Dense(num_hidden, activation='relu'))
- model.add(keras.layers.Dense(num_classes, activation='softmax'))
- return model
-
-
class TestSequential(test.TestCase, parameterized.TestCase):
"""Most Sequential model API tests are covered in `training_test.py`.
"""
@@ -63,7 +53,8 @@ class TestSequential(test.TestCase, parameterized.TestCase):
batch_size = 5
num_classes = 2
- model = _get_small_mlp(num_hidden, num_classes, input_dim)
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden, num_classes, input_dim)
model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3))
x = np.random.random((batch_size, input_dim))
y = np.random.random((batch_size, num_classes))
@@ -94,7 +85,7 @@ class TestSequential(test.TestCase, parameterized.TestCase):
batch_size = 5
num_classes = 2
- model = _get_small_mlp(num_hidden, num_classes)
+ model = testing_utils.get_small_sequential_mlp(num_hidden, num_classes)
model.compile(
loss='mse',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
@@ -118,7 +109,7 @@ class TestSequential(test.TestCase, parameterized.TestCase):
num_samples = 50
steps_per_epoch = 10
- model = _get_small_mlp(num_hidden, num_classes)
+ model = testing_utils.get_small_sequential_mlp(num_hidden, num_classes)
model.compile(
loss='mse',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
@@ -145,9 +136,9 @@ class TestSequential(test.TestCase, parameterized.TestCase):
def get_model():
if deferred:
- model = _get_small_mlp(10, 4)
+ model = testing_utils.get_small_sequential_mlp(10, 4)
else:
- model = _get_small_mlp(10, 4, input_dim=3)
+ model = testing_utils.get_small_sequential_mlp(10, 4, input_dim=3)
model.compile(
optimizer=rmsprop.RMSPropOptimizer(1e-3),
loss='categorical_crossentropy',
@@ -262,7 +253,7 @@ class TestSequential(test.TestCase, parameterized.TestCase):
batch_size = 5
num_classes = 2
- model = _get_small_mlp(num_hidden, num_classes)
+ model = testing_utils.get_small_sequential_mlp(num_hidden, num_classes)
model.compile(
loss='mse',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
@@ -284,21 +275,21 @@ class TestSequential(test.TestCase, parameterized.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_sequential_shape_inference_deferred(self):
- model = _get_small_mlp(4, 5)
+ model = testing_utils.get_small_sequential_mlp(4, 5)
output_shape = model.compute_output_shape((None, 7))
self.assertEqual(tuple(output_shape.as_list()), (None, 5))
@tf_test_util.run_in_graph_and_eager_modes
def test_sequential_build_deferred(self):
- model = _get_small_mlp(4, 5)
+ model = testing_utils.get_small_sequential_mlp(4, 5)
model.build((None, 10))
self.assertTrue(model.built)
self.assertEqual(len(model.weights), 4)
# Test with nested model
- model = _get_small_mlp(4, 3)
- inner_model = _get_small_mlp(4, 5)
+ model = testing_utils.get_small_sequential_mlp(4, 3)
+ inner_model = testing_utils.get_small_sequential_mlp(4, 5)
model.add(inner_model)
model.build((None, 10))
@@ -308,8 +299,8 @@ class TestSequential(test.TestCase, parameterized.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_sequential_nesting(self):
- model = _get_small_mlp(4, 3)
- inner_model = _get_small_mlp(4, 5)
+ model = testing_utils.get_small_sequential_mlp(4, 3)
+ inner_model = testing_utils.get_small_sequential_mlp(4, 5)
model.add(inner_model)
model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3))
@@ -353,7 +344,7 @@ class TestSequentialEagerIntegration(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_build_before_fit(self):
# Fix for b/112433577
- model = _get_small_mlp(4, 5)
+ model = testing_utils.get_small_sequential_mlp(4, 5)
model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3))
model.build((None, 6))
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index f71388cadb..502635c408 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -800,18 +800,18 @@ class Model(Network):
RuntimeError: If the model was never compiled.
"""
if sample_weight is not None and sample_weight.all():
- raise NotImplementedError('sample_weight is currently not supported when '
- 'using DistributionStrategy.')
+ raise NotImplementedError('`sample_weight` is currently not supported '
+ 'when using DistributionStrategy.')
if class_weight:
- raise NotImplementedError('class_weight is currently not supported when '
- 'using DistributionStrategy.')
+ raise NotImplementedError('`class_weight` is currently not supported '
+ 'when using DistributionStrategy.')
# TODO(anjalisridhar): Can we use the iterator and getnext op cache?
# We require users to pass Datasets since we distribute the dataset across
# multiple devices.
if not isinstance(x, dataset_ops.Dataset):
- raise ValueError('When using DistributionStrategy you must specify a '
- 'Dataset object instead of a %s.' % type(x))
+ raise ValueError('When using DistributionStrategy, model inputs should be'
+ ' Dataset instances; found instead %s.' % type(x))
# TODO(anjalisridhar): We want distribute_dataset() to accept a Dataset or a
# function which returns a Dataset. Currently distribute_dataset() only
# accepts a function that returns a Dataset. Once we add support for being
@@ -834,8 +834,9 @@ class Model(Network):
next_element = iterator.get_next()
if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
- raise ValueError('Please provide data as a list or tuple of 2 elements '
- ' - input and target pair. Received %s' % next_element)
+ raise ValueError('Please provide model inputs as a list or tuple of 2 '
+ 'elements: input and target pair. '
+ 'Received %s' % next_element)
x, y = next_element
# Validate that all the elements in x and y are of the same type and shape.
# We can then pass the first element of x and y to `_standardize_weights`
@@ -971,8 +972,9 @@ class Model(Network):
'required number of samples.')
if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
- raise ValueError('Please provide data as a list or tuple of 2 elements '
- ' - input and target pair. Received %s' % next_element)
+ raise ValueError('Please provide model inputs as a list or tuple of 2 '
+ 'elements: input and target pair. '
+ 'Received %s' % next_element)
x, y = next_element
x, y, sample_weights = self._standardize_weights(x, y, sample_weight,
class_weight, batch_size)
@@ -980,6 +982,10 @@ class Model(Network):
def _standardize_weights(self, x, y, sample_weight=None, class_weight=None,
batch_size=None,):
+ if sample_weight is not None and class_weight is not None:
+ logging.warning(
+ 'Received both a `sample_weight` and `class_weight` argument. '
+ 'The `class_weight` argument will be ignored.')
# First, we build/compile the model on the fly if necessary.
all_inputs = []
is_build_called = False
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 15e7d725de..8d835ed5a9 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -49,289 +49,287 @@ class TrainingTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_fit_on_arrays(self):
- with self.test_session():
- a = keras.layers.Input(shape=(3,), name='input_a')
- b = keras.layers.Input(shape=(3,), name='input_b')
-
- dense = keras.layers.Dense(4, name='dense')
- c = dense(a)
- d = dense(b)
- e = keras.layers.Dropout(0.5, name='dropout')(c)
-
- model = keras.models.Model([a, b], [d, e])
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- loss = 'mse'
- loss_weights = [1., 0.5]
- model.compile(
- optimizer,
- loss,
- metrics=[metrics_module.CategoricalAccuracy(), 'mae'],
- loss_weights=loss_weights)
-
- input_a_np = np.random.random((10, 3))
- input_b_np = np.random.random((10, 3))
-
- output_d_np = np.random.random((10, 4))
- output_e_np = np.random.random((10, 4))
-
- # Test fit at different verbosity
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- epochs=1,
- batch_size=5,
- verbose=0)
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- epochs=1,
- batch_size=5,
- verbose=1)
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- epochs=2,
- batch_size=5,
- verbose=2)
- model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
-
- # Test model with input data as a list of lists
- model.fit(
- [np.ndarray.tolist(input_a_np), np.ndarray.tolist(input_b_np)],
- [output_d_np, output_e_np],
- epochs=2,
- batch_size=5,
- verbose=2)
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
- # Test with validation data
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- validation_data=([input_a_np, input_b_np], [output_d_np,
- output_e_np]),
- epochs=1,
- batch_size=5,
- verbose=0)
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- validation_data=([input_a_np, input_b_np], [output_d_np,
- output_e_np]),
- epochs=2,
- batch_size=5,
- verbose=1)
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- validation_data=([input_a_np, input_b_np], [output_d_np,
- output_e_np]),
- epochs=2,
- batch_size=5,
- verbose=2)
- # Test with validation split
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- epochs=2,
- batch_size=5,
- verbose=0,
- validation_split=0.2)
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
- # Test with dictionary inputs
- model.fit(
- {
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {
- 'dense': output_d_np,
- 'dropout': output_e_np
- },
- epochs=1,
- batch_size=5,
- verbose=0)
- model.fit(
- {
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {
- 'dense': output_d_np,
- 'dropout': output_e_np
- },
- epochs=1,
- batch_size=5,
- verbose=1)
- model.fit(
- {
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {
- 'dense': output_d_np,
- 'dropout': output_e_np
- },
- validation_data=({
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {
- 'dense': output_d_np,
- 'dropout': output_e_np
- }),
- epochs=1,
- batch_size=5,
- verbose=0)
- model.train_on_batch({
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {
- 'dense': output_d_np,
- 'dropout': output_e_np
- })
-
- # Test with lists for loss, metrics
- loss = ['mae', 'mse']
- model.compile(
- optimizer,
- loss,
- metrics=[metrics_module.CategoricalAccuracy(), 'mae'])
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- epochs=1,
- batch_size=5,
- verbose=0)
+ model = keras.models.Model([a, b], [d, e])
- # Test with dictionaries for loss, metrics, loss weights
- loss = {'dense': 'mse', 'dropout': 'mae'}
- loss_weights = {'dense': 1., 'dropout': 0.5}
- metrics = {
- 'dense': 'mse',
- 'dropout': metrics_module.CategoricalAccuracy()
- }
- model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ loss_weights = [1., 0.5]
+ model.compile(
+ optimizer,
+ loss,
+ metrics=[metrics_module.CategoricalAccuracy(), 'mae'],
+ loss_weights=loss_weights)
+
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 3))
+
+ output_d_np = np.random.random((10, 4))
+ output_e_np = np.random.random((10, 4))
+
+ # Test fit at different verbosity
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=1,
+ batch_size=5,
+ verbose=1)
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=2,
+ batch_size=5,
+ verbose=2)
+ model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
+
+ # Test model with input data as a list of lists
+ model.fit(
+ [np.ndarray.tolist(input_a_np), np.ndarray.tolist(input_b_np)],
+ [output_d_np, output_e_np],
+ epochs=2,
+ batch_size=5,
+ verbose=2)
+
+ # Test with validation data
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ validation_data=([input_a_np, input_b_np], [output_d_np,
+ output_e_np]),
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ validation_data=([input_a_np, input_b_np], [output_d_np,
+ output_e_np]),
+ epochs=2,
+ batch_size=5,
+ verbose=1)
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ validation_data=([input_a_np, input_b_np], [output_d_np,
+ output_e_np]),
+ epochs=2,
+ batch_size=5,
+ verbose=2)
+ # Test with validation split
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=2,
+ batch_size=5,
+ verbose=0,
+ validation_split=0.2)
+
+ # Test with dictionary inputs
+ model.fit(
+ {
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+ model.fit(
+ {
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
+ epochs=1,
+ batch_size=5,
+ verbose=1)
+ model.fit(
+ {
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
+ validation_data=({
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ }),
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+ model.train_on_batch({
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ })
+
+ # Test with lists for loss, metrics
+ loss = ['mae', 'mse']
+ model.compile(
+ optimizer,
+ loss,
+ metrics=[metrics_module.CategoricalAccuracy(), 'mae'])
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+
+ # Test with dictionaries for loss, metrics, loss weights
+ loss = {'dense': 'mse', 'dropout': 'mae'}
+ loss_weights = {'dense': 1., 'dropout': 0.5}
+ metrics = {
+ 'dense': 'mse',
+ 'dropout': metrics_module.CategoricalAccuracy()
+ }
+ model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights)
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+
+ # Invalid use cases
+ with self.assertRaises(ValueError):
+ model.train_on_batch({'input_a': input_a_np},
+ [output_d_np, output_e_np])
+ with self.assertRaises(AttributeError):
model.fit(
[input_a_np, input_b_np], [output_d_np, output_e_np],
epochs=1,
- batch_size=5,
+ validation_data=([input_a_np, input_b_np], 0, 0),
verbose=0)
+ with self.assertRaises(ValueError):
+ model.train_on_batch([input_a_np], [output_d_np, output_e_np])
+ with self.assertRaises(AttributeError):
+ model.train_on_batch(1, [output_d_np, output_e_np])
+ with self.assertRaises(ValueError):
+ model.train_on_batch(input_a_np, [output_d_np, output_e_np])
+ with self.assertRaises(ValueError):
+ bad_input = np.random.random((11, 3))
+ model.train_on_batch([bad_input, input_b_np],
+ [output_d_np, output_e_np])
+ with self.assertRaises(ValueError):
+ bad_target = np.random.random((11, 4))
+ model.train_on_batch([input_a_np, input_b_np],
+ [bad_target, output_e_np])
+
+ # Build single-input model
+ x = keras.layers.Input(shape=(3,), name='input_a')
+ y = keras.layers.Dense(4)(x)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer, loss='mse')
+ # This will work
+ model.fit([input_a_np], output_d_np, epochs=1)
+ with self.assertRaises(ValueError):
+ model.fit([input_a_np, input_a_np], output_d_np, epochs=1)
- # Invalid use cases
- with self.assertRaises(ValueError):
- model.train_on_batch({'input_a': input_a_np},
- [output_d_np, output_e_np])
- with self.assertRaises(AttributeError):
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- epochs=1,
- validation_data=([input_a_np, input_b_np], 0, 0),
- verbose=0)
- with self.assertRaises(ValueError):
- model.train_on_batch([input_a_np], [output_d_np, output_e_np])
- with self.assertRaises(AttributeError):
- model.train_on_batch(1, [output_d_np, output_e_np])
- with self.assertRaises(ValueError):
- model.train_on_batch(input_a_np, [output_d_np, output_e_np])
- with self.assertRaises(ValueError):
- bad_input = np.random.random((11, 3))
- model.train_on_batch([bad_input, input_b_np],
- [output_d_np, output_e_np])
- with self.assertRaises(ValueError):
- bad_target = np.random.random((11, 4))
- model.train_on_batch([input_a_np, input_b_np],
- [bad_target, output_e_np])
-
- # Build single-input model
- x = keras.layers.Input(shape=(3,), name='input_a')
- y = keras.layers.Dense(4)(x)
- model = keras.models.Model(x, y)
- model.compile(optimizer, loss='mse')
- # This will work
- model.fit([input_a_np], output_d_np, epochs=1)
- with self.assertRaises(ValueError):
- model.fit([input_a_np, input_a_np], output_d_np, epochs=1)
-
- # Test model on a list of floats
- input_a_np = np.random.random((10, 3))
- input_b_np = np.random.random((10, 4))
+ # Test model on a list of floats
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 4))
- model.fit([np.ndarray.tolist(input_a_np)],
- [np.ndarray.tolist(input_b_np)],
- epochs=2,
- batch_size=5,
- verbose=2)
+ model.fit([np.ndarray.tolist(input_a_np)],
+ [np.ndarray.tolist(input_b_np)],
+ epochs=2,
+ batch_size=5,
+ verbose=2)
@tf_test_util.run_in_graph_and_eager_modes
def test_evaluate_predict_on_arrays(self):
- with self.test_session():
- a = keras.layers.Input(shape=(3,), name='input_a')
- b = keras.layers.Input(shape=(3,), name='input_b')
-
- dense = keras.layers.Dense(4, name='dense')
- c = dense(a)
- d = dense(b)
- e = keras.layers.Dropout(0.5, name='dropout')(c)
-
- model = keras.models.Model([a, b], [d, e])
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- loss = 'mse'
- loss_weights = [1., 0.5]
- model.compile(
- optimizer,
- loss,
- metrics=['mae', metrics_module.CategoricalAccuracy()],
- loss_weights=loss_weights,
- sample_weight_mode=None)
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
- input_a_np = np.random.random((10, 3))
- input_b_np = np.random.random((10, 3))
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
- output_d_np = np.random.random((10, 4))
- output_e_np = np.random.random((10, 4))
+ model = keras.models.Model([a, b], [d, e])
- # Test evaluate at different verbosity
- out = model.evaluate(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- batch_size=5,
- verbose=0)
- self.assertEqual(len(out), 7)
- out = model.evaluate(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- batch_size=5,
- verbose=1)
- self.assertEqual(len(out), 7)
- out = model.evaluate(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- batch_size=5,
- verbose=2)
- self.assertEqual(len(out), 7)
- out = model.test_on_batch([input_a_np, input_b_np],
- [output_d_np, output_e_np])
- self.assertEqual(len(out), 7)
-
- # Test evaluate with dictionary inputs
- model.evaluate(
- {
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {
- 'dense': output_d_np,
- 'dropout': output_e_np
- },
- batch_size=5,
- verbose=0)
- model.evaluate(
- {
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {
- 'dense': output_d_np,
- 'dropout': output_e_np
- },
- batch_size=5,
- verbose=1)
-
- # Test predict
- out = model.predict([input_a_np, input_b_np], batch_size=5)
- self.assertEqual(len(out), 2)
- out = model.predict({'input_a': input_a_np, 'input_b': input_b_np})
- self.assertEqual(len(out), 2)
- out = model.predict_on_batch({
- 'input_a': input_a_np,
- 'input_b': input_b_np
- })
- self.assertEqual(len(out), 2)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ loss_weights = [1., 0.5]
+ model.compile(
+ optimizer,
+ loss,
+ metrics=['mae', metrics_module.CategoricalAccuracy()],
+ loss_weights=loss_weights,
+ sample_weight_mode=None)
+
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 3))
+
+ output_d_np = np.random.random((10, 4))
+ output_e_np = np.random.random((10, 4))
+
+ # Test evaluate at different verbosity
+ out = model.evaluate(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ batch_size=5,
+ verbose=0)
+ self.assertEqual(len(out), 7)
+ out = model.evaluate(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ batch_size=5,
+ verbose=1)
+ self.assertEqual(len(out), 7)
+ out = model.evaluate(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ batch_size=5,
+ verbose=2)
+ self.assertEqual(len(out), 7)
+ out = model.test_on_batch([input_a_np, input_b_np],
+ [output_d_np, output_e_np])
+ self.assertEqual(len(out), 7)
+
+ # Test evaluate with dictionary inputs
+ model.evaluate(
+ {
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
+ batch_size=5,
+ verbose=0)
+ model.evaluate(
+ {
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
+ batch_size=5,
+ verbose=1)
+
+ # Test predict
+ out = model.predict([input_a_np, input_b_np], batch_size=5)
+ self.assertEqual(len(out), 2)
+ out = model.predict({'input_a': input_a_np, 'input_b': input_b_np})
+ self.assertEqual(len(out), 2)
+ out = model.predict_on_batch({
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ })
+ self.assertEqual(len(out), 2)
@tf_test_util.run_in_graph_and_eager_modes
def test_invalid_loss(self):
@@ -340,31 +338,27 @@ class TrainingTest(test.TestCase):
test_samples = 1000
input_dim = 5
- with self.test_session():
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(10, input_shape=(input_dim,)))
- model.add(keras.layers.Activation('relu'))
- model.add(keras.layers.Dense(num_classes))
- model.add(keras.layers.Activation('softmax'))
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- model.compile(optimizer, loss='categorical_crossentropy')
- np.random.seed(1337)
- (x_train, y_train), (_, _) = testing_utils.get_test_data(
- train_samples=train_samples,
- test_samples=test_samples,
- input_shape=(input_dim,),
- num_classes=num_classes)
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=10, num_classes=num_classes, input_dim=input_dim)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ model.compile(optimizer, loss='categorical_crossentropy')
+ np.random.seed(1337)
+ (x_train, y_train), (_, _) = testing_utils.get_test_data(
+ train_samples=train_samples,
+ test_samples=test_samples,
+ input_shape=(input_dim,),
+ num_classes=num_classes)
- with self.assertRaises(ValueError):
- model.fit(x_train, np.concatenate([y_train, y_train], axis=-1))
+ with self.assertRaises(ValueError):
+ model.fit(x_train, np.concatenate([y_train, y_train], axis=-1))
- if not context.executing_eagerly():
- # TODO(psv): Investigate these use cases in eager mode.
- with self.assertRaises(ValueError):
- model.fit(x_train, y_train)
+ if not context.executing_eagerly():
+ # TODO(psv): Investigate these use cases in eager mode.
+ with self.assertRaises(ValueError):
+ model.fit(x_train, y_train)
- with self.assertRaises(ValueError):
- model.compile(optimizer, loss=None)
+ with self.assertRaises(ValueError):
+ model.compile(optimizer, loss=None)
def test_training_on_sparse_data_with_dense_placeholders(self):
if scipy_sparse is None:
@@ -468,67 +462,63 @@ class LossWeightingTest(test.TestCase):
input_dim = 5
learning_rate = 0.001
- with self.test_session():
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(10, input_shape=(input_dim,)))
- model.add(keras.layers.Activation('relu'))
- model.add(keras.layers.Dense(num_classes))
- model.add(keras.layers.Activation('softmax'))
- model.compile(
- loss='categorical_crossentropy',
- metrics=['acc'],
- weighted_metrics=['mae'],
- optimizer=RMSPropOptimizer(learning_rate=learning_rate))
-
- np.random.seed(1337)
- (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
- train_samples=train_samples,
- test_samples=test_samples,
- input_shape=(input_dim,),
- num_classes=num_classes)
- int_y_test = y_test.copy()
- int_y_train = y_train.copy()
- # convert class vectors to binary class matrices
- y_train = keras.utils.to_categorical(y_train, num_classes)
- y_test = keras.utils.to_categorical(y_test, num_classes)
- test_ids = np.where(int_y_test == np.array(weighted_class))[0]
-
- class_weight = dict([(i, 1.) for i in range(num_classes)])
- class_weight[weighted_class] = 2.
-
- sample_weight = np.ones((y_train.shape[0]))
- sample_weight[int_y_train == weighted_class] = 2.
-
- model.fit(
- x_train,
- y_train,
- batch_size=batch_size,
- epochs=epochs // 3,
- verbose=0,
- class_weight=class_weight,
- validation_data=(x_train, y_train, sample_weight))
- model.fit(
- x_train,
- y_train,
- batch_size=batch_size,
- epochs=epochs // 2,
- verbose=0,
- class_weight=class_weight)
- model.fit(
- x_train,
- y_train,
- batch_size=batch_size,
- epochs=epochs // 2,
- verbose=0,
- class_weight=class_weight,
- validation_split=0.1)
-
- model.train_on_batch(
- x_train[:batch_size], y_train[:batch_size], class_weight=class_weight)
- ref_score = model.evaluate(x_test, y_test, verbose=0)
- score = model.evaluate(
- x_test[test_ids, :], y_test[test_ids, :], verbose=0)
- self.assertLess(score[0], ref_score[0])
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=10, num_classes=num_classes, input_dim=input_dim)
+ model.compile(
+ loss='categorical_crossentropy',
+ metrics=['acc'],
+ weighted_metrics=['mae'],
+ optimizer=RMSPropOptimizer(learning_rate=learning_rate))
+
+ np.random.seed(1337)
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=train_samples,
+ test_samples=test_samples,
+ input_shape=(input_dim,),
+ num_classes=num_classes)
+ int_y_test = y_test.copy()
+ int_y_train = y_train.copy()
+ # convert class vectors to binary class matrices
+ y_train = keras.utils.to_categorical(y_train, num_classes)
+ y_test = keras.utils.to_categorical(y_test, num_classes)
+ test_ids = np.where(int_y_test == np.array(weighted_class))[0]
+
+ class_weight = dict([(i, 1.) for i in range(num_classes)])
+ class_weight[weighted_class] = 2.
+
+ sample_weight = np.ones((y_train.shape[0]))
+ sample_weight[int_y_train == weighted_class] = 2.
+
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=epochs // 3,
+ verbose=0,
+ class_weight=class_weight,
+ validation_data=(x_train, y_train, sample_weight))
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=epochs // 2,
+ verbose=0,
+ class_weight=class_weight)
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=epochs // 2,
+ verbose=0,
+ class_weight=class_weight,
+ validation_split=0.1)
+
+ model.train_on_batch(
+ x_train[:batch_size], y_train[:batch_size], class_weight=class_weight)
+ ref_score = model.evaluate(x_test, y_test, verbose=0)
+ score = model.evaluate(
+ x_test[test_ids, :], y_test[test_ids, :], verbose=0)
+ self.assertLess(score[0], ref_score[0])
@tf_test_util.run_in_graph_and_eager_modes
def test_sample_weights(self):
@@ -541,63 +531,82 @@ class LossWeightingTest(test.TestCase):
input_dim = 5
learning_rate = 0.001
- with self.test_session():
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(10, input_shape=(input_dim,)))
- model.add(keras.layers.Activation('relu'))
- model.add(keras.layers.Dense(num_classes))
- model.add(keras.layers.Activation('softmax'))
- model.compile(
- RMSPropOptimizer(learning_rate=learning_rate),
- metrics=['acc'],
- weighted_metrics=['mae'],
- loss='categorical_crossentropy')
-
- np.random.seed(43)
- (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
- train_samples=train_samples,
- test_samples=test_samples,
- input_shape=(input_dim,),
- num_classes=num_classes)
- int_y_test = y_test.copy()
- int_y_train = y_train.copy()
- # convert class vectors to binary class matrices
- y_train = keras.utils.to_categorical(y_train, num_classes)
- y_test = keras.utils.to_categorical(y_test, num_classes)
- test_ids = np.where(int_y_test == np.array(weighted_class))[0]
-
- sample_weight = np.ones((y_train.shape[0]))
- sample_weight[int_y_train == weighted_class] = 2.
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=10, num_classes=num_classes, input_dim=input_dim)
+ model.compile(
+ RMSPropOptimizer(learning_rate=learning_rate),
+ metrics=['acc'],
+ weighted_metrics=['mae'],
+ loss='categorical_crossentropy')
+
+ np.random.seed(43)
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=train_samples,
+ test_samples=test_samples,
+ input_shape=(input_dim,),
+ num_classes=num_classes)
+ int_y_test = y_test.copy()
+ int_y_train = y_train.copy()
+ # convert class vectors to binary class matrices
+ y_train = keras.utils.to_categorical(y_train, num_classes)
+ y_test = keras.utils.to_categorical(y_test, num_classes)
+ test_ids = np.where(int_y_test == np.array(weighted_class))[0]
+
+ sample_weight = np.ones((y_train.shape[0]))
+ sample_weight[int_y_train == weighted_class] = 2.
+
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=epochs // 3,
+ verbose=0,
+ sample_weight=sample_weight)
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=epochs // 3,
+ verbose=0,
+ sample_weight=sample_weight,
+ validation_split=0.1)
+
+ model.train_on_batch(
+ x_train[:batch_size],
+ y_train[:batch_size],
+ sample_weight=sample_weight[:batch_size])
+ model.test_on_batch(
+ x_train[:batch_size],
+ y_train[:batch_size],
+ sample_weight=sample_weight[:batch_size])
+ ref_score = model.evaluate(x_test, y_test, verbose=0)
+ if not context.executing_eagerly():
+ score = model.evaluate(
+ x_test[test_ids, :], y_test[test_ids, :], verbose=0)
+ self.assertLess(score[0], ref_score[0])
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_warning_for_concurrent_sample_and_class_weights(self):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(10, input_shape=(3,)))
+ model.compile(
+ loss='mse',
+ optimizer=RMSPropOptimizer(learning_rate=0.01))
+ x_train = np.random.random((10, 3))
+ y_train = np.random.random((10, 10))
+ sample_weight = np.ones((y_train.shape[0]))
+ class_weight = {0: 1., 1: 1.}
+
+ with test.mock.patch.object(logging, 'warning') as mock_log:
model.fit(
x_train,
y_train,
- batch_size=batch_size,
- epochs=epochs // 3,
- verbose=0,
- sample_weight=sample_weight)
- model.fit(
- x_train,
- y_train,
- batch_size=batch_size,
- epochs=epochs // 3,
+ epochs=1,
verbose=0,
sample_weight=sample_weight,
- validation_split=0.1)
-
- model.train_on_batch(
- x_train[:batch_size],
- y_train[:batch_size],
- sample_weight=sample_weight[:batch_size])
- model.test_on_batch(
- x_train[:batch_size],
- y_train[:batch_size],
- sample_weight=sample_weight[:batch_size])
- ref_score = model.evaluate(x_test, y_test, verbose=0)
- if not context.executing_eagerly():
- score = model.evaluate(
- x_test[test_ids, :], y_test[test_ids, :], verbose=0)
- self.assertLess(score[0], ref_score[0])
+ class_weight=class_weight)
+ msg = ('The `class_weight` argument will be ignored.')
+ self.assertRegexpMatches(str(mock_log.call_args), msg)
@tf_test_util.run_in_graph_and_eager_modes
def test_temporal_sample_weights(self):
@@ -1886,223 +1895,198 @@ class TestTrainingWithDatasetIterators(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_training_and_eval_methods_on_iterators_single_io(self):
- with self.test_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- loss = 'mse'
- metrics = ['mae', metrics_module.CategoricalAccuracy()]
- model.compile(optimizer, loss, metrics=metrics)
-
- inputs = np.zeros((10, 3))
- targets = np.zeros((10, 4))
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
- iterator = dataset.make_one_shot_iterator()
-
- model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1)
- model.evaluate(iterator, steps=2, verbose=1)
- model.predict(iterator, steps=2)
- model.train_on_batch(iterator)
- model.test_on_batch(iterator)
- model.predict_on_batch(iterator)
-
- # Test with validation data
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ metrics = ['mae', metrics_module.CategoricalAccuracy()]
+ model.compile(optimizer, loss, metrics=metrics)
+
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
+
+ model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1)
+ model.evaluate(iterator, steps=2, verbose=1)
+ model.predict(iterator, steps=2)
+ model.train_on_batch(iterator)
+ model.test_on_batch(iterator)
+ model.predict_on_batch(iterator)
+
+ # Test with validation data
+ model.fit(iterator,
+ epochs=1, steps_per_epoch=2, verbose=0,
+ validation_data=iterator, validation_steps=2)
+ # Test with validation split
+ with self.assertRaisesRegexp(
+ ValueError, '`validation_split` argument is not supported '
+ 'when input `x` is a dataset or a dataset iterator'):
model.fit(iterator,
epochs=1, steps_per_epoch=2, verbose=0,
- validation_data=iterator, validation_steps=2)
- # Test with validation split
- with self.assertRaisesRegexp(
- ValueError, '`validation_split` argument is not supported '
- 'when input `x` is a dataset or a dataset iterator'):
- model.fit(iterator,
- epochs=1, steps_per_epoch=2, verbose=0,
- validation_split=0.5, validation_steps=2)
-
- # Test with sample weight.
- sample_weight = np.random.random((10,))
- with self.assertRaisesRegexp(
- ValueError, '`sample_weight` argument is not supported '
- 'when input `x` is a dataset or a dataset iterator'):
- model.fit(
- iterator,
- epochs=1,
- steps_per_epoch=2,
- verbose=0,
- sample_weight=sample_weight)
+ validation_split=0.5, validation_steps=2)
- # Test invalid usage
- with self.assertRaisesRegexp(ValueError,
- 'you should not specify a target'):
- model.fit(iterator, iterator,
- epochs=1, steps_per_epoch=2, verbose=0)
+ # Test with sample weight.
+ sample_weight = np.random.random((10,))
+ with self.assertRaisesRegexp(
+ ValueError, '`sample_weight` argument is not supported '
+ 'when input `x` is a dataset or a dataset iterator'):
+ model.fit(
+ iterator,
+ epochs=1,
+ steps_per_epoch=2,
+ verbose=0,
+ sample_weight=sample_weight)
- with self.assertRaisesRegexp(
- ValueError, 'you should specify the `steps_per_epoch` argument'):
- model.fit(iterator, epochs=1, verbose=0)
- with self.assertRaisesRegexp(ValueError,
- 'you should specify the `steps` argument'):
- model.evaluate(iterator, verbose=0)
- with self.assertRaisesRegexp(ValueError,
- 'you should specify the `steps` argument'):
- model.predict(iterator, verbose=0)
+ # Test invalid usage
+ with self.assertRaisesRegexp(ValueError,
+ 'you should not specify a target'):
+ model.fit(iterator, iterator,
+ epochs=1, steps_per_epoch=2, verbose=0)
+
+ with self.assertRaisesRegexp(
+ ValueError, 'you should specify the `steps_per_epoch` argument'):
+ model.fit(iterator, epochs=1, verbose=0)
+ with self.assertRaisesRegexp(ValueError,
+ 'you should specify the `steps` argument'):
+ model.evaluate(iterator, verbose=0)
+ with self.assertRaisesRegexp(ValueError,
+ 'you should specify the `steps` argument'):
+ model.predict(iterator, verbose=0)
@tf_test_util.run_in_graph_and_eager_modes
def test_get_next_op_created_once(self):
- with self.test_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- loss = 'mse'
- metrics = ['mae']
- model.compile(optimizer, loss, metrics=metrics)
-
- inputs = np.zeros((10, 3))
- targets = np.zeros((10, 4))
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
- iterator = dataset.make_one_shot_iterator()
-
- model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1)
- # Finalize graph to make sure we are not appending another iterator
- # get_next op in the graph.
- ops.get_default_graph().finalize()
- model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1)
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ model.compile(optimizer, loss, metrics=metrics)
+
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
+
+ model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1)
+ # Finalize graph to make sure we are not appending another iterator
+ # get_next op in the graph.
+ ops.get_default_graph().finalize()
+ model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1)
@tf_test_util.run_in_graph_and_eager_modes
def test_iterators_running_out_of_data(self):
- with self.test_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- loss = 'mse'
- metrics = ['mae']
- model.compile(optimizer, loss, metrics=metrics)
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ model.compile(optimizer, loss, metrics=metrics)
- inputs = np.zeros((10, 3))
- targets = np.zeros((10, 4))
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(2)
- dataset = dataset.batch(10)
- iterator = dataset.make_one_shot_iterator()
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(2)
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
- with test.mock.patch.object(logging, 'warning') as mock_log:
- model.fit(iterator, epochs=1, steps_per_epoch=3, verbose=0)
- self.assertRegexpMatches(
- str(mock_log.call_args),
- 'dataset iterator ran out of data')
+ with test.mock.patch.object(logging, 'warning') as mock_log:
+ model.fit(iterator, epochs=1, steps_per_epoch=3, verbose=0)
+ self.assertRegexpMatches(
+ str(mock_log.call_args),
+ 'dataset iterator ran out of data')
class TestTrainingWithDataset(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_calling_model_on_same_dataset(self):
- with self.test_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- loss = 'mse'
- metrics = ['mae']
- model.compile(optimizer, loss, metrics=metrics)
-
- inputs = np.zeros((10, 3))
- targets = np.zeros((10, 4))
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
-
- # Call fit with validation data
- model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
- validation_data=dataset, validation_steps=2)
- # Finalize the graph to make sure new ops aren't added when calling on the
- # same dataset
- ops.get_default_graph().finalize()
- model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
- validation_data=dataset, validation_steps=2)
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ model.compile(optimizer, loss, metrics=metrics)
+
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+
+ # Call fit with validation data
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+ validation_data=dataset, validation_steps=2)
+ # Finalize the graph to make sure new ops aren't added when calling on the
+ # same dataset
+ ops.get_default_graph().finalize()
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+ validation_data=dataset, validation_steps=2)
@tf_test_util.run_in_graph_and_eager_modes
def test_training_and_eval_methods_on_dataset(self):
- with self.test_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- loss = 'mse'
- metrics = ['mae', metrics_module.CategoricalAccuracy()]
- model.compile(optimizer, loss, metrics=metrics)
-
- inputs = np.zeros((10, 3))
- targets = np.zeros((10, 4))
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
-
- model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
- model.evaluate(dataset, steps=2, verbose=1)
- model.predict(dataset, steps=2)
- model.train_on_batch(dataset)
- model.predict_on_batch(dataset)
-
- # Test with validation data
- model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
- validation_data=dataset, validation_steps=2)
-
- # Test with validation split
- with self.assertRaisesRegexp(
- ValueError, '`validation_split` argument is not supported '
- 'when input `x` is a dataset or a dataset iterator'):
- model.fit(dataset,
- epochs=1, steps_per_epoch=2, verbose=0,
- validation_split=0.5, validation_steps=2)
-
- # Test with sample weight.
- sample_weight = np.random.random((10,))
- with self.assertRaisesRegexp(
- ValueError, '`sample_weight` argument is not supported '
- 'when input `x` is a dataset or a dataset iterator'):
- model.fit(
- dataset,
- epochs=1,
- steps_per_epoch=2,
- verbose=0,
- sample_weight=sample_weight)
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ metrics = ['mae', metrics_module.CategoricalAccuracy()]
+ model.compile(optimizer, loss, metrics=metrics)
+
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
+ model.evaluate(dataset, steps=2, verbose=1)
+ model.predict(dataset, steps=2)
+ model.train_on_batch(dataset)
+ model.predict_on_batch(dataset)
+
+ # Test with validation data
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+ validation_data=dataset, validation_steps=2)
+
+ # Test with validation split
+ with self.assertRaisesRegexp(
+ ValueError, '`validation_split` argument is not supported '
+ 'when input `x` is a dataset or a dataset iterator'):
+ model.fit(dataset,
+ epochs=1, steps_per_epoch=2, verbose=0,
+ validation_split=0.5, validation_steps=2)
- # Test invalid usage
- with self.assertRaisesRegexp(ValueError,
- 'you should not specify a target'):
- model.fit(dataset, dataset,
- epochs=1, steps_per_epoch=2, verbose=0)
+ # Test with sample weight.
+ sample_weight = np.random.random((10,))
+ with self.assertRaisesRegexp(
+ ValueError, '`sample_weight` argument is not supported '
+ 'when input `x` is a dataset or a dataset iterator'):
+ model.fit(
+ dataset,
+ epochs=1,
+ steps_per_epoch=2,
+ verbose=0,
+ sample_weight=sample_weight)
- with self.assertRaisesRegexp(
- ValueError, 'you should specify the `steps_per_epoch` argument'):
- model.fit(dataset, epochs=1, verbose=0)
- with self.assertRaisesRegexp(ValueError,
- 'you should specify the `steps` argument'):
- model.evaluate(dataset, verbose=0)
- with self.assertRaisesRegexp(ValueError,
- 'you should specify the `steps` argument'):
- model.predict(dataset, verbose=0)
+ # Test invalid usage
+ with self.assertRaisesRegexp(ValueError,
+ 'you should not specify a target'):
+ model.fit(dataset, dataset,
+ epochs=1, steps_per_epoch=2, verbose=0)
+
+ with self.assertRaisesRegexp(
+ ValueError, 'you should specify the `steps_per_epoch` argument'):
+ model.fit(dataset, epochs=1, verbose=0)
+ with self.assertRaisesRegexp(ValueError,
+ 'you should specify the `steps` argument'):
+ model.evaluate(dataset, verbose=0)
+ with self.assertRaisesRegexp(ValueError,
+ 'you should specify the `steps` argument'):
+ model.predict(dataset, verbose=0)
def test_dataset_input_shape_validation(self):
with self.test_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- loss = 'mse'
- model.compile(optimizer, loss)
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse')
# User forgets to batch the dataset
inputs = np.zeros((10, 3))
@@ -2111,7 +2095,7 @@ class TestTrainingWithDataset(test.TestCase):
dataset = dataset.repeat(100)
with self.assertRaisesRegexp(ValueError,
- 'expected input to have 2 dimensions'):
+ r'expected (.*?) to have 2 dimensions'):
model.train_on_batch(dataset)
# Wrong input shape
@@ -2122,7 +2106,7 @@ class TestTrainingWithDataset(test.TestCase):
dataset = dataset.batch(10)
with self.assertRaisesRegexp(ValueError,
- 'expected input to have shape'):
+ r'expected (.*?) to have shape \(3,\)'):
model.train_on_batch(dataset)
@@ -2153,134 +2137,127 @@ class TestTrainingWithMetrics(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_metrics_correctness(self):
- with self.test_session():
- model = keras.Sequential()
- model.add(
- keras.layers.Dense(
- 3, activation='relu', input_dim=4, kernel_initializer='ones'))
- model.add(
- keras.layers.Dense(
- 1, activation='sigmoid', kernel_initializer='ones'))
- model.compile(
- loss='mae',
- metrics=['accuracy', metrics_module.BinaryAccuracy()],
- optimizer=RMSPropOptimizer(learning_rate=0.001))
-
- # verify correctness of stateful and stateless metrics.
- x = np.ones((100, 4))
- y = np.ones((100, 1))
- outs = model.evaluate(x, y)
- self.assertEqual(outs[1], 1.)
- self.assertEqual(outs[2], 1.)
-
- y = np.zeros((100, 1))
- outs = model.evaluate(x, y)
- self.assertEqual(outs[1], 0.)
- self.assertEqual(outs[2], 0.)
+ model = keras.Sequential()
+ model.add(
+ keras.layers.Dense(
+ 3, activation='relu', input_dim=4, kernel_initializer='ones'))
+ model.add(
+ keras.layers.Dense(
+ 1, activation='sigmoid', kernel_initializer='ones'))
+ model.compile(
+ loss='mae',
+ metrics=['accuracy', metrics_module.BinaryAccuracy()],
+ optimizer=RMSPropOptimizer(learning_rate=0.001))
+
+ # verify correctness of stateful and stateless metrics.
+ x = np.ones((100, 4))
+ y = np.ones((100, 1))
+ outs = model.evaluate(x, y)
+ self.assertEqual(outs[1], 1.)
+ self.assertEqual(outs[2], 1.)
+
+ y = np.zeros((100, 1))
+ outs = model.evaluate(x, y)
+ self.assertEqual(outs[1], 0.)
+ self.assertEqual(outs[2], 0.)
@tf_test_util.run_in_graph_and_eager_modes
def test_metrics_correctness_with_iterator(self):
- with self.test_session():
- model = keras.Sequential()
- model.add(
- keras.layers.Dense(
- 8, activation='relu', input_dim=4, kernel_initializer='ones'))
- model.add(
- keras.layers.Dense(
- 1, activation='sigmoid', kernel_initializer='ones'))
- model.compile(
- loss='binary_crossentropy',
- metrics=['accuracy', metrics_module.BinaryAccuracy()],
- optimizer=RMSPropOptimizer(learning_rate=0.001))
-
- np.random.seed(123)
- x = np.random.randint(10, size=(100, 4)).astype(np.float32)
- y = np.random.randint(2, size=(100, 1)).astype(np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
- dataset = dataset.batch(10)
- iterator = dataset.make_one_shot_iterator()
- outs = model.evaluate(iterator, steps=10)
- self.assertEqual(np.around(outs[1], decimals=1), 0.5)
- self.assertEqual(np.around(outs[2], decimals=1), 0.5)
-
- y = np.zeros((100, 1), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
- iterator = dataset.make_one_shot_iterator()
- outs = model.evaluate(iterator, steps=10)
- self.assertEqual(outs[1], 0.)
- self.assertEqual(outs[2], 0.)
+ model = keras.Sequential()
+ model.add(
+ keras.layers.Dense(
+ 8, activation='relu', input_dim=4, kernel_initializer='ones'))
+ model.add(
+ keras.layers.Dense(
+ 1, activation='sigmoid', kernel_initializer='ones'))
+ model.compile(
+ loss='binary_crossentropy',
+ metrics=['accuracy', metrics_module.BinaryAccuracy()],
+ optimizer=RMSPropOptimizer(learning_rate=0.001))
+
+ np.random.seed(123)
+ x = np.random.randint(10, size=(100, 4)).astype(np.float32)
+ y = np.random.randint(2, size=(100, 1)).astype(np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
+ outs = model.evaluate(iterator, steps=10)
+ self.assertEqual(np.around(outs[1], decimals=1), 0.5)
+ self.assertEqual(np.around(outs[2], decimals=1), 0.5)
+
+ y = np.zeros((100, 1), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
+ outs = model.evaluate(iterator, steps=10)
+ self.assertEqual(outs[1], 0.)
+ self.assertEqual(outs[2], 0.)
@tf_test_util.run_in_graph_and_eager_modes
def test_metrics_correctness_with_weighted_metrics(self):
- with self.test_session():
- np.random.seed(1337)
- x = np.array([[[1.], [1.]], [[0.], [0.]]])
- model = keras.models.Sequential()
- model.add(
- keras.layers.TimeDistributed(
- keras.layers.Dense(1, kernel_initializer='ones'),
- input_shape=(2, 1)))
- model.compile(
- RMSPropOptimizer(learning_rate=0.001),
- loss='mse',
- sample_weight_mode='temporal',
- weighted_metrics=['accuracy',
- metrics_module.BinaryAccuracy()])
- y = np.array([[[1.], [1.]], [[1.], [1.]]])
-
- outs = model.evaluate(x, y)
- self.assertEqual(outs, [0.5, 0.5, 0.5])
-
- w = np.array([[0., 0.], [0., 0.]])
- outs = model.evaluate(x, y, sample_weight=w)
- self.assertEqual(outs, [0., 0., 0.])
-
- w = np.array([[3., 4.], [1., 2.]])
- outs = model.evaluate(x, y, sample_weight=w)
- self.assertArrayNear(outs, [0.3, 0.7, 0.7], .001)
+ np.random.seed(1337)
+ x = np.array([[[1.], [1.]], [[0.], [0.]]])
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.TimeDistributed(
+ keras.layers.Dense(1, kernel_initializer='ones'),
+ input_shape=(2, 1)))
+ model.compile(
+ RMSPropOptimizer(learning_rate=0.001),
+ loss='mse',
+ sample_weight_mode='temporal',
+ weighted_metrics=['accuracy',
+ metrics_module.BinaryAccuracy()])
+ y = np.array([[[1.], [1.]], [[1.], [1.]]])
+
+ outs = model.evaluate(x, y)
+ self.assertEqual(outs, [0.5, 0.5, 0.5])
+
+ w = np.array([[0., 0.], [0., 0.]])
+ outs = model.evaluate(x, y, sample_weight=w)
+ self.assertEqual(outs, [0., 0., 0.])
+
+ w = np.array([[3., 4.], [1., 2.]])
+ outs = model.evaluate(x, y, sample_weight=w)
+ self.assertArrayNear(outs, [0.3, 0.7, 0.7], .001)
@tf_test_util.run_in_graph_and_eager_modes
def test_metric_state_reset_between_fit_and_evaluate(self):
- with self.test_session():
- model = keras.Sequential()
- model.add(keras.layers.Dense(3, activation='relu', input_dim=4))
- model.add(keras.layers.Dense(1, activation='sigmoid'))
- acc_obj = metrics_module.BinaryAccuracy()
- model.compile(
- loss='mae',
- metrics=[acc_obj],
- optimizer=RMSPropOptimizer(learning_rate=0.001))
-
- x_train = np.random.random((100, 4))
- y_train = np.random.random((100, 1))
- model.fit(x_train, y_train, batch_size=5, epochs=2)
- self.assertEqual(self.evaluate(acc_obj.count), 100)
-
- x_test = np.random.random((10, 4))
- y_test = np.random.random((10, 1))
- model.evaluate(x_test, y_test, batch_size=5)
- self.assertEqual(self.evaluate(acc_obj.count), 10)
+ model = keras.Sequential()
+ model.add(keras.layers.Dense(3, activation='relu', input_dim=4))
+ model.add(keras.layers.Dense(1, activation='sigmoid'))
+ acc_obj = metrics_module.BinaryAccuracy()
+ model.compile(
+ loss='mae',
+ metrics=[acc_obj],
+ optimizer=RMSPropOptimizer(learning_rate=0.001))
+
+ x_train = np.random.random((100, 4))
+ y_train = np.random.random((100, 1))
+ model.fit(x_train, y_train, batch_size=5, epochs=2)
+ self.assertEqual(self.evaluate(acc_obj.count), 100)
+
+ x_test = np.random.random((10, 4))
+ y_test = np.random.random((10, 1))
+ model.evaluate(x_test, y_test, batch_size=5)
+ self.assertEqual(self.evaluate(acc_obj.count), 10)
@tf_test_util.run_in_graph_and_eager_modes
def test_invalid_metrics(self):
num_classes = 5
input_dim = 5
- with self.test_session():
- model = keras.models.Sequential()
- model.add(
- keras.layers.Dense(10, activation='relu', input_shape=(input_dim,)))
- model.add(keras.layers.Dense(num_classes, activation='softmax'))
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=10, num_classes=num_classes, input_dim=input_dim)
- with self.assertRaisesRegexp(
- TypeError, 'Type of `metrics` argument not understood. '
- 'Expected a list or dictionary, found: '):
- model.compile(
- RMSPropOptimizer(learning_rate=0.001),
- loss='categorical_crossentropy',
- metrics=metrics_module.CategoricalAccuracy())
+ with self.assertRaisesRegexp(
+ TypeError, 'Type of `metrics` argument not understood. '
+ 'Expected a list or dictionary, found: '):
+ model.compile(
+ RMSPropOptimizer(learning_rate=0.001),
+ loss='categorical_crossentropy',
+ metrics=metrics_module.CategoricalAccuracy())
@tf_test_util.run_in_graph_and_eager_modes
def test_metrics_masking(self):
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index 0bd6620220..b6aa9adb47 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -20,13 +20,19 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import saving
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.engine import training
+from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_layer import Input
from tensorflow.python.keras.engine.input_layer import InputLayer
+from tensorflow.python.keras.engine.network import Network
from tensorflow.python.keras.utils import generic_utils
-
+from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
+from tensorflow.python.training import training_util
+from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.checkpointable import data_structures
# API entries importable from `keras.models`:
Model = training.Model # pylint: disable=invalid-name
@@ -246,3 +252,213 @@ def clone_model(model, input_tensors=None):
return _clone_sequential_model(model, input_tensors=input_tensors)
else:
return _clone_functional_model(model, input_tensors=input_tensors)
+
+
+# "Clone" a subclassed model by reseting all of the attributes.
+
+
+def _in_place_subclassed_model_reset(model):
+ """Substitute for model cloning that works for subclassed models.
+
+ Subclassed models cannot be cloned because their topology is not serializable.
+ To "instantiate" an identical model in a new TF graph, we reuse the original
+ model object, but we clear its state.
+
+ After calling this function on a model instance, you can use the model
+ instance as if it were a model clone (in particular you can use it in a new
+ graph).
+
+ This method clears the state of the input model. It is thus destructive.
+ However the original state can be restored fully by calling
+ `_in_place_subclassed_model_state_restoration`.
+
+ Args:
+ model: Instance of a Keras model created via subclassing.
+
+ Raises:
+ ValueError: In case the model uses a subclassed model as inner layer.
+ """
+ assert not model._is_graph_network # Only makes sense for subclassed networks
+ # Retrieve all layers tracked by the model as well as their attribute names
+ attributes_cache = {}
+ for name in dir(model):
+ try:
+ value = getattr(model, name)
+ except (AttributeError, ValueError, TypeError):
+ continue
+ if isinstance(value, Layer):
+ attributes_cache[name] = value
+ assert value in model._layers
+ elif isinstance(value, (list, tuple)) and name not in ('layers', '_layers'):
+ # Handle case: list/tuple of layers (also tracked by the Network API).
+ if value and all(isinstance(val, Layer) for val in value):
+ raise ValueError('We do not support the use of list-of-layers '
+ 'attributes in subclassed models used with '
+ '`model_to_estimator` at this time. Found list '
+ 'model: %s' % name)
+
+ # Replace layers on the model with fresh layers
+ layers_to_names = {value: key for key, value in attributes_cache.items()}
+ original_layers = model._layers[:]
+ model._layers = data_structures.NoDependency([])
+ for layer in original_layers: # We preserve layer order.
+ config = layer.get_config()
+ # This will not work for nested subclassed models used as layers.
+ # This would be theoretically possible to support, but would add complexity.
+ # Only do it if users complain.
+ if isinstance(layer, Network) and not layer._is_graph_network:
+ raise ValueError('We do not support the use of nested subclassed models '
+ 'in `model_to_estimator` at this time. Found nested '
+ 'model: %s' % layer)
+ fresh_layer = layer.__class__.from_config(config)
+ name = layers_to_names[layer]
+ setattr(model, name, fresh_layer)
+
+ # Cache original model build attributes (in addition to layers)
+ if (not hasattr(model, '_original_attributes_cache') or
+ model._original_attributes_cache is None):
+ if model.built:
+ attributes_to_cache = [
+ 'inputs',
+ 'outputs',
+ '_feed_outputs',
+ '_feed_output_names',
+ '_feed_output_shapes',
+ '_feed_loss_fns',
+ 'loss_weights_list',
+ 'targets',
+ '_feed_targets',
+ 'sample_weight_modes',
+ 'weighted_metrics',
+ 'metrics_names',
+ 'metrics_tensors',
+ 'metrics_updates',
+ 'stateful_metric_names',
+ 'total_loss',
+ 'sample_weights',
+ '_feed_sample_weights',
+ 'train_function',
+ 'test_function',
+ 'predict_function',
+ '_collected_trainable_weights',
+ '_feed_inputs',
+ '_feed_input_names',
+ '_feed_input_shapes',
+ 'optimizer',
+ ]
+ for name in attributes_to_cache:
+ attributes_cache[name] = getattr(model, name)
+ model._original_attributes_cache = data_structures.NoDependency(
+ attributes_cache)
+ # Reset built state
+ model.built = False
+ model.inputs = None
+ model.outputs = None
+
+
+def in_place_subclassed_model_state_restoration(model):
+ """Restores the original state of a model after it was "reset".
+
+ This undoes this action of `_in_place_subclassed_model_reset`, which is called
+ in `clone_and_build_model` if `in_place_reset` is set to True.
+
+ Args:
+ model: Instance of a Keras model created via subclassing, on which
+ `_in_place_subclassed_model_reset` was previously called.
+ """
+ assert not model._is_graph_network
+ # Restore layers and build attributes
+ if (hasattr(model, '_original_attributes_cache') and
+ model._original_attributes_cache is not None):
+ # Models have sticky attribute assignment, so we want to be careful to add
+ # back the previous attributes and track Layers by their original names
+ # without adding dependencies on "utility" attributes which Models exempt
+ # when they're constructed.
+ model._layers = data_structures.NoDependency([])
+ for name, value in model._original_attributes_cache.items():
+ if not isinstance(value, checkpointable.CheckpointableBase):
+ # If this value is not already checkpointable, it's probably that way
+ # for a reason; we don't want to start tracking data structures that the
+ # original Model didn't.
+ value = data_structures.NoDependency(value)
+ setattr(model, name, value)
+ model._original_attributes_cache = None
+ else:
+ # Restore to the state of a never-called model.
+ model.built = False
+ model.inputs = None
+ model.outputs = None
+
+
+def clone_and_build_model(
+ model, input_tensors=None, target_tensors=None, custom_objects=None,
+ compile_clone=True, in_place_reset=False):
+ """Clone a `Model` and build/compile it with the same settings used before.
+
+ This function should be run in the same graph as the model.
+
+ Args:
+ model: `tf.keras.Model` object. Can be Functional, Sequential, or
+ sub-classed.
+ input_tensors: Optional list of input tensors to build the model upon. If
+ not provided, placeholders will be created.
+ target_tensors: Optional list of target tensors for compiling the model. If
+ not provided, placeholders will be created.
+ custom_objects: Optional dictionary mapping string names to custom classes
+ or functions.
+ compile_clone: Boolean, whether to compile model clone (default `True`).
+ in_place_reset: Boolean, whether to reset the model in place. Only used if
+ the model is not a graph network. If the model is a subclassed model, then
+ this argument must be set to `True` (default `False`). To restore the
+ original model, use the function
+ `in_place_subclassed_model_state_restoration(model)`.
+
+ Returns:
+ Clone of the model.
+
+ Raises:
+ ValueError: if trying to clone a subclassed model, and `in_place_reset` is
+ set to False.
+ """
+ if model._is_graph_network:
+ if custom_objects:
+ with CustomObjectScope(custom_objects):
+ clone = clone_model(model, input_tensors=input_tensors)
+ else:
+ clone = clone_model(model, input_tensors=input_tensors)
+ else:
+ if not in_place_reset:
+ raise ValueError(
+ 'Model is not a graph network (usually means that it is a subclassed '
+ 'model). The model cannot be cloned, but there is a workaround where '
+ 'the model is reset in-place. To use this, please set the argument '
+ '`in_place_reset` to `True`. This will reset the attributes in the '
+ 'original model. To restore the attributes, call '
+ '`in_place_subclassed_model_state_restoration(model)`.')
+ clone = model
+ _in_place_subclassed_model_reset(clone)
+ if input_tensors is not None:
+ clone._set_inputs(input_tensors)
+
+ # Compile/Build model
+ if not compile_clone:
+ if isinstance(clone, Sequential):
+ clone.build()
+ elif model.optimizer:
+ if isinstance(model.optimizer, optimizers.TFOptimizer):
+ optimizer = model.optimizer
+ else:
+ optimizer_config = model.optimizer.get_config()
+ optimizer = model.optimizer.__class__.from_config(optimizer_config)
+ optimizer.iterations = training_util.get_or_create_global_step()
+
+ clone.compile(
+ optimizer,
+ model.loss,
+ metrics=model.metrics,
+ loss_weights=model.loss_weights,
+ sample_weight_mode=model.sample_weight_mode,
+ weighted_metrics=model.weighted_metrics,
+ target_tensors=target_tensors)
+
+ return clone
diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py
index 1385ad5390..5f755f7b5e 100644
--- a/tensorflow/python/keras/models_test.py
+++ b/tensorflow/python/keras/models_test.py
@@ -24,6 +24,8 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.framework import test_util
+from tensorflow.python.keras import metrics
+from tensorflow.python.keras import models
from tensorflow.python.platform import test
from tensorflow.python.training import adam
@@ -169,6 +171,7 @@ class CheckpointingTests(test.TestCase):
model.load_weights(save_prefix)
self.assertEqual(12., self.evaluate(beta1_power))
+
class TestModelBackend(test.TestCase):
def test_model_backend_float64_use_cases(self):
@@ -183,5 +186,136 @@ class TestModelBackend(test.TestCase):
keras.backend.set_floatx(floatx)
+
+class TestCloneAndBuildModel(test.TestCase):
+
+ def test_clone_and_build_non_compiled_model(self):
+ with self.test_session():
+ inp = np.random.random((10, 4))
+ out = np.random.random((10, 4))
+
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(4, input_shape=(4,)))
+ model.add(keras.layers.BatchNormalization())
+ model.add(keras.layers.Dropout(0.5))
+ model.add(keras.layers.Dense(4))
+
+ # Everything should work in a new session.
+ keras.backend.clear_session()
+
+ with self.test_session():
+ # With placeholder creation
+ new_model = models.clone_and_build_model(model, compile_clone=True)
+ with self.assertRaisesRegexp(RuntimeError, 'must compile'):
+ new_model.evaluate(inp, out)
+ with self.assertRaisesRegexp(RuntimeError, 'must compile'):
+ new_model.train_on_batch(inp, out)
+ new_model.compile('rmsprop', 'mse')
+ new_model.train_on_batch(inp, out)
+
+ # Create new tensors for inputs and targets
+ input_a = keras.Input(shape=(4,))
+ target_a = keras.Input(shape=(4,))
+ new_model = models.clone_and_build_model(model, input_tensors=input_a,
+ target_tensors=[target_a],
+ compile_clone=True)
+ with self.assertRaisesRegexp(RuntimeError, 'must compile'):
+ new_model.evaluate(inp, out)
+ with self.assertRaisesRegexp(RuntimeError, 'must compile'):
+ new_model.train_on_batch(inp, out)
+ new_model.compile('rmsprop', 'mse')
+ new_model.train_on_batch(inp, out)
+
+ def _assert_same_compile_params(self, model):
+ """Assert that two models have the same compile parameters."""
+
+ self.assertEqual('mse', model.loss)
+ self.assertTrue(
+ isinstance(model.optimizer, keras.optimizers.RMSprop))
+ self.assertEqual(['acc', metrics.categorical_accuracy], model.metrics)
+
+ def _clone_and_build_test_helper(self, model, is_subclassed=False):
+ inp = np.random.random((10, 4))
+ out = np.random.random((10, 4))
+
+ # Everything should work in a new session.
+ keras.backend.clear_session()
+
+ with self.test_session():
+ # With placeholder creation
+ new_model = models.clone_and_build_model(
+ model, compile_clone=True, in_place_reset=is_subclassed)
+
+ self._assert_same_compile_params(new_model)
+ new_model.train_on_batch(inp, out)
+ new_model.evaluate(inp, out)
+
+ # Create new tensors for inputs and targets
+ input_a = keras.Input(shape=(4,), name='a')
+ new_model = models.clone_and_build_model(
+ model, input_tensors=input_a, compile_clone=True,
+ in_place_reset=is_subclassed)
+ self._assert_same_compile_params(new_model)
+ new_model.train_on_batch(inp, out)
+ new_model.evaluate(inp, out)
+
+ target_a = keras.Input(shape=(4,), name='b')
+ new_model = models.clone_and_build_model(
+ model, input_tensors=input_a, target_tensors=[target_a],
+ compile_clone=True, in_place_reset=is_subclassed)
+ self._assert_same_compile_params(new_model)
+ new_model.train_on_batch(inp, out)
+ new_model.evaluate(inp, out)
+
+ def test_clone_and_build_compiled_sequential_model(self):
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(4, input_shape=(4,)))
+ model.add(keras.layers.BatchNormalization())
+ model.add(keras.layers.Dropout(0.5))
+ model.add(keras.layers.Dense(4))
+ model.compile('rmsprop', 'mse',
+ metrics=['acc', metrics.categorical_accuracy])
+
+ self._clone_and_build_test_helper(model)
+
+ def test_clone_and_build_functional_model(self):
+ with self.test_session():
+ input_a = keras.Input(shape=(4,))
+ dense_1 = keras.layers.Dense(4,)
+ dense_2 = keras.layers.Dense(4,)
+
+ x_a = dense_1(input_a)
+ x_a = keras.layers.Dropout(0.5)(x_a)
+ x_a = keras.layers.BatchNormalization()(x_a)
+ x_a = dense_2(x_a)
+ model = keras.models.Model(input_a, x_a)
+ model.compile('rmsprop', 'mse',
+ metrics=['acc', metrics.categorical_accuracy])
+
+ self._clone_and_build_test_helper(model)
+
+ def test_clone_and_build_subclassed_model(self):
+ class SubclassedModel(keras.Model):
+
+ def __init__(self):
+ super(SubclassedModel, self).__init__()
+ self.layer1 = keras.layers.Dense(4)
+ self.layer2 = keras.layers.Dense(4)
+
+ def call(self, inp):
+ out = self.layer1(inp)
+ out = keras.layers.BatchNormalization()(out)
+ out = keras.layers.Dropout(0.5)(out)
+ out = self.layer2(out)
+ return out
+
+ with self.test_session():
+ model = SubclassedModel()
+ model.compile('rmsprop', 'mse',
+ metrics=['acc', metrics.categorical_accuracy])
+ self._clone_and_build_test_helper(model, True)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index 6e8ee06ff5..58405c550b 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -184,3 +184,22 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
# for further checks in the caller function
return actual_output
+
+def get_small_sequential_mlp(num_hidden, num_classes, input_dim=None):
+ model = keras.models.Sequential()
+ if input_dim:
+ model.add(keras.layers.Dense(num_hidden, activation='relu',
+ input_dim=input_dim))
+ else:
+ model.add(keras.layers.Dense(num_hidden, activation='relu'))
+ activation = 'sigmoid' if num_classes == 1 else 'softmax'
+ model.add(keras.layers.Dense(num_classes, activation=activation))
+ return model
+
+
+def get_small_functional_mlp(num_hidden, num_classes, input_dim):
+ inputs = keras.Input(shape=(input_dim,))
+ outputs = keras.layers.Dense(num_hidden, activation='relu')(inputs)
+ activation = 'sigmoid' if num_classes == 1 else 'softmax'
+ outputs = keras.layers.Dense(num_classes, activation=activation)(outputs)
+ return keras.Model(inputs, outputs)
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 9fe52f3d28..b9c5f26cb7 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -84,6 +84,25 @@ tf_py_test(
)
tf_py_test(
+ name = "batch_scatter_ops_test",
+ srcs = ["batch_scatter_ops_test.py"],
+ additional_deps = [
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:variables",
+ ],
+)
+
+tf_py_test(
name = "bcast_ops_test",
size = "small",
srcs = ["bcast_ops_test.py"],
@@ -645,7 +664,7 @@ cuda_py_test(
cuda_py_test(
name = "parameterized_truncated_normal_op_test",
- size = "small",
+ size = "medium",
srcs = ["parameterized_truncated_normal_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -747,6 +766,7 @@ tf_py_test(
size = "small",
srcs = ["regex_replace_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
@@ -960,6 +980,17 @@ tf_py_test(
)
tf_py_test(
+ name = "string_length_op_test",
+ size = "small",
+ srcs = ["string_length_op_test.py"],
+ additional_deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:string_ops",
+ ],
+)
+
+tf_py_test(
name = "string_strip_op_test",
size = "small",
srcs = ["string_strip_op_test.py"],
diff --git a/tensorflow/python/kernel_tests/batch_scatter_ops_test.py b/tensorflow/python/kernel_tests/batch_scatter_ops_test.py
new file mode 100644
index 0000000000..0d41a7e3b3
--- /dev/null
+++ b/tensorflow/python/kernel_tests/batch_scatter_ops_test.py
@@ -0,0 +1,129 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.tf.scatter."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def _AsType(v, vtype):
+ return v.astype(vtype) if isinstance(v, np.ndarray) else vtype(v)
+
+
+def _NumpyUpdate(ref, indices, updates):
+ for i, indx in np.ndenumerate(indices):
+ indx = i[:-1] + (indx,)
+ ref[indx] = updates[i]
+
+
+_TF_OPS_TO_NUMPY = {
+ state_ops.batch_scatter_update: _NumpyUpdate,
+}
+
+
+class ScatterTest(test.TestCase):
+
+ def _VariableRankTest(self,
+ tf_scatter,
+ vtype,
+ itype,
+ repeat_indices=False,
+ updates_are_scalar=False):
+ np.random.seed(8)
+ with self.test_session(use_gpu=False):
+ for indices_shape in (2,), (3, 7), (3, 4, 7):
+ for extra_shape in (), (5,), (5, 9):
+ # Generate random indices with no duplicates for easy numpy comparison
+ sparse_dim = len(indices_shape) - 1
+ indices = np.random.randint(
+ indices_shape[sparse_dim], size=indices_shape, dtype=itype)
+ updates = _AsType(
+ np.random.randn(*(indices_shape + extra_shape)), vtype)
+
+ old = _AsType(np.random.randn(*(indices_shape + extra_shape)), vtype)
+
+ # Scatter via numpy
+ new = old.copy()
+ np_scatter = _TF_OPS_TO_NUMPY[tf_scatter]
+ np_scatter(new, indices, updates)
+ # Scatter via tensorflow
+ ref = variables.Variable(old)
+ ref.initializer.run()
+ tf_scatter(ref, indices, updates).eval()
+ self.assertAllClose(ref.eval(), new)
+
+ def _VariableRankTests(self,
+ tf_scatter):
+ vtypes = [np.float32, np.float64]
+ if tf_scatter != state_ops.scatter_div:
+ vtypes.append(np.int32)
+
+ for vtype in vtypes:
+ for itype in (np.int32, np.int64):
+ self._VariableRankTest(tf_scatter, vtype, itype)
+
+ def testVariableRankUpdate(self):
+ vtypes = [np.float32, np.float64]
+ for vtype in vtypes:
+ for itype in (np.int32, np.int64):
+ self._VariableRankTest(
+ state_ops.batch_scatter_update, vtype, itype)
+
+ def testBooleanScatterUpdate(self):
+ with self.test_session(use_gpu=False) as session:
+ var = variables.Variable([True, False])
+ update0 = state_ops.batch_scatter_update(var, [1], [True])
+ update1 = state_ops.batch_scatter_update(
+ var, constant_op.constant(
+ [0], dtype=dtypes.int64), [False])
+ var.initializer.run()
+
+ session.run([update0, update1])
+
+ self.assertAllEqual([False, True], var.eval())
+
+ def testScatterOutOfRange(self):
+ params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
+ updates = np.array([-3, -4, -5]).astype(np.float32)
+ with self.test_session(use_gpu=False):
+ ref = variables.Variable(params)
+ ref.initializer.run()
+
+ # Indices all in range, no problem.
+ indices = np.array([2, 0, 5])
+ state_ops.batch_scatter_update(ref, indices, updates).eval()
+
+ # Test some out of range errors.
+ indices = np.array([-1, 0, 5])
+ with self.assertRaisesOpError(
+ r'indices\[0\] = \[-1\] does not index into shape \[6\]'):
+ state_ops.batch_scatter_update(ref, indices, updates).eval()
+
+ indices = np.array([2, 0, 6])
+ with self.assertRaisesOpError(r'indices\[2\] = \[6\] does not index into '
+ r'shape \[6\]'):
+ state_ops.batch_scatter_update(ref, indices, updates).eval()
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py
index bf82e08551..3193222262 100644
--- a/tensorflow/python/kernel_tests/list_ops_test.py
+++ b/tensorflow/python/kernel_tests/list_ops_test.py
@@ -421,6 +421,31 @@ class ListOpsTest(test_util.TensorFlowTestCase):
"Invalid data type at index 0"):
self.evaluate(list_ops.tensor_list_push_back_batch(l_batch, [3, 4]))
+ @test_util.run_in_graph_and_eager_modes
+ def testZerosLike(self):
+ for dtype in (dtypes.uint8, dtypes.uint16, dtypes.int8, dtypes.int16,
+ dtypes.int32, dtypes.int64, dtypes.float16, dtypes.float32,
+ dtypes.float64, dtypes.complex64, dtypes.complex128,
+ dtypes.bool):
+ l_empty = list_ops.empty_tensor_list(
+ element_dtype=dtype, element_shape=scalar_shape())
+ l_empty_zeros = array_ops.zeros_like(l_empty)
+ t_empty_zeros = list_ops.tensor_list_stack(
+ l_empty_zeros, element_dtype=dtype)
+
+ l_full = list_ops.tensor_list_push_back(l_empty,
+ math_ops.cast(0, dtype=dtype))
+ l_full = list_ops.tensor_list_push_back(l_full,
+ math_ops.cast(1, dtype=dtype))
+ l_full_zeros = array_ops.zeros_like(l_full)
+ t_full_zeros = list_ops.tensor_list_stack(
+ l_full_zeros, element_dtype=dtype)
+
+ self.assertAllEqual(self.evaluate(t_empty_zeros), [])
+ self.assertAllEqual(
+ self.evaluate(t_full_zeros), np.zeros(
+ (2,), dtype=dtype.as_numpy_dtype))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/parameterized_truncated_normal_op_test.py b/tensorflow/python/kernel_tests/parameterized_truncated_normal_op_test.py
index dd67919f69..e14894cf56 100644
--- a/tensorflow/python/kernel_tests/parameterized_truncated_normal_op_test.py
+++ b/tensorflow/python/kernel_tests/parameterized_truncated_normal_op_test.py
@@ -182,6 +182,19 @@ class ParameterizedTruncatedNormalTest(test.TestCase):
def testSmallStddev(self):
self.validateKolmogorovSmirnov([10**5], 0.0, 0.1, 0.05, 0.10)
+ def testSamplingWithSmallStdDevFarFromBound(self):
+ sample_op = random_ops.parameterized_truncated_normal(
+ shape=(int(1e5),), means=0.8, stddevs=0.05, minvals=-1., maxvals=1.)
+
+ with self.test_session(use_gpu=True) as sess:
+ samples = sess.run(sample_op)
+ # 0. is more than 16 standard deviations from the mean, and
+ # should have a likelihood < 1e-57.
+ # TODO(jjhunt) Sampler is still numerically unstable in this case,
+ # numbers less than 0 should never observed.
+ no_neg_samples = np.sum(samples < 0.)
+ self.assertLess(no_neg_samples, 2.)
+
# Benchmarking code
def parameterized_vs_naive(shape, num_iters, use_gpu=False):
diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py
index ba9359d923..1d0c2dceba 100644
--- a/tensorflow/python/kernel_tests/partitioned_variables_test.py
+++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -34,6 +36,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import saver as saver_lib
class PartitionerCreatorsTest(test.TestCase):
@@ -622,6 +625,38 @@ class PartitionedVariablesTestCase(test.TestCase):
variables.global_variables_initializer().run()
self.assertAllClose([-0.4, -0.4], x.eval())
+ def testMetaGraphSaveLoad(self):
+ save_prefix = os.path.join(self.get_temp_dir(), "ckpt")
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(
+ graph=save_graph) as session:
+ partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0)
+ with variable_scope.variable_scope("root", partitioner=partitioner):
+ v0 = variable_scope.get_variable(
+ "v0", dtype=dtypes.float32, shape=(10, 10))
+ v0_list = v0._get_variable_list()
+ v0_part = v0._get_partitions()
+ self.assertEqual(len(v0_list), 5)
+ self.assertAllEqual(v0_part, (5, 1))
+ variables.global_variables_initializer().run()
+
+ save_graph.get_collection_ref("partvar").append(v0)
+ saver = saver_lib.Saver()
+ save_graph.finalize()
+ save_path = saver.save(sess=session, save_path=save_prefix)
+ previous_value = session.run(
+ save_graph.get_tensor_by_name(v0.name + ":0"))
+
+ restore_graph = ops.Graph()
+ with restore_graph.as_default(), self.test_session(
+ graph=restore_graph) as session:
+ saver = saver_lib.import_meta_graph(save_path + ".meta")
+ saver.restore(sess=session, save_path=save_path)
+ v0, = save_graph.get_collection_ref("partvar")
+ self.assertIsInstance(v0, variables.PartitionedVariable)
+ self.assertAllEqual(
+ previous_value,
+ session.run(restore_graph.get_tensor_by_name(v0.name + ":0")))
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/regex_replace_op_test.py b/tensorflow/python/kernel_tests/regex_replace_op_test.py
index 6739ac3224..f0e84b8fca 100644
--- a/tensorflow/python/kernel_tests/regex_replace_op_test.py
+++ b/tensorflow/python/kernel_tests/regex_replace_op_test.py
@@ -18,54 +18,104 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
+
+from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
-class RegexReplaceOpTest(test.TestCase):
+@parameterized.parameters(
+ (gen_string_ops.regex_replace),
+ (gen_string_ops.static_regex_replace))
+class RegexReplaceOpVariantsTest(test.TestCase, parameterized.TestCase):
+
+ def testForwarding(self, op):
+ with self.test_session():
+ # Generate an input that is uniquely consumed by the regex op.
+ # This exercises code paths which are optimized for this case
+ # (e.g., using forwarding).
+ inp = string_ops.substr(
+ constant_op.constant(["AbCdEfG",
+ "HiJkLmN"], dtypes.string),
+ pos=0,
+ len=5)
+ stripped = op(inp, "\\p{Ll}", ".").eval()
+ self.assertAllEqual([b"A.C.E", b"H.J.L"], stripped)
- def testRemovePrefix(self):
+ def testRemovePrefix(self, op):
values = ["a:foo", "a:bar", "a:foo", "b:baz", "b:qux", "ca:b"]
with self.test_session():
input_vector = constant_op.constant(values, dtypes.string)
- stripped = string_ops.regex_replace(
- input_vector, "^(a:|b:)", "", replace_global=False).eval()
+ stripped = op(input_vector, "^(a:|b:)", "", replace_global=False).eval()
self.assertAllEqual([b"foo", b"bar", b"foo", b"baz", b"qux", b"ca:b"],
stripped)
- def testRegexReplace(self):
+ def testRegexReplace(self, op):
values = ["aba\naba", "abcdabcde"]
with self.test_session():
input_vector = constant_op.constant(values, dtypes.string)
- stripped = string_ops.regex_replace(input_vector, "a.*a", "(\\0)").eval()
+ stripped = op(input_vector, "a.*a", "(\\0)").eval()
self.assertAllEqual([b"(aba)\n(aba)", b"(abcda)bcde"], stripped)
- def testEmptyMatch(self):
+ def testEmptyMatch(self, op):
values = ["abc", "1"]
with self.test_session():
input_vector = constant_op.constant(values, dtypes.string)
- stripped = string_ops.regex_replace(input_vector, "", "x").eval()
+ stripped = op(input_vector, "", "x").eval()
self.assertAllEqual([b"xaxbxcx", b"x1x"], stripped)
- def testInvalidPattern(self):
+ def testInvalidPattern(self, op):
values = ["abc", "1"]
with self.test_session():
input_vector = constant_op.constant(values, dtypes.string)
invalid_pattern = "A["
- replace = string_ops.regex_replace(input_vector, invalid_pattern, "x")
+ replace = op(input_vector, invalid_pattern, "x")
with self.assertRaisesOpError("Invalid pattern"):
replace.eval()
- def testGlobal(self):
+ def testGlobal(self, op):
values = ["ababababab", "abcabcabc", ""]
with self.test_session():
input_vector = constant_op.constant(values, dtypes.string)
- stripped = string_ops.regex_replace(input_vector, "ab", "abc",
- True).eval()
+ stripped = op(input_vector, "ab", "abc", True).eval()
self.assertAllEqual([b"abcabcabcabcabc", b"abccabccabcc", b""], stripped)
+def as_string(s):
+ return s
+
+
+def as_tensor(s):
+ return constant_op.constant(s, dtypes.string)
+
+
+class RegexReplaceTest(test.TestCase, parameterized.TestCase):
+
+ @parameterized.parameters(
+ (as_string, as_tensor),
+ (as_tensor, as_string),
+ (as_tensor, as_tensor))
+ def testRegexReplaceDelegation(self, pattern_fn, rewrite_fn):
+ with compat.forward_compatibility_horizon(2018, 10, 11):
+ with self.test_session():
+ input_vector = constant_op.constant("foo", dtypes.string)
+ pattern = pattern_fn("[a-z]")
+ replace = rewrite_fn(".")
+ op = string_ops.regex_replace(input_vector, pattern, replace)
+ self.assertTrue(op.name.startswith("RegexReplace"))
+
+ def testStaticRegexReplaceDelegation(self):
+ with compat.forward_compatibility_horizon(2018, 10, 11):
+ with self.test_session():
+ input_vector = constant_op.constant("foo", dtypes.string)
+ pattern = "[a-z]"
+ replace = "."
+ op = string_ops.regex_replace(input_vector, pattern, replace)
+ self.assertTrue(op.name.startswith("StaticRegexReplace"))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index b1ef46f2a1..f815348b2a 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -106,6 +106,13 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
v = resource_variable_ops.ResourceVariable(False, name="bool_test")
self.assertAllEqual(bool(v), False)
+ @test_util.run_in_graph_and_eager_modes
+ def testStridedSliceAssign(self):
+ v = resource_variable_ops.ResourceVariable([1.0, 2.0])
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(v[0].assign(2.0))
+ self.assertAllEqual(self.evaluate(v), [2.0, 2.0])
+
def testDifferentAssignGraph(self):
with ops.Graph().as_default():
v = resource_variable_ops.ResourceVariable(1.0)
diff --git a/tensorflow/python/kernel_tests/softplus_op_test.py b/tensorflow/python/kernel_tests/softplus_op_test.py
index b8e7c50a37..c0269db9ae 100644
--- a/tensorflow/python/kernel_tests/softplus_op_test.py
+++ b/tensorflow/python/kernel_tests/softplus_op_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_ops
@@ -121,9 +122,12 @@ class SoftplusTest(test.TestCase):
print("softplus (float) third-order gradient err = ", err)
self.assertLess(err, 5e-5)
- def testWarnInts(self):
- # Running the op triggers address sanitizer errors, so we just make it
- nn_ops.softplus(constant_op.constant(7))
+ def testNoInts(self):
+ with self.test_session():
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "No OpKernel was registered to support Op 'Softplus'"):
+ nn_ops.softplus(constant_op.constant(7)).eval()
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/softsign_op_test.py b/tensorflow/python/kernel_tests/softsign_op_test.py
index 371f86ff15..a5247ce08d 100644
--- a/tensorflow/python/kernel_tests/softsign_op_test.py
+++ b/tensorflow/python/kernel_tests/softsign_op_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
@@ -65,11 +66,12 @@ class SoftsignTest(test.TestCase):
print("softsign (float) gradient err = ", err)
self.assertLess(err, 1e-4)
- def testWarnInts(self):
- # NOTE(irving): Actually I don't know how to intercept the warning, but
- # let's make sure it runs. I promised I've looked, and there was a warning.
+ def testNoInts(self):
with self.test_session():
- nn_ops.softsign(constant_op.constant(7)).eval()
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "No OpKernel was registered to support Op 'Softsign'"):
+ nn_ops.softsign(constant_op.constant(7)).eval()
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/string_length_op_test.py b/tensorflow/python/kernel_tests/string_length_op_test.py
new file mode 100644
index 0000000000..075a3204ad
--- /dev/null
+++ b/tensorflow/python/kernel_tests/string_length_op_test.py
@@ -0,0 +1,37 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for string_length_op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class StringLengthOpTest(test.TestCase):
+
+ def testStringLength(self):
+ strings = [[["1", "12"], ["123", "1234"], ["12345", "123456"]]]
+
+ with self.test_session() as sess:
+ lengths = string_ops.string_length(strings)
+ values = sess.run(lengths)
+ self.assertAllEqual(values, [[[1, 2], [3, 4], [5, 6]]])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/string_split_op_test.py b/tensorflow/python/kernel_tests/string_split_op_test.py
index e20daccb28..b6a0f45adc 100644
--- a/tensorflow/python/kernel_tests/string_split_op_test.py
+++ b/tensorflow/python/kernel_tests/string_split_op_test.py
@@ -58,14 +58,28 @@ class StringSplitOpTest(test.TestCase):
self.assertAllEqual(shape, [3, 5])
def testStringSplitEmptyToken(self):
- strings = [" hello ", "", "world "]
+ strings = ["", " a", "b ", " c", " ", " d ", " e", "f ", " g ", " "]
with self.test_session() as sess:
tokens = string_ops.string_split(strings)
indices, values, shape = sess.run(tokens)
- self.assertAllEqual(indices, [[0, 0], [2, 0]])
- self.assertAllEqual(values, [b"hello", b"world"])
- self.assertAllEqual(shape, [3, 1])
+ self.assertAllEqual(
+ indices,
+ [[1, 0], [2, 0], [3, 0], [5, 0], [6, 0], [7, 0], [8, 0]])
+ self.assertAllEqual(values, [b"a", b"b", b"c", b"d", b"e", b"f", b"g"])
+ self.assertAllEqual(shape, [10, 1])
+
+ def testStringSplitOnSetEmptyToken(self):
+ strings = ["", " a", "b ", " c", " ", " d ", ". e", "f .", " .g. ", " ."]
+
+ with self.test_session() as sess:
+ tokens = string_ops.string_split(strings, delimiter=" .")
+ indices, values, shape = sess.run(tokens)
+ self.assertAllEqual(
+ indices,
+ [[1, 0], [2, 0], [3, 0], [5, 0], [6, 0], [7, 0], [8, 0]])
+ self.assertAllEqual(values, [b"a", b"b", b"c", b"d", b"e", b"f", b"g"])
+ self.assertAllEqual(shape, [10, 1])
def testStringSplitWithDelimiter(self):
strings = ["hello|world", "hello world"]
diff --git a/tensorflow/python/kernel_tests/template_test.py b/tensorflow/python/kernel_tests/template_test.py
index 0b3a396d6b..9dcdaa61ed 100644
--- a/tensorflow/python/kernel_tests/template_test.py
+++ b/tensorflow/python/kernel_tests/template_test.py
@@ -25,6 +25,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@@ -359,6 +360,23 @@ class TemplateTest(test.TestCase):
self.assertEqual(2, len(tmpl1._checkpoint_dependencies))
self.assertEqual("nested", tmpl1._checkpoint_dependencies[0].name)
self.assertEqual("nested_1", tmpl1._checkpoint_dependencies[1].name)
+ model = training.Model()
+ model.template = tmpl1
+ self.assertEqual(model.variables, [v1, v2])
+ self.assertEqual(model.trainable_variables, [v1, v2])
+ self.assertEqual(len(model.non_trainable_variables), 0)
+ model.templates = [tmpl2]
+ self.assertEqual(model.variables, [v1, v2, v5, v6])
+ self.assertEqual(model.trainable_variables, [v1, v2, v5, v6])
+ self.assertEqual(len(model.non_trainable_variables), 0)
+ # Make sure losses, layers, and updates aren't broken by having a Template
+ # in the mix, which does not expose any updates or losses.
+ self.assertEqual([], model.layers)
+ self.assertEqual([], model.updates)
+ self.assertEqual([], model.losses)
+ self.assertEqual([], model.templates.layers)
+ self.assertEqual([], model.templates.updates)
+ self.assertEqual([], model.templates.losses)
@test_util.run_in_graph_and_eager_modes
def test_nested_templates_with_defun(self):
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index ab08865532..3ba880d7a1 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -262,11 +262,13 @@ class Layer(base_layer.Layer):
use_resource = (use_resource or
self._use_resource_variables or
scope.use_resource)
+ if initializer is None:
+ initializer = scope.initializer
variable = super(Layer, self).add_weight(
name,
shape,
dtype=dtypes.as_dtype(dtype),
- initializer=initializer or scope.initializer,
+ initializer=initializer,
trainable=trainable,
constraint=constraint,
partitioner=partitioner,
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc
index fc02d6de0e..6189503d8f 100644
--- a/tensorflow/python/lib/core/py_func.cc
+++ b/tensorflow/python/lib/core/py_func.cc
@@ -398,7 +398,7 @@ Status ConvertNdarrayToTensor(PyObject* obj, Tensor* ret) {
TF_RETURN_IF_ERROR(NumericNpDTypeToTfDType(PyArray_TYPE(input), &dtype));
CHECK(DataTypeCanUseMemcpy(dtype));
if (reinterpret_cast<intptr_t>(PyArray_DATA(input)) %
- EIGEN_MAX_ALIGN_BYTES !=
+ std::max(1, EIGEN_MAX_ALIGN_BYTES) !=
0) {
Tensor t(dtype, shape);
StringPiece p = t.tensor_data();
diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc
index 3c64813735..e4e5268b0f 100644
--- a/tensorflow/python/lib/io/py_record_writer.cc
+++ b/tensorflow/python/lib/io/py_record_writer.cc
@@ -52,10 +52,17 @@ PyRecordWriter::~PyRecordWriter() {
file_.reset();
}
-bool PyRecordWriter::WriteRecord(tensorflow::StringPiece record) {
- if (writer_ == nullptr) return false;
+void PyRecordWriter::WriteRecord(tensorflow::StringPiece record,
+ TF_Status* out_status) {
+ if (writer_ == nullptr) {
+ TF_SetStatus(out_status, TF_FAILED_PRECONDITION,
+ "Writer not initialized or previously closed");
+ return;
+ }
Status s = writer_->WriteRecord(record);
- return s.ok();
+ if (!s.ok()) {
+ Set_TF_Status_from_Status(out_status, s);
+ }
}
void PyRecordWriter::Flush(TF_Status* out_status) {
diff --git a/tensorflow/python/lib/io/py_record_writer.h b/tensorflow/python/lib/io/py_record_writer.h
index 9d66c031d4..61a4960ee6 100644
--- a/tensorflow/python/lib/io/py_record_writer.h
+++ b/tensorflow/python/lib/io/py_record_writer.h
@@ -43,7 +43,7 @@ class PyRecordWriter {
TF_Status* out_status);
~PyRecordWriter();
- bool WriteRecord(tensorflow::StringPiece record);
+ void WriteRecord(tensorflow::StringPiece record, TF_Status* out_status);
void Flush(TF_Status* out_status);
void Close(TF_Status* out_status);
diff --git a/tensorflow/python/lib/io/python_io.py b/tensorflow/python/lib/io/python_io.py
index aec12ab3ea..404423ce07 100644
--- a/tensorflow/python/lib/io/python_io.py
+++ b/tensorflow/python/lib/io/python_io.py
@@ -15,7 +15,7 @@
"""Python functions for directly manipulating TFRecord-formatted files.
-See the @{$python/python_io} guide.
+See the [Python IO](https://tensorflow.org/api_guides/python/python_io) guide.
"""
from __future__ import absolute_import
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
index 941d6cd67c..2b3e986f6b 100644
--- a/tensorflow/python/lib/io/tf_record.py
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -125,8 +125,8 @@ class TFRecordWriter(object):
Args:
record: str
"""
- # TODO(sethtroisi): Failures are currently swallowed, change that.
- self._writer.WriteRecord(record)
+ with errors.raise_exception_on_not_ok_status() as status:
+ self._writer.WriteRecord(record, status)
def flush(self):
"""Flush the file."""
diff --git a/tensorflow/python/lib/io/tf_record_test.py b/tensorflow/python/lib/io/tf_record_test.py
index 4743c037ec..b853b64ae4 100644
--- a/tensorflow/python/lib/io/tf_record_test.py
+++ b/tensorflow/python/lib/io/tf_record_test.py
@@ -358,12 +358,12 @@ class TFRecordWriterCloseAndFlushTests(test.TestCase):
with self.assertRaises(errors_impl.FailedPreconditionError):
self._writer.flush()
- def testWriteAfterClose(self):
+ def testWriteAfterCloseIsError(self):
self._writer.write(self._Record(0))
self._writer.close()
- # TODO(sethtroisi): No way to know this failed, changed that.
- self._writer.write(self._Record(1))
+ with self.assertRaises(errors_impl.FailedPreconditionError):
+ self._writer.write(self._Record(1))
class TFRecordWriterCloseAndFlushGzipTests(TFRecordWriterCloseAndFlushTests):
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 4b096cb73d..66bc4df18c 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -15,7 +15,7 @@
# Tests for this file live in python/kernel_tests/array_ops_test.py
"""Support for manipulating tensors.
-See the @{$python/array_ops} guide.
+See the [Array Ops](https://tensorflow.org/api_guides/python/array_ops) guide.
"""
from __future__ import absolute_import
@@ -712,10 +712,7 @@ def strided_slice(input_,
new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask)
- if not context.executing_eagerly():
- # TODO(apassos) In eager mode assignment will be done by overriding
- # __setitem__ instead.
- op.assign = assign
+ op.assign = assign
return op
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index 375a5ec2c3..c5a0f2949e 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -15,7 +15,8 @@
# pylint: disable=g-short-docstring-punctuation
"""Asserts and Boolean Checks.
-See the @{$python/check_ops} guide.
+See the [Asserts and
+checks](https://tensorflow.org/api_guides/python/check_ops) guide.
"""
from __future__ import absolute_import
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index f84ff4ddf0..d1095c8954 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -14,7 +14,8 @@
# ==============================================================================
"""Control Flow Operations.
-See the @{$python/control_flow_ops} guide.
+See the [Control
+Flow](https://tensorflow.org/api_guides/python/control_flow_ops) guide.
"""
# pylint: disable=g-bad-name
from __future__ import absolute_import
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index 4ecc74675a..a6be82673f 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -15,7 +15,8 @@
"""Functional operations.
-See the @{$python/functional_ops} guide.
+See the [Higher Order
+Functions](https://tensorflow.org/api_guides/python/functional_ops) guide.
"""
from __future__ import absolute_import
diff --git a/tensorflow/python/ops/histogram_ops.py b/tensorflow/python/ops/histogram_ops.py
index e86a8e5a5b..7291e05685 100644
--- a/tensorflow/python/ops/histogram_ops.py
+++ b/tensorflow/python/ops/histogram_ops.py
@@ -14,8 +14,6 @@
# ==============================================================================
# pylint: disable=g-short-docstring-punctuation
"""Histograms.
-
-Please see @{$python/histogram_ops} guide.
"""
from __future__ import absolute_import
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index 343531ac55..3de46e7cf3 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -16,7 +16,7 @@
# pylint: disable=g-short-docstring-punctuation
"""Image processing and decoding ops.
-See the @{$python/image} guide.
+See the [Images](https://tensorflow.org/api_guides/python/image) guide.
"""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index b5274ef2ed..fbc1350c61 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -16,7 +16,8 @@
# pylint: disable=line-too-long
"""Inputs and Readers.
-See the @{$python/io_ops} guide.
+See the [Inputs and
+Readers](https://tensorflow.org/api_guides/python/io_ops) guide.
"""
from __future__ import absolute_import
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 2a7a2fd51f..8e11c4bce1 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -972,9 +972,9 @@ def _RealDivGrad(op, grad):
grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy))
-@ops.RegisterGradient("UnsafeDiv")
-def _UnsafeDivGrad(op, grad):
- """UnsafeDiv op gradient."""
+@ops.RegisterGradient("DivNoNan")
+def _DivNoNanGrad(op, grad):
+ """DivNoNan op gradient."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
@@ -983,10 +983,10 @@ def _UnsafeDivGrad(op, grad):
x = math_ops.conj(x)
y = math_ops.conj(y)
return (array_ops.reshape(
- math_ops.reduce_sum(math_ops.unsafe_div(grad, y), rx), sx),
+ math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
- grad * math_ops.unsafe_div(math_ops.unsafe_div(-x, y), y),
+ grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y),
ry), sy))
diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py
index f9bb60e7fe..059c8ebd7e 100644
--- a/tensorflow/python/ops/math_grad_test.py
+++ b/tensorflow/python/ops/math_grad_test.py
@@ -231,11 +231,12 @@ class FloorModGradientTest(test.TestCase):
self.assertLess(error, 1e-4)
-class UnsafeDivGradientTest(test.TestCase):
+class DivNoNanGradientTest(test.TestCase):
def testBasicGradient(self):
- inputs = constant_op.constant(np.arange(-3, 3), dtype=dtypes.float32)
- outputs = math_ops.unsafe_div(inputs, 1 + math_ops.abs(inputs))
+ inputs = constant_op.constant(np.arange(-3, 3),
+ dtype=dtypes.float32)
+ outputs = math_ops.div_no_nan(inputs, 1 + math_ops.abs(inputs))
with self.test_session():
error = gradient_checker.compute_gradient_error(
inputs,
@@ -244,9 +245,11 @@ class UnsafeDivGradientTest(test.TestCase):
self.assertLess(error, 1e-4)
def testGradientWithDenominatorIsZero(self):
- x = constant_op.constant(np.arange(-3, 3), dtype=dtypes.float32)
- y = array_ops.zeros_like(x, dtype=dtypes.float32)
- outputs = math_ops.unsafe_div(x, y)
+ x = constant_op.constant(np.arange(-3, 3),
+ dtype=dtypes.float32)
+ y = array_ops.zeros_like(x,
+ dtype=dtypes.float32)
+ outputs = math_ops.div_no_nan(x, y)
with self.test_session():
dx, dy = gradients.gradients(outputs, [x, y])
self.assertAllClose(dx.eval(), np.zeros(x.shape.as_list()))
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index c9da1a0bba..67ea534639 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""Basic arithmetic operators.
-See the @{$python/math_ops} guide.
+See the [python/math_ops](python/math_ops) guide.
"""
from __future__ import absolute_import
from __future__ import division
@@ -1038,29 +1038,27 @@ def div(x, y, name=None):
return _div_python2(x, y, name)
-def unsafe_div(x, y, name=None):
+@tf_export("div_no_nan")
+def div_no_nan(x, y, name=None):
"""Computes an unsafe divide which returns 0 if the y is zero.
- Note that the function uses Python 3 division operator semantics.
-
Args:
- x: A `Tensor`. Must be one of the following types:
- `float32`, `float64`, `int16`, `int32`, `int64`.
+ x: A `Tensor`. Must be one of the following types: `float32`, `float64`.
y: A `Tensor` whose dtype is compatible with `x`.
name: A name for the operation (optional).
Returns:
The element-wise value of the x divided by y.
"""
- with ops.name_scope(name, "unsafe_div", [x, y]) as name:
+ with ops.name_scope(name, "div_no_nan", [x, y]) as name:
x = ops.convert_to_tensor(x, name="x")
y = ops.convert_to_tensor(y, name="y", dtype=x.dtype.base_dtype)
x_dtype = x.dtype.base_dtype
y_dtype = y.dtype.base_dtype
if x_dtype != y_dtype:
- raise TypeError(
- "x and y must have the same dtype, got %r != %r" % (x_dtype, y_dtype))
- return gen_math_ops.unsafe_div(x, y, name=name)
+ raise TypeError("x and y must have the same dtype, got %r != %r" %
+ (x_dtype, y_dtype))
+ return gen_math_ops.div_no_nan(x, y, name=name)
# TODO(aselle): This should be removed
@@ -2564,8 +2562,9 @@ def _unsorted_segment_N(data, segment_ids, num_segments):
def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
r""" Computes the mean along segments of a tensor.
- Read @{$math_ops#segmentation$the section on segmentation} for an explanation
- of segments.
+ Read [the section on
+ segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+ for an explanation of segments.
This operator is similar to the unsorted segment sum operator found
[here](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
@@ -2596,8 +2595,9 @@ def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None):
r"""Computes the sum along segments of a tensor divided by the sqrt(N).
- Read @{$math_ops#segmentation$the section on segmentation} for an explanation
- of segments.
+ Read [the section on
+ segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+ for an explanation of segments.
This operator is similar to the unsorted segment sum operator found
[here](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
@@ -2632,8 +2632,9 @@ def sparse_segment_sum(data, indices, segment_ids, name=None,
num_segments=None):
r"""Computes the sum along sparse segments of a tensor.
- Read @{$math_ops#Segmentation$the section on segmentation} for an explanation
- of segments.
+ Read [the section on
+ segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+ for an explanation of segments.
Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
dimension, selecting a subset of dimension 0, specified by `indices`.
@@ -2707,8 +2708,9 @@ def sparse_segment_mean(data,
num_segments=None):
r"""Computes the mean along sparse segments of a tensor.
- Read @{$math_ops#Segmentation$the section on segmentation} for an explanation
- of segments.
+ Read [the section on
+ segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+ for an explanation of segments.
Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
dimension, selecting a subset of dimension 0, specified by `indices`.
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 5fe7bbca11..5ac7e133d9 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -473,18 +473,19 @@ class DivAndModTest(test_util.TensorFlowTestCase):
self.assertAllEqual(tf_result, expanded_nums)
-class UnsafeDivTest(test_util.TensorFlowTestCase):
+class DivNoNanTest(test_util.TensorFlowTestCase):
def testBasic(self):
- nums = np.arange(-10, 10, .25).reshape(80, 1)
- divs = np.arange(-3, 3, .25).reshape(1, 24)
+ for dtype in [np.float32, np.float64]:
+ nums = np.arange(-10, 10, .25, dtype=dtype).reshape(80, 1)
+ divs = np.arange(-3, 3, .25, dtype=dtype).reshape(1, 24)
- np_result = np.true_divide(nums, divs)
- np_result[:, divs[0] == 0] = 0
+ np_result = np.true_divide(nums, divs)
+ np_result[:, divs[0] == 0] = 0
- with self.test_session():
- tf_result = math_ops.unsafe_div(nums, divs).eval()
- self.assertAllEqual(tf_result, np_result)
+ with self.test_session():
+ tf_result = math_ops.div_no_nan(nums, divs).eval()
+ self.assertAllEqual(tf_result, np_result)
if __name__ == "__main__":
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 9461a01515..763877c2d2 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -301,6 +301,40 @@ def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None):
return total_cm, update_op
+def _aggregate_across_towers(metrics_collections, metric_value_fn, *args):
+ """Aggregate metric value across towers."""
+ def fn(distribution, *a):
+ """Call `metric_value_fn` in the correct control flow context."""
+ if hasattr(distribution, '_outer_control_flow_context'):
+ # If there was an outer context captured before this method was called,
+ # then we enter that context to create the metric value op. If the
+ # caputred context is `None`, ops.control_dependencies(None) gives the
+ # desired behavior. Else we use `Enter` and `Exit` to enter and exit the
+ # captured context.
+ # This special handling is needed because sometimes the metric is created
+ # inside a while_loop (and perhaps a TPU rewrite context). But we don't
+ # want the value op to be evaluated every step or on the TPU. So we
+ # create it outside so that it can be evaluated at the end on the host,
+ # once the update ops have been evaluted.
+
+ # pylint: disable=protected-access
+ if distribution._outer_control_flow_context is None:
+ with ops.control_dependencies(None):
+ metric_value = metric_value_fn(distribution, *a)
+ else:
+ distribution._outer_control_flow_context.Enter()
+ metric_value = metric_value_fn(distribution, *a)
+ distribution._outer_control_flow_context.Exit()
+ # pylint: enable=protected-access
+ else:
+ metric_value = metric_value_fn(distribution, *a)
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, metric_value)
+ return metric_value
+
+ return distribution_strategy_context.get_tower_context().merge_call(fn, *args)
+
+
@tf_export('metrics.mean')
def mean(values,
weights=None,
@@ -368,14 +402,10 @@ def mean(values,
with ops.control_dependencies([values]):
update_count_op = state_ops.assign_add(count, num_values)
- def aggregate_across_towers(_, t, c):
- mean_t = _safe_div(t, c, 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_t)
- return mean_t
+ compute_mean = lambda _, t, c: _safe_div(t, c, 'value')
- mean_t = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_across_towers, total, count)
+ mean_t = _aggregate_across_towers(
+ metrics_collections, compute_mean, total, count)
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
if updates_collections:
@@ -612,14 +642,8 @@ def _confusion_matrix_at_thresholds(labels,
def _aggregate_variable(v, collections):
-
- def f(distribution, value):
- value = distribution.read_var(value)
- if collections:
- ops.add_to_collections(collections, value)
- return value
-
- return distribution_strategy_context.get_tower_context().merge_call(f, v)
+ f = lambda distribution, value: distribution.read_var(value)
+ return _aggregate_across_towers(collections, f, v)
@tf_export('metrics.auc')
@@ -807,15 +831,12 @@ def auc(labels,
raise ValueError('Invalid summation_method: %s' % summation_method)
# sum up the areas of all the trapeziums
- def aggregate_auc(_, values):
- auc_value = compute_auc(values['tp'], values['fn'], values['tn'],
- values['fp'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, auc_value)
- return auc_value
-
- auc_value = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_auc, values)
+ def compute_auc_value(_, values):
+ return compute_auc(values['tp'], values['fn'], values['tn'], values['fp'],
+ 'value')
+
+ auc_value = _aggregate_across_towers(
+ metrics_collections, compute_auc_value, values)
update_op = compute_auc(update_ops['tp'], update_ops['fn'],
update_ops['tn'], update_ops['fp'], 'update_op')
@@ -1046,16 +1067,14 @@ def mean_per_class_accuracy(labels,
update_total_op = state_ops.scatter_add(total, labels, ones)
update_count_op = state_ops.scatter_add(count, labels, is_correct)
- def aggregate_mean_accuracy(_, count, total):
+ def compute_mean_accuracy(_, count, total):
per_class_accuracy = _safe_div(count, total, None)
mean_accuracy_v = math_ops.reduce_mean(
per_class_accuracy, name='mean_accuracy')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_accuracy_v)
return mean_accuracy_v
- mean_accuracy_v = distribution_strategy_context.get_tower_context(
- ).merge_call(aggregate_mean_accuracy, count, total)
+ mean_accuracy_v = _aggregate_across_towers(
+ metrics_collections, compute_mean_accuracy, count, total)
update_op = _safe_div(update_count_op, update_total_op, name='update_op')
if updates_collections:
@@ -1128,7 +1147,7 @@ def mean_iou(labels,
total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
num_classes, weights)
- def compute_mean_iou(total_cm, name):
+ def compute_mean_iou(_, total_cm):
"""Compute the mean intersection-over-union via the confusion matrix."""
sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1))
@@ -1152,17 +1171,12 @@ def mean_iou(labels,
# If the number of valid entries is 0 (no classes) we return 0.
result = array_ops.where(
math_ops.greater(num_valid_entries, 0),
- math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0)
+ math_ops.reduce_sum(iou, name='mean_iou') / num_valid_entries, 0)
return result
- def mean_iou_across_towers(_, v):
- mean_iou_v = compute_mean_iou(v, 'mean_iou')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_iou_v)
- return mean_iou_v
-
- mean_iou_v = distribution_strategy_context.get_tower_context().merge_call(
- mean_iou_across_towers, total_cm)
+ # TODO(priyag): Use outside_compilation if in TPU context.
+ mean_iou_v = _aggregate_across_towers(
+ metrics_collections, compute_mean_iou, total_cm)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -1371,14 +1385,10 @@ def mean_tensor(values,
with ops.control_dependencies([values]):
update_count_op = state_ops.assign_add(count, num_values)
- def aggregate_across_towers(_, t, c):
- mean_t = _safe_div(t, c, 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_t)
- return mean_t
+ compute_mean = lambda _, t, c: _safe_div(t, c, 'value')
- mean_t = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_across_towers, total, count)
+ mean_t = _aggregate_across_towers(
+ metrics_collections, compute_mean, total, count)
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
if updates_collections:
@@ -2004,13 +2014,10 @@ def precision(labels,
math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name)
def once_across_towers(_, true_p, false_p):
- p = compute_precision(true_p, false_p, 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, p)
- return p
+ return compute_precision(true_p, false_p, 'value')
- p = distribution_strategy_context.get_tower_context().merge_call(
- once_across_towers, true_p, false_p)
+ p = _aggregate_across_towers(metrics_collections, once_across_towers,
+ true_p, false_p)
update_op = compute_precision(true_positives_update_op,
false_positives_update_op, 'update_op')
@@ -2088,13 +2095,10 @@ def precision_at_thresholds(labels,
return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
def precision_across_towers(_, values):
- prec = compute_precision(values['tp'], values['fp'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, prec)
- return prec
+ return compute_precision(values['tp'], values['fp'], 'value')
- prec = distribution_strategy_context.get_tower_context().merge_call(
- precision_across_towers, values)
+ prec = _aggregate_across_towers(
+ metrics_collections, precision_across_towers, values)
update_op = compute_precision(update_ops['tp'], update_ops['fp'],
'update_op')
@@ -2184,13 +2188,10 @@ def recall(labels,
math_ops.div(true_p, true_p + false_n), 0, name)
def once_across_towers(_, true_p, false_n):
- rec = compute_recall(true_p, false_n, 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, rec)
- return rec
+ return compute_recall(true_p, false_n, 'value')
- rec = distribution_strategy_context.get_tower_context().merge_call(
- once_across_towers, true_p, false_n)
+ rec = _aggregate_across_towers(
+ metrics_collections, once_across_towers, true_p, false_n)
update_op = compute_recall(true_positives_update_op,
false_negatives_update_op, 'update_op')
@@ -2622,14 +2623,11 @@ def recall_at_top_k(labels,
class_id=class_id,
weights=weights)
- def aggregate_across_towers(_, tp, fn):
- metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
- if metrics_collections:
- ops.add_to_collections(metrics_collections, metric)
- return metric
+ def compute_recall(_, tp, fn):
+ return math_ops.div(tp, math_ops.add(tp, fn), name=scope)
- metric = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_across_towers, tp, fn)
+ metric = _aggregate_across_towers(
+ metrics_collections, compute_recall, tp, fn)
update = math_ops.div(
tp_update, math_ops.add(tp_update, fn_update), name='update')
@@ -2704,13 +2702,10 @@ def recall_at_thresholds(labels,
return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
def recall_across_towers(_, values):
- rec = compute_recall(values['tp'], values['fn'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, rec)
- return rec
+ return compute_recall(values['tp'], values['fn'], 'value')
- rec = distribution_strategy_context.get_tower_context().merge_call(
- recall_across_towers, values)
+ rec = _aggregate_across_towers(
+ metrics_collections, recall_across_towers, values)
update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
if updates_collections:
@@ -2778,14 +2773,9 @@ def root_mean_squared_error(labels,
mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
None, name or
'root_mean_squared_error')
- def once_across_towers(_, mse):
- rmse = math_ops.sqrt(mse)
- if metrics_collections:
- ops.add_to_collections(metrics_collections, rmse)
- return rmse
- rmse = distribution_strategy_context.get_tower_context().merge_call(
- once_across_towers, mse)
+ once_across_towers = lambda _, mse: math_ops.sqrt(mse)
+ rmse = _aggregate_across_towers(metrics_collections, once_across_towers, mse)
update_rmse_op = math_ops.sqrt(update_mse_op)
if updates_collections:
@@ -2880,15 +2870,12 @@ def sensitivity_at_specificity(labels,
return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon,
name)
- def aggregate_across_towers(_, values):
- sensitivity = compute_sensitivity_at_specificity(
+ def sensitivity_across_towers(_, values):
+ return compute_sensitivity_at_specificity(
values['tp'], values['tn'], values['fp'], values['fn'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, sensitivity)
- return sensitivity
- sensitivity = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_across_towers, values)
+ sensitivity = _aggregate_across_towers(
+ metrics_collections, sensitivity_across_towers, values)
update_op = compute_sensitivity_at_specificity(
update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
@@ -3157,14 +3144,11 @@ def _streaming_sparse_average_precision_at_top_k(labels,
total_update = state_ops.assign_add(total_var, batch_total, name='update')
# Divide total by max to get mean, for both vars and the update ops.
- def aggregate_across_towers(_, total_var, max_var):
- mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_average_precision)
- return mean_average_precision
+ def precision_across_towers(_, total_var, max_var):
+ return _safe_scalar_div(total_var, max_var, name='mean')
- mean_average_precision = distribution_strategy_context.get_tower_context(
- ).merge_call(aggregate_across_towers, total_var, max_var)
+ mean_average_precision = _aggregate_across_towers(
+ metrics_collections, precision_across_towers, total_var, max_var)
update = _safe_scalar_div(total_update, max_update, name=scope)
if updates_collections:
@@ -3443,14 +3427,11 @@ def precision_at_top_k(labels,
class_id=class_id,
weights=weights)
- def aggregate_across_towers(_, tp, fp):
- metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
- if metrics_collections:
- ops.add_to_collections(metrics_collections, metric)
- return metric
+ def precision_across_towers(_, tp, fp):
+ return math_ops.div(tp, math_ops.add(tp, fp), name=scope)
- metric = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_across_towers, tp, fp)
+ metric = _aggregate_across_towers(
+ metrics_collections, precision_across_towers, tp, fp)
update = math_ops.div(
tp_update, math_ops.add(tp_update, fp_update), name='update')
@@ -3681,15 +3662,12 @@ def specificity_at_sensitivity(labels,
return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon,
name)
- def aggregate_across_towers(_, values):
- specificity = compute_specificity_at_sensitivity(
+ def specificity_across_towers(_, values):
+ return compute_specificity_at_sensitivity(
values['tp'], values['tn'], values['fp'], values['fn'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, specificity)
- return specificity
- specificity = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_across_towers, values)
+ specificity = _aggregate_across_towers(
+ metrics_collections, specificity_across_towers, values)
update_op = compute_specificity_at_sensitivity(
update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 339684122e..4b73fc830e 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -16,7 +16,7 @@
# pylint: disable=unused-import,g-bad-import-order
"""Neural network support.
-See the @{$python/nn} guide.
+See the [Neural network](https://tensorflow.org/api_guides/python/nn) guide.
"""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index df23ac55ce..a648653909 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -471,7 +471,9 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
softmax = nn_ops.softmax(logits)
grad += ((grad_grad - array_ops.squeeze(
- math_ops.matmul(grad_grad[:, None, :], softmax[:, :, None]), axis=1)) *
+ math_ops.matmul(array_ops.expand_dims(grad_grad, 1),
+ array_ops.expand_dims(softmax, 2)),
+ axis=1)) *
softmax)
return grad, _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits))
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 51f812b395..2a1919e66f 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -1210,7 +1210,9 @@ def nce_loss(weights,
num_true]`. The target classes.
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
activations of the input network.
- num_sampled: An `int`. The number of classes to randomly sample per batch.
+ num_sampled: An `int`. The number of negative classes to randomly sample
+ per batch. This single sample of negative classes is evaluated for each
+ element in the batch.
num_classes: An `int`. The number of possible classes.
num_true: An `int`. The number of target classes per training example.
sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 6fd1273687..edc6e04b48 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -698,7 +698,7 @@ def convolution(
`padded_input` is obtained by zero padding the input using an effective
spatial filter shape of `(spatial_filter_shape-1) * dilation_rate + 1` and
output striding `strides` as described in the
- @{$python/nn#Convolution$comment here}.
+ [comment here](https://tensorflow.org/api_guides/python/nn#Convolution).
In the case that `data_format` does start with `"NC"`, the `input` and output
(but not the `filter`) are simply transposed as follows:
@@ -1836,8 +1836,9 @@ def softmax_cross_entropy_with_logits_v2(
name: A name for the operation (optional).
Returns:
- A `Tensor` of the same shape as `labels` and of the same type as `logits`
- with the softmax cross entropy loss.
+ A `Tensor` that contains the softmax cross entropy loss. Its type is the
+ same as `logits` and its shape is the same as `labels` except that it does
+ not have the last dimension of `labels`.
"""
_ensure_xent_args("softmax_cross_entropy_with_logits", _sentinel, labels,
logits)
@@ -1962,8 +1963,9 @@ def softmax_cross_entropy_with_logits(
name: A name for the operation (optional).
Returns:
- A `Tensor` of the same shape as `labels` and of the same type as `logits`
- with the softmax cross entropy loss.
+ A `Tensor` that contains the softmax cross entropy loss. Its type is the
+ same as `logits` and its shape is the same as `labels` except that it does
+ not have the last dimension of `labels`.
"""
_ensure_xent_args("softmax_cross_entropy_with_logits", _sentinel, labels,
logits)
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index d533731c07..3d0205f768 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -94,26 +94,8 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
ops.set_shape_and_handle_data_for_outputs(h.op)
handle._handle_data = h._handle_data
# pylint: enable=protected-access
-
- # Clean up our reference cycles to avoid making the garbage collector run.
- # pylint: disable=protected-access
- # OrderedDict, constructed on Graph creation, makes a simple reference loop
- # and hides it in an __attribute in some Python versions. We don't need to
- # throw an error if we can't find it, but if we do find it we can break the
- # loop to avoid creating work for the garbage collector.
- problematic_cycle = graph._functions.__dict__.get("_OrderedDict__root", None)
- # pylint: enable=protected-access
- if problematic_cycle:
- try:
- del problematic_cycle[0][:]
- except TypeError:
- # This is probably not one of the problematic Python versions. Continue
- # with the rest of our cleanup.
- pass
- # Now clean up our own reference cycles by clearing all of the attributes for
- # the Graph and op we created.
- h.__dict__ = {}
- graph.__dict__ = {}
+ # Clean up op->graph->op reference cycles.
+ ops.dismantle_graph(graph)
return handle
@@ -185,7 +167,8 @@ def shape_safe_assign_variable_handle(handle, shape, value, name=None):
class ResourceVariable(variables.RefVariable):
"""Variable based on resource handles.
- See the @{$variables$Variables How To} for a high level overview.
+ See the [Variables How To](https://tensorflow.org/guide/variables)
+ for a high level overview.
A `ResourceVariable` allows you to maintain state across subsequent calls to
session.run.
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index d11e446dbf..8d66de6b20 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Script Language Operators. See the @{$python/script_ops} guide."""
+"""Script Language Operators."""
# pylint: disable=g-bad-name
from __future__ import absolute_import
@@ -343,7 +343,8 @@ def eager_py_func(func, inp, Tout, name=None):
or print statements as desired, and wrap those functions in
`tf.contrib.eager.py_func`.
- For more information on eager execution, see @{$guide/eager}.
+ For more information on eager execution, see the
+ [Eager guide](https://tensorflow.org/guide/eager).
`tf.contrib.eager.py_func` is similar in spirit to `tf.py_func`, but unlike
the latter, the former lets you use TensorFlow operations in the wrapped
diff --git a/tensorflow/python/ops/session_ops.py b/tensorflow/python/ops/session_ops.py
index dee84bab0c..e229501c10 100644
--- a/tensorflow/python/ops/session_ops.py
+++ b/tensorflow/python/ops/session_ops.py
@@ -13,7 +13,11 @@
# limitations under the License.
# ==============================================================================
-"""Tensor Handle Operations. See the @{$python/session_ops} guide."""
+"""Tensor Handle Operations.
+
+See the [Session Ops](https://tensorflow.org/api_guides/python/session_ops)
+guide.
+"""
# pylint: disable=g-bad-name
from __future__ import absolute_import
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index fd547dcb19..e91813b4a8 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -14,7 +14,10 @@
# ==============================================================================
# pylint: disable=g-short-docstring-punctuation
-"""Sparse Tensor Representation. See the @{$python/sparse_ops} guide."""
+"""Sparse Tensor Representation.
+
+See the [Sparse Ops](https://tensorflow.org/api_guides/python/sparse_ops) guide.
+"""
from __future__ import absolute_import
from __future__ import division
@@ -796,6 +799,11 @@ def sparse_reduce_max(sp_input, axis=None, keepdims=None,
with a single element is returned. Additionally, the axes can be negative,
similar to the indexing rules in Python.
+ The values not defined in `sp_input` don't participate in the reduce max,
+ as opposed to be implicitly assumed 0 -- hence it can return negative values
+ for sparse `reduction_axes`. But, in case there are no values in
+ `reduction_axes`, it will reduce to 0. See second example below.
+
For example:
```python
@@ -807,6 +815,11 @@ def sparse_reduce_max(sp_input, axis=None, keepdims=None,
tf.sparse_reduce_max(x, 1) ==> [2, 3] # Can also use -1 as the axis.
tf.sparse_reduce_max(x, 1, keepdims=True) ==> [[2], [3]]
tf.sparse_reduce_max(x, [0, 1]) ==> 3
+
+ # 'y' represents [[-7, ?]
+ # [ 4, 3]
+ # [ ?, ?]
+ tf.sparse_reduce_max(x, 1) ==> [-7, 4, 0]
```
Args:
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index d556d11a1b..d1573fbbf2 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -13,7 +13,10 @@
# limitations under the License.
# ==============================================================================
-"""Variables. See the @{$python/state_ops} guide."""
+"""Variables.
+
+See the [Variables](https://tensorflow.org/api_guides/python/state_ops) guide.
+"""
from __future__ import absolute_import
from __future__ import division
@@ -21,6 +24,8 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gen_state_ops
# go/tf-wildcard-import
@@ -524,3 +529,101 @@ def scatter_sub(ref, indices, updates, use_locking=False, name=None):
return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
+
+
+@tf_export("batch_scatter_update")
+def batch_scatter_update(ref, indices, updates, use_locking=True, name=None):
+ """Generalization of `tf.scatter_update` to axis different than 0.
+
+ Analogous to `batch_gather`. This assumes that `ref`, `indices` and `updates`
+ have a series of leading dimensions that are the same for all of them, and the
+ updates are performed on the last dimension of indices. In other words, the
+ dimensions should be the following:
+
+ `num_prefix_dims = indices.ndims - 1`
+ `batch_dim = num_prefix_dims + 1`
+ `updates.shape = indices.shape + var.shape[batch_dim:]`
+
+ where
+
+ `updates.shape[:num_prefix_dims]`
+ `== indices.shape[:num_prefix_dims]`
+ `== var.shape[:num_prefix_dims]`
+
+ And the operation performed can be expressed as:
+
+ `var[i_1, ..., i_n, indices[i_1, ..., i_n, j]] = updates[i_1, ..., i_n, j]`
+
+ When indices is a 1D tensor, this operation is equivalent to
+ `tf.scatter_update`.
+
+ To avoid this operation there would be 2 alternatives:
+ 1) Reshaping the variable by merging the first `ndims` dimensions. However,
+ this is not possible because `tf.reshape` returns a Tensor, which we
+ cannot use `tf.scatter_update` on.
+ 2) Looping over the first `ndims` of the variable and using
+ `tf.scatter_update` on the subtensors that result of slicing the first
+ dimension. This is a valid option for `ndims = 1`, but less efficient than
+ this implementation.
+
+ See also `tf.scatter_update` and `tf.scatter_nd_update`.
+
+ Args:
+ ref: `Variable` to scatter onto.
+ indices: Tensor containing indices as described above.
+ updates: Tensor of updates to apply to `ref`.
+ use_locking: Boolean indicating whether to lock the writing operation.
+ name: Optional scope name string.
+
+ Returns:
+ Ref to `variable` after it has been modified.
+
+ Raises:
+ ValueError: If the initial `ndims` of `ref`, `indices`, and `updates` are
+ not the same.
+ """
+ with ops.name_scope(name):
+ indices = ops.convert_to_tensor(indices, name="indices")
+ indices_shape = array_ops.shape(indices)
+ indices_dimensions = indices.get_shape().ndims
+
+ if indices_dimensions is None:
+ raise ValueError("batch_gather does not allow indices with unknown "
+ "shape.")
+
+ nd_indices = array_ops.expand_dims(indices, axis=-1)
+ nd_indices_list = []
+
+ # Scatter ND requires indices to have an additional dimension, in which the
+ # coordinates of the updated things are specified. For this to be adapted to
+ # the scatter_update with several leading dimensions, we simply make use of
+ # a tf.range for all the leading dimensions followed by concat of all the
+ # coordinates we created with the original indices.
+
+ # For example if indices.shape = [2, 3, 4], we should generate the following
+ # indices for tf.scatter_nd_update:
+ # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]]
+ # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]]
+ # nd_indices[:, :, 2] = indices
+ for dimension in range(indices_dimensions - 1):
+ # In this loop we generate the following for the example (one for each
+ # iteration).
+ # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]]
+ # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]]
+ # This is done at every iteration with a tf.range over the size of the
+ # i-th dimension and using broadcasting over the desired shape.
+ dimension_size = indices_shape[dimension]
+ shape_to_broadcast = [1] * (indices_dimensions + 1)
+ shape_to_broadcast[dimension] = dimension_size
+ dimension_range = array_ops.reshape(
+ gen_math_ops._range(0, dimension_size, 1), shape_to_broadcast)
+ if dimension_range.dtype.base_dtype != nd_indices.dtype:
+ dimension_range = gen_math_ops.cast(dimension_range, nd_indices.dtype)
+ nd_indices_list.append(
+ dimension_range * array_ops.ones_like(nd_indices))
+ # Add the original indices at the end, as described above, and concat.
+ nd_indices_list.append(nd_indices)
+ final_indices = array_ops.concat(nd_indices_list, axis=-1)
+ return scatter_nd_update(
+ ref, final_indices, updates, use_locking=use_locking)
+
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index 0280c89c10..c832ba4e2a 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -15,7 +15,7 @@
"""Operations for working with string Tensors.
-See the @{$python/string_ops} guide.
+See the [Strings](https://tensorflow.org/api_guides/python/string_ops) guide.
"""
from __future__ import absolute_import
@@ -24,6 +24,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -31,6 +32,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.util import compat as util_compat
# go/tf-wildcard-import
# pylint: disable=wildcard-import
@@ -42,6 +44,41 @@ from tensorflow.python.util.tf_export import tf_export
# Expose regex_full_match in strings namespace
tf_export("strings.regex_full_match")(regex_full_match)
+
+def regex_replace(source, pattern, rewrite, replace_global=True):
+ r"""Replace elements of `source` matching regex `pattern with `rewrite`.
+
+ Args:
+ source: string `Tensor`, the source strings to process.
+ pattern: string or scalar string `Tensor`, regular expression to use,
+ see more details at https://github.com/google/re2/wiki/Syntax
+ rewrite: string or scalar string `Tensor`, value to use in match
+ replacement, supports backslash-escaped digits (\1 to \9) can be to insert
+ text matching corresponding parenthesized group.
+ replace_global: `bool`, if `True` replace all non-overlapping matches,
+ else replace only the first match.
+
+ Returns:
+ string `Tensor` of the same shape as `source` with specified replacements.
+ """
+ # TODO(b/112455102): Remove compat.forward_compatible once past the horizon.
+ if not compat.forward_compatible(2018, 10, 10):
+ return gen_string_ops.regex_replace(
+ input=source, pattern=pattern,
+ rewrite=rewrite, replace_global=replace_global)
+ if (isinstance(pattern, util_compat.bytes_or_text_types) and
+ isinstance(rewrite, util_compat.bytes_or_text_types)):
+ # When `pattern` and `rewrite` are static through the life of the op we can
+ # use a version which performs the expensive regex compilation once at
+ # creation time.
+ return gen_string_ops.static_regex_replace(
+ input=source, pattern=pattern,
+ rewrite=rewrite, replace_global=replace_global)
+ return gen_string_ops.regex_replace(
+ input=source, pattern=pattern,
+ rewrite=rewrite, replace_global=replace_global)
+
+
@tf_export("string_split")
def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=invalid-name
"""Split elements of `source` based on `delimiter` into a `SparseTensor`.
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index c248dd9172..46bcd68f1a 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -42,6 +42,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_contextlib
+from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -837,9 +838,6 @@ class _VariableStore(object):
raise ValueError("Variable %s does not exist, or was not created with "
"tf.get_variable(). Did you mean to set "
"reuse=tf.AUTO_REUSE in VarScope?" % name)
- if not shape.is_fully_defined() and not initializing_from_value:
- raise ValueError("Shape of a new variable (%s) must be fully defined, "
- "but instead was %s." % (name, shape))
# Create the tensor to initialize the variable with default value.
if initializer is None:
@@ -854,8 +852,17 @@ class _VariableStore(object):
# Instantiate initializer if provided initializer is a type object.
if isinstance(initializer, type(init_ops.Initializer)):
initializer = initializer(dtype=dtype)
- init_val = lambda: initializer( # pylint: disable=g-long-lambda
- shape.as_list(), dtype=dtype, partition_info=partition_info)
+ if shape and shape.is_fully_defined():
+ init_val = lambda: initializer( # pylint: disable=g-long-lambda
+ shape.as_list(), dtype=dtype, partition_info=partition_info)
+ elif not tf_inspect.getargspec(initializer).args:
+ init_val = initializer
+ else:
+ raise ValueError("You can only pass an initializer function that"
+ "expects no arguments to its callable when the "
+ "shape is not fully defined. The given initializer "
+ "function expects the following args %s" %
+ tf_inspect.getargspec(initializer).args)
variable_dtype = dtype.base_dtype
# Create the variable.
@@ -1440,12 +1447,11 @@ def get_variable(name,
aggregation=aggregation)
-get_variable_or_local_docstring = (
- """%s
+get_variable_or_local_docstring = ("""%s
%sThis function prefixes the name with the current variable scope
and performs reuse checks. See the
-@{$variables$Variable Scope How To}
+[Variable Scope How To](https://tensorflow.org/guide/variables)
for an extensive description of how reusing works. Here is a basic example:
```python
@@ -1895,8 +1901,8 @@ class variable_scope(object):
Variable scope allows you to create new variables and to share already created
ones while providing checks to not create or share by accident. For details,
- see the @{$variables$Variable Scope How To}, here we present only a few basic
- examples.
+ see the [Variable Scope How To](https://tensorflow.org/guide/variables), here
+ we present only a few basic examples.
Simple example of how to create a new variable:
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 402ab2dd9d..7a28615ba9 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -135,7 +135,7 @@ class VariableMetaclass(type):
@tf_export("Variable")
class Variable(six.with_metaclass(VariableMetaclass,
checkpointable.CheckpointableBase)):
- """See the @{$variables$Variables How To} for a high level overview.
+ """See the [Variables Guide](https://tensorflow.org/guide/variables).
A variable maintains state in the graph across calls to `run()`. You add a
variable to the graph by constructing an instance of the class `Variable`.
diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py
index 9ffb48c4a5..5dc4037d62 100644
--- a/tensorflow/python/platform/test.py
+++ b/tensorflow/python/platform/test.py
@@ -15,7 +15,7 @@
"""Testing.
-See the @{$python/test} guide.
+See the [Testing](https://tensorflow.org/api_guides/python/test) guide.
Note: `tf.test.mock` is an alias to the python `mock` or `unittest.mock`
depending on the python version.
diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py
index 980320cc66..fbae2b77fa 100644
--- a/tensorflow/python/summary/summary.py
+++ b/tensorflow/python/summary/summary.py
@@ -15,7 +15,7 @@
"""Tensor summaries for exporting information about a model.
-See the @{$python/summary} guide.
+See the [Summary](https://tensorflow.org/api_guides/python/summary) guide.
"""
from __future__ import absolute_import
diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD
index 222f856511..01d43e09d1 100644
--- a/tensorflow/python/tools/BUILD
+++ b/tensorflow/python/tools/BUILD
@@ -114,6 +114,12 @@ py_library(
],
)
+py_library(
+ name = "component_api_helper",
+ srcs = ["component_api_helper.py"],
+ srcs_version = "PY2AND3",
+)
+
py_binary(
name = "strip_unused",
srcs = ["strip_unused.py"],
diff --git a/tensorflow/python/tools/component_api_helper.py b/tensorflow/python/tools/component_api_helper.py
new file mode 100644
index 0000000000..988ecc61f0
--- /dev/null
+++ b/tensorflow/python/tools/component_api_helper.py
@@ -0,0 +1,85 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Helper functions to help integrate TensorFlow components into TF API.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import importlib
+import os
+
+
+def package_hook(parent_package_str, child_package_str, error_msg=None):
+ """Used to hook in an external package into the TensorFlow namespace.
+
+ Example usage:
+ ### tensorflow/__init__.py
+ from tensorflow.python.tools import component_api_helper
+ component_api_helper.package_hook(
+ 'tensorflow', 'tensorflow_estimator.python')
+ component_api_helper(
+ 'tensorflow.contrib', 'tensorflow_estimator.contrib.python')
+ del component_api_helper
+
+ TODO(mikecase): This function has a minor issue, where if the child package
+ does not exist alone in its directory, sibling packages to it will also be
+ accessible from the parent. This is because we just add
+ `child_pkg.__file__/..` to the subpackage search path. This should not be
+ a big issue because of how our API generation scripts work (the child package
+ we are hooking up should always be alone). But there might be a better way
+ of doing this.
+
+ Args:
+ parent_package_str: Parent package name as a string such as 'tensorflow' or
+ 'tensorflow.contrib'. This will become the parent package for the
+ component package being hooked in.
+ child_package_str: Child package name as a string such as
+ 'tensorflow_estimator.python'. This package will be added as a subpackage
+ of the parent.
+ error_msg: Message to print if child package cannot be found.
+ """
+ parent_pkg = importlib.import_module(parent_package_str)
+ try:
+ child_pkg = importlib.import_module(child_package_str)
+ except ImportError:
+ if error_msg:
+ print(error_msg)
+ return
+
+ def set_child_as_subpackage():
+ """Sets child package as a subpackage of parent package.
+
+ Will allow the following import statement to work.
+ >>> import parent.child
+ """
+ child_pkg_path = [os.path.join(os.path.dirname(child_pkg.__file__), "..")]
+ try:
+ parent_pkg.__path__ += child_pkg_path
+ except AttributeError:
+ parent_pkg.__path__ = child_pkg_path
+
+ def set_child_as_attr():
+ """Sets child package as a attr of the parent package.
+
+ Will allow for the following.
+ >>> import parent
+ >>> parent.child
+ """
+ child_pkg_attr_name = child_pkg.__name__.split(".")[-1]
+ setattr(parent_pkg, child_pkg_attr_name, child_pkg)
+
+ set_child_as_subpackage()
+ set_child_as_attr()
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py
index acf070075e..c7f414c5dc 100644
--- a/tensorflow/python/tools/freeze_graph.py
+++ b/tensorflow/python/tools/freeze_graph.py
@@ -59,7 +59,7 @@ from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
-def _has_variables(sess):
+def _has_no_variables(sess):
"""Determines if the graph has any variables.
Args:
@@ -168,7 +168,7 @@ def freeze_graph_with_def_protos(input_graph_def,
"the flag --input_saved_model_dir.")
return -1
# Models that have been frozen previously do not contain Variables.
- elif _has_variables(sess):
+ elif _has_no_variables(sess):
print("No variables were found in this model. It is likely the model "
"was frozen previously. You cannot freeze a graph twice.")
return 0
diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py
index 6778f3c735..3508b98475 100644
--- a/tensorflow/python/training/adagrad.py
+++ b/tensorflow/python/training/adagrad.py
@@ -70,20 +70,24 @@ class AdagradOptimizer(optimizer.Optimizer):
def _create_slots(self, var_list):
for v in var_list:
- with ops.colocate_with(v):
- dtype = v.dtype.base_dtype
- if v.get_shape().is_fully_defined():
- init = init_ops.constant_initializer(self._initial_accumulator_value,
- dtype=dtype)
- else:
- # Use a Tensor instead of initializer if variable does not have static
- # shape.
- init_constant = gen_array_ops.fill(array_ops.shape(v),
- self._initial_accumulator_value)
- init = math_ops.cast(init_constant, dtype)
+ dtype = v.dtype.base_dtype
+ if v.get_shape().is_fully_defined():
+ init = init_ops.constant_initializer(self._initial_accumulator_value,
+ dtype=dtype)
+ else:
+ init = self._init_constant_op(v, dtype)
self._get_or_make_slot_with_initializer(v, init, v.get_shape(), dtype,
"accumulator", self._name)
+ def _init_constant_op(self, v, dtype):
+ def init():
+ # Use a Tensor instead of initializer if variable does not have
+ # static shape.
+ init_constant = gen_array_ops.fill(array_ops.shape(v),
+ self._initial_accumulator_value)
+ return math_ops.cast(init_constant, dtype)
+ return init
+
def _prepare(self):
learning_rate = self._call_if_callable(self._learning_rate)
self._learning_rate_tensor = ops.convert_to_tensor(
diff --git a/tensorflow/python/training/adagrad_test.py b/tensorflow/python/training/adagrad_test.py
index c9aec33d09..4e634fff84 100644
--- a/tensorflow/python/training/adagrad_test.py
+++ b/tensorflow/python/training/adagrad_test.py
@@ -302,6 +302,39 @@ class AdagradOptimizerTest(test.TestCase):
# Creating optimizer should cause no exception.
adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1)
+ def testDynamicShapeVariableWithCallableInit(self):
+ var0 = variable_scope.get_variable("var0",
+ initializer=constant_op.constant(1.),
+ validate_shape=False)
+ self.assertFalse(var0.shape.is_fully_defined())
+
+ grads0 = constant_op.constant(0.1, dtype=dtypes.float32)
+ learning_rate = lambda: 3.0
+
+ ada_opt = adagrad.AdagradOptimizer(
+ learning_rate, initial_accumulator_value=0.1, use_locking=True)
+
+ if not context.executing_eagerly():
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0], [var0]))
+ self.evaluate(variables.global_variables_initializer())
+
+ # Fetch params to validate initial values
+ v0_val = self.evaluate([var0])
+ self.assertAllClose([1.0], v0_val)
+
+ # Run 3 steps of adagrad
+ for _ in range(3):
+ if not context.executing_eagerly():
+ self.evaluate(ada_update)
+ else:
+ ada_opt.apply_gradients(zip([grads0], [var0]))
+
+ # Validate updated params
+ v0_val = self.evaluate([var0])
+ self.assertAllCloseAccordingToType(
+ np.array([-1.6026098728179932]), v0_val)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD
index 6e9b8ff905..d26932c1aa 100644
--- a/tensorflow/python/training/checkpointable/BUILD
+++ b/tensorflow/python/training/checkpointable/BUILD
@@ -101,15 +101,26 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":base",
+ ":data_structures",
":tracking",
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:checkpoint_management",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
"//tensorflow/python:io_ops_gen",
- "//tensorflow/python:ops",
+ "//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:saveable_object",
+ "//tensorflow/python:saver",
+ "//tensorflow/python:session",
+ "//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
"//tensorflow/python/eager:context",
],
)
diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py
index 390434c0a2..9189d8f3e8 100644
--- a/tensorflow/python/training/checkpointable/base.py
+++ b/tensorflow/python/training/checkpointable/base.py
@@ -22,6 +22,7 @@ import functools
import json
import weakref
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -93,14 +94,17 @@ class CheckpointInitialValue(ops.Tensor):
class PythonStringStateSaveable(saveable_object.SaveableObject):
"""Saves Python state in a checkpoint."""
- def __init__(self, name, state_callback):
+ def __init__(self, name, state_callback, restore_callback=None):
"""Configure saving.
Args:
name: The checkpoint key to write to.
state_callback: A function taking no arguments which returns a
string. This function is run every time a checkpoint is written.
+ restore_callback: A function taking a Python string, used to restore
+ state. Optional; defaults to doing nothing.
"""
+ self._restore_callback = restore_callback
if context.executing_eagerly():
self._save_string = (
lambda: constant_op.constant(state_callback(), dtype=dtypes.string))
@@ -113,9 +117,14 @@ class PythonStringStateSaveable(saveable_object.SaveableObject):
super(PythonStringStateSaveable, self).__init__(
self._save_string, [spec], name)
+ def python_restore(self, restored_strings):
+ """Called to restore Python state."""
+ if self._restore_callback:
+ restored, = restored_strings
+ self._restore_callback(restored)
+
def restore(self, restored_tensors, restored_shapes):
- # TODO(allenl): Add a Python hook for state coming out of a checkpoint
- # (currently PythonStringStateSaveable is write-only).
+ """Called to restore TensorFlow state (nothing to do)."""
return control_flow_ops.no_op()
@@ -227,7 +236,7 @@ class _CheckpointPosition(object):
with ops.device("/cpu:0"):
# Run the restore itself on the CPU.
value, = io_ops.restore_v2(
- prefix=self._checkpoint.save_path,
+ prefix=self._checkpoint.save_path_tensor,
tensor_names=[checkpoint_key],
shape_and_slices=[""],
dtypes=[base_type],
@@ -236,42 +245,99 @@ class _CheckpointPosition(object):
value_tensors[serialized_tensor.name] = array_ops.identity(value)
return value_tensors
- def restore_ops(self):
- """Create or fetch restore ops for this object's attributes.
-
- Requires that the `Checkpointable` Python object has been bound to an object
- ID in the checkpoint.
-
- Returns:
- A list of operations when graph building, or an empty list when executing
- eagerly.
- """
+ def _gather_ops_or_named_saveables(self):
+ """Looks up or creates SaveableObjects which don't have cached ops."""
saveables = self.checkpointable._gather_saveables_for_checkpoint() # pylint: disable=protected-access
# Name saveables based on the name this object had when it was checkpointed.
named_saveables = {}
- restore_ops = []
- building_graph = not context.executing_eagerly()
+ python_saveables = []
+ existing_restore_ops = []
for serialized_tensor in self.object_proto.attributes:
- saveable_factory = saveables.get(serialized_tensor.name, None)
- if saveable_factory is None:
- # Purposefully does not throw an exception if attributes have been added
- # or deleted. Stores unused attributes so an exception can be raised if
- # the user decides to check that everything in the checkpoint was
- # loaded.
- self._checkpoint.unused_attributes.setdefault(
- self.checkpointable, []).append(serialized_tensor.name)
+ if context.executing_eagerly():
+ existing_op = None
+ else:
+ existing_op = self._checkpoint.restore_ops_by_name.get(
+ serialized_tensor.checkpoint_key, None)
+ if existing_op is not None:
+ existing_restore_ops.append(existing_op)
continue
- if building_graph:
- existing_ops = self._checkpoint.restore_ops_by_name.get(
- serialized_tensor.name, None)
+
+ # Only if we don't have cached ops for this SaveableObject, we'll see if
+ # the SaveableObject itself has been cached. If not, we'll make it, and
+ # either way we'll extract new ops from it (or if it has Python state to
+ # restore, we'll run that).
+ if self._checkpoint.saveable_object_cache is None:
+ # No SaveableObject caching when executing eagerly.
+ saveable = None
else:
- existing_ops = None
- if existing_ops is None:
+ # If we've already created and cached a SaveableObject for this
+ # attribute, we can re-use it to avoid re-creating some ops when graph
+ # building.
+ saveable_list = self._checkpoint.saveable_object_cache.get(
+ self.checkpointable, {}).get(serialized_tensor.name, (None,))
+ if len(saveable_list) == 1:
+ # Almost every attribute will have exactly one SaveableObject.
+ saveable, = saveable_list
+ else:
+ # Don't use cached SaveableObjects for partitioned variables, which is
+ # the only case where we'd have a list of SaveableObjects. Op caching
+ # will catch them.
+ saveable = None
+ if saveable is not None:
+ # The name of this attribute has changed, so we need to re-generate
+ # the SaveableObject.
+ if serialized_tensor.checkpoint_key not in saveable.name:
+ saveable = None
+ del self._checkpoint.saveable_object_cache[self.checkpointable]
+ break
+ if saveable is None:
+ # If there was no cached SaveableObject, we should check if the Python
+ # object has the attribute.
+ saveable_factory = saveables.get(serialized_tensor.name, None)
+ if saveable_factory is None:
+ # Purposefully does not throw an exception if attributes have been
+ # added or deleted. Stores unused attributes so an exception can be
+ # raised if the user decides to check that everything in the
+ # checkpoint was loaded.
+ self._checkpoint.unused_attributes.setdefault(
+ self.checkpointable, []).append(serialized_tensor.name)
+ continue
if callable(saveable_factory):
saveable = saveable_factory(name=serialized_tensor.checkpoint_key)
else:
saveable = saveable_factory
+ if self._checkpoint.saveable_object_cache is not None:
+ self._checkpoint.saveable_object_cache.setdefault(
+ self.checkpointable, {})[serialized_tensor.name] = [saveable]
+ if isinstance(saveable, PythonStringStateSaveable):
+ python_saveables.append(saveable)
+ else:
named_saveables[serialized_tensor.checkpoint_key] = saveable
+ return existing_restore_ops, named_saveables, python_saveables
+
+ def restore_ops(self):
+ """Create or fetch restore ops for this object's attributes.
+
+ Requires that the `Checkpointable` Python object has been bound to an object
+ ID in the checkpoint.
+
+ Returns:
+ A list of operations when graph building, or an empty list when executing
+ eagerly.
+ """
+ (restore_ops,
+ named_saveables,
+ python_saveables) = self._gather_ops_or_named_saveables()
+
+ # Eagerly run restorations for Python state.
+ reader = pywrap_tensorflow.NewCheckpointReader(
+ self._checkpoint.save_path_string)
+ for saveable in python_saveables:
+ spec_names = [spec.name for spec in saveable.specs]
+ saveable.python_restore(
+ [reader.get_tensor(name) for name in spec_names])
+
+ # If we have new SaveableObjects, extract and cache restore ops.
if named_saveables:
validated_saveables = (
self._checkpoint.builder._ValidateAndSliceInputs(named_saveables)) # pylint: disable=protected-access
@@ -281,7 +347,7 @@ class _CheckpointPosition(object):
("Saveable keys changed when validating. Got back %s, was "
"expecting %s") % (named_saveables.keys(), validated_names))
all_tensors = self._checkpoint.builder.bulk_restore(
- filename_tensor=self._checkpoint.save_path,
+ filename_tensor=self._checkpoint.save_path_tensor,
saveables=validated_saveables, preferred_shard=-1,
restore_sequentially=False)
saveable_index = 0
@@ -291,7 +357,7 @@ class _CheckpointPosition(object):
saveable_index:saveable_index + num_specs]
saveable_index += num_specs
restore_op = saveable.restore(saveable_tensors, restored_shapes=None)
- if building_graph:
+ if not context.executing_eagerly():
assert saveable.name not in self._checkpoint.restore_ops_by_name
self._checkpoint.restore_ops_by_name[saveable.name] = restore_op
restore_ops.append(restore_op)
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py
index 507cda8734..f06cbbfa15 100644
--- a/tensorflow/python/training/checkpointable/data_structures.py
+++ b/tensorflow/python/training/checkpointable/data_structures.py
@@ -128,7 +128,8 @@ class CheckpointableDataStructure(base.CheckpointableBase):
"stored in a List object. Got %s, which does not inherit from "
"CheckpointableBase.") % (value,))
if (isinstance(value, CheckpointableDataStructure)
- or layer_utils.is_layer(value)):
+ or layer_utils.is_layer(value)
+ or layer_utils.has_weights(value)):
# Check for object-identity rather than with __eq__ to avoid
# de-duplicating empty container types. Automatically generated list
# wrappers keep things like "[] == []" true, which means "[] in [[]]" is
@@ -149,14 +150,14 @@ class CheckpointableDataStructure(base.CheckpointableBase):
def trainable_weights(self):
return layer_utils.gather_trainable_weights(
trainable=self.trainable,
- sub_layers=self.layers,
+ sub_layers=self._layers,
extra_variables=self._extra_variables)
@property
def non_trainable_weights(self):
return layer_utils.gather_non_trainable_weights(
trainable=self.trainable,
- sub_layers=self.layers,
+ sub_layers=self._layers,
extra_variables=self._extra_variables)
@property
@@ -183,7 +184,8 @@ class CheckpointableDataStructure(base.CheckpointableBase):
# have any inputs.
aggregated = []
for layer in self.layers:
- aggregated += layer.updates
+ if hasattr(layer, "updates"):
+ aggregated += layer.updates
return aggregated
@property
@@ -191,7 +193,8 @@ class CheckpointableDataStructure(base.CheckpointableBase):
"""Aggregate losses from any `Layer` instances."""
aggregated = []
for layer in self.layers:
- aggregated += layer.losses
+ if hasattr(layer, "losses"):
+ aggregated += layer.losses
return aggregated
def __hash__(self):
diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py
index 472b7c32b4..4638917b4c 100644
--- a/tensorflow/python/training/checkpointable/data_structures_test.py
+++ b/tensorflow/python/training/checkpointable/data_structures_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.layers import core as non_keras_core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util
@@ -96,6 +97,11 @@ class ListTests(test.TestCase):
model.load_weights(save_path)
self.assertAllEqual([[1., 2., 3.], [4., 5., 6.]],
self.evaluate(model.variables[0]))
+ v = variables.Variable(1.)
+ model.var_list = [v]
+ self.assertIn(v, model.variables)
+ self.assertIn(v, model.trainable_variables)
+ self.assertNotIn(v, model.non_trainable_variables)
def testUpdatesForwarded(self):
with context.graph_mode():
diff --git a/tensorflow/python/training/checkpointable/layer_utils.py b/tensorflow/python/training/checkpointable/layer_utils.py
index d65b631fe9..ec764bca89 100644
--- a/tensorflow/python/training/checkpointable/layer_utils.py
+++ b/tensorflow/python/training/checkpointable/layer_utils.py
@@ -30,13 +30,20 @@ def is_layer(obj):
and hasattr(obj, "variables"))
+def has_weights(obj):
+ """Implicit check for Layer-like objects."""
+ # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer).
+ return (hasattr(obj, "trainable_weights")
+ and hasattr(obj, "non_trainable_weights"))
+
+
def filter_empty_layer_containers(layer_list):
"""Filter out empty Layer-like containers."""
filtered = []
for obj in layer_list:
if is_layer(obj):
filtered.append(obj)
- else:
+ elif hasattr(obj, "layers"):
# Checkpointable data structures will not show up in ".layers" lists, but
# the layers they contain will.
filtered.extend(obj.layers)
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index e42f989469..f49ed5c9ff 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -68,16 +68,25 @@ _OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"
class _CheckpointRestoreCoordinator(object):
"""Holds the status of an object-based checkpoint load."""
- def __init__(self, object_graph_proto, save_path, dtype_map=None):
+ def __init__(self, object_graph_proto, save_path, save_path_tensor,
+ restore_op_cache, saveable_object_cache):
"""Specify the checkpoint being loaded.
Args:
object_graph_proto: The CheckpointableObjectGraph protocol buffer
associated with this checkpoint.
- save_path: A string `Tensor`. The path to the checkpoint, as returned by
+ save_path: A string, the path to the checkpoint, as returned by
`tf.train.latest_checkpoint`.
- dtype_map: When executing eagerly, specifies dtypes for creating slot
- variables. None when graph building.
+ save_path_tensor: A string `Tensor` which contains or will be fed the save
+ path.
+ restore_op_cache: A dictionary shared between
+ `_CheckpointRestoreCoordinator`s for the same Python objects, used to
+ look up restore ops by name to avoid re-creating them across multiple
+ `restore()` calls.
+ saveable_object_cache: A mapping of checkpointable objects -> attribute
+ names -> list(`SaveableObject`s), used when `SaveableObjects` must be
+ referenced every restore (e.g. for Python state); otherwise they would
+ create their own ops every restore.
"""
self.builder = saver_lib.BulkSaverBuilder()
self.object_graph_proto = object_graph_proto
@@ -97,12 +106,17 @@ class _CheckpointRestoreCoordinator(object):
# loading). Used to make status assertions fail when loading checkpoints
# that don't quite match.
self.all_python_objects = _ObjectIdentityWeakSet()
- self.save_path = save_path
- self.dtype_map = dtype_map
+ self.save_path_tensor = save_path_tensor
+ self.save_path_string = save_path
+ self.dtype_map = pywrap_tensorflow.NewCheckpointReader(
+ save_path).get_variable_to_dtype_map()
+ # A NewCheckpointReader for the most recent checkpoint, for streaming Python
+ # state restoration.
# When graph building, contains a list of ops to run to restore objects from
# this checkpoint.
self.restore_ops = []
- self.restore_ops_by_name = {}
+ self.restore_ops_by_name = restore_op_cache
+ self.saveable_object_cache = saveable_object_cache
self.new_restore_ops_callback = None
# A mapping from optimizer proto ids to lists of slot variables to be
# restored when the optimizer is tracked. Only includes slot variables whose
@@ -1153,16 +1167,15 @@ class CheckpointableSaver(object):
self._last_save_object_graph = None
self._last_save_saver = None
- # Op caching for restore
- self._last_restore_object_graph = None
- self._last_restore_checkpoint = None
+ # Op caching for restore, shared between _CheckpointRestoreCoordinators
+ self._restore_op_cache = {}
if context.executing_eagerly():
# SaveableObjects are always recreated when executing eagerly.
self._saveable_object_cache = None
else:
- # Maps Checkpointable objects -> attribute names -> SaveableObjects, to
- # avoid re-creating SaveableObjects when graph building.
+ # Maps Checkpointable objects -> attribute names -> list(SaveableObjects),
+ # to avoid re-creating SaveableObjects when graph building.
self._saveable_object_cache = _ObjectIdentityWeakKeyDictionary()
@property
@@ -1340,22 +1353,12 @@ class CheckpointableSaver(object):
object_graph_proto = (
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
object_graph_proto.ParseFromString(object_graph_string)
- if graph_building and object_graph_proto == self._last_restore_object_graph:
- checkpoint = self._last_restore_checkpoint
- else:
- checkpoint = _CheckpointRestoreCoordinator(
- object_graph_proto=object_graph_proto,
- save_path=file_prefix_tensor,
- dtype_map=dtype_map)
- if graph_building:
- if self._last_restore_object_graph is not None:
- raise NotImplementedError(
- "Using a single Saver to restore different object graphs is not "
- "currently supported when graph building. Use a different Saver "
- "for each object graph (restore ops will be duplicated), or "
- "file a feature request if this limitation bothers you.")
- self._last_restore_checkpoint = checkpoint
- self._last_restore_object_graph = object_graph_proto
+ checkpoint = _CheckpointRestoreCoordinator(
+ object_graph_proto=object_graph_proto,
+ save_path=save_path,
+ save_path_tensor=file_prefix_tensor,
+ restore_op_cache=self._restore_op_cache,
+ saveable_object_cache=self._saveable_object_cache)
base._CheckpointPosition( # pylint: disable=protected-access
checkpoint=checkpoint, proto_id=0).restore(self._root_checkpointable)
load_status = CheckpointLoadStatus(
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index a0a87b6b79..cac293e916 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -1073,16 +1073,11 @@ class CheckpointingTests(test.TestCase):
self.assertEqual(5, self.evaluate(checkpoint.var_5))
self.assertEqual(1, self.evaluate(checkpoint.var_1))
self.assertEqual(0, self.evaluate(checkpoint.var_0))
- if context.executing_eagerly():
- checkpoint.restore(checkpoint_prefix + "-10").run_restore_ops()
- self.assertEqual(9, self.evaluate(checkpoint.var_9))
- self.assertEqual(8, self.evaluate(checkpoint.var_8))
- self.assertEqual(1, self.evaluate(checkpoint.var_1))
- self.assertEqual(0, self.evaluate(checkpoint.var_0))
- else:
- # Restoring into modified graphs is an error while graph building.
- with self.assertRaises(NotImplementedError):
- checkpoint.restore(checkpoint_prefix + "-10").run_restore_ops()
+ checkpoint.restore(checkpoint_prefix + "-10").run_restore_ops()
+ self.assertEqual(9, self.evaluate(checkpoint.var_9))
+ self.assertEqual(8, self.evaluate(checkpoint.var_8))
+ self.assertEqual(1, self.evaluate(checkpoint.var_1))
+ self.assertEqual(0, self.evaluate(checkpoint.var_0))
def testManyRestoresGraph(self):
"""Restores after the first should not modify the graph."""
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 28c60ad809..1ac7c39872 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -248,6 +248,7 @@ class DistributionStrategy(object):
devices.
We have then a few approaches we want to support:
+
* Code written (as if) with no knowledge of class `DistributionStrategy`.
This code should work as before, even if some of the layers, etc.
used by that code are written to be distribution-aware. This is done
@@ -624,13 +625,18 @@ class DistributionStrategy(object):
Args:
fn: function to run using this distribution strategy. The function must
- have the following signature: def fn(context, inputs).
+ have the following signature: def fn(context, *inputs).
`context` is an instance of `MultiStepContext` that will be passed when
`fn` is run. `context` can be used to specify the outputs to be returned
from `fn` by calling `context.set_last_step_output`. It can also be used
to capture non tensor outputs by `context.set_non_tensor_output`.
See `MultiStepContext` documentation for more information.
- `inputs` will have same type/structure as `iterator.get_next()`.
+ `inputs` will have same type/structure as `iterator.get_next()`. If the
+ `iterator.get_next()` returns a tuple say `return x, y` then whose will
+ be unpacked and passed to the `step_fn`; and step_fn signature would
+ look like `def step_fn(context, x, y)`. If the iterator returns a single
+ value say `return x` then the value is passed as is; the step_fn
+ signature would look like `def step_fn(context, x)`.
Typically, `fn` will use `call_for_each_tower` method of the strategy
to distribute the computation over multiple towers.
iterator: Iterator of a dataset that represents the input for `fn`. The
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index caa26581e8..0d6207f8c4 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -15,7 +15,8 @@
"""Input pipeline.
-Please see the @{$reading_data$reading data how-to}
+Please see the [reading data
+how-to](https://tensorflow.org/api_guides/python/reading_data)
for context.
"""
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 7b06bffa4b..c077630de2 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -25,6 +25,7 @@ import sys
import six
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.distribute import distribute_coordinator_context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -284,6 +285,63 @@ class Scaffold(object):
resources.initialize_resources(resources.local_resources()))
+def _create_monitored_session_with_worker_context(worker_context, # pylint: disable=missing-docstring
+ scaffold,
+ checkpoint_dir=None,
+ hooks=None,
+ chief_only_hooks=None,
+ save_checkpoint_secs=None,
+ save_summaries_steps=None,
+ save_summaries_secs=None,
+ config=None,
+ stop_grace_period_secs=120,
+ log_step_count_steps=100,
+ max_wait_secs=7200,
+ save_checkpoint_steps=None,
+ summary_dir=None):
+ all_hooks = []
+ if hooks:
+ all_hooks.extend(hooks)
+ if chief_only_hooks and worker_context.is_chief:
+ all_hooks.extend(chief_only_hooks)
+
+ summary_dir = summary_dir or checkpoint_dir
+ if summary_dir and worker_context.should_save_summary:
+ if log_step_count_steps and log_step_count_steps > 0:
+ all_hooks.append(
+ basic_session_run_hooks.StepCounterHook(
+ output_dir=summary_dir, every_n_steps=log_step_count_steps))
+
+ if (save_summaries_steps and save_summaries_steps > 0) or (
+ save_summaries_secs and save_summaries_secs > 0):
+ all_hooks.append(
+ basic_session_run_hooks.SummarySaverHook(
+ scaffold=scaffold,
+ save_steps=save_summaries_steps,
+ save_secs=save_summaries_secs,
+ output_dir=summary_dir))
+
+ if checkpoint_dir and worker_context.should_checkpoint:
+ if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
+ save_checkpoint_steps and save_checkpoint_steps > 0):
+ all_hooks.append(
+ basic_session_run_hooks.CheckpointSaverHook(
+ checkpoint_dir,
+ save_steps=save_checkpoint_steps,
+ save_secs=save_checkpoint_secs,
+ scaffold=scaffold))
+
+ session_creator = worker_context.session_creator(
+ scaffold,
+ config=config,
+ checkpoint_dir=checkpoint_dir,
+ max_wait_secs=max_wait_secs)
+ return MonitoredSession(
+ session_creator=session_creator,
+ hooks=all_hooks,
+ stop_grace_period_secs=stop_grace_period_secs)
+
+
@tf_export('train.MonitoredTrainingSession')
def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
is_chief=True,
@@ -373,14 +431,35 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
save_checkpoint_steps = None
scaffold = scaffold or Scaffold()
+ worker_context = distribute_coordinator_context.get_current_worker_context()
+
+ if worker_context:
+ return _create_monitored_session_with_worker_context(
+ worker_context,
+ scaffold,
+ checkpoint_dir=checkpoint_dir,
+ hooks=hooks,
+ chief_only_hooks=chief_only_hooks,
+ save_checkpoint_secs=save_checkpoint_secs,
+ save_summaries_steps=save_summaries_steps,
+ save_summaries_secs=save_summaries_secs,
+ config=config,
+ stop_grace_period_secs=stop_grace_period_secs,
+ log_step_count_steps=log_step_count_steps,
+ max_wait_secs=max_wait_secs,
+ save_checkpoint_steps=save_checkpoint_steps,
+ summary_dir=summary_dir)
+
if not is_chief:
session_creator = WorkerSessionCreator(
scaffold=scaffold,
master=master,
config=config,
max_wait_secs=max_wait_secs)
- return MonitoredSession(session_creator=session_creator, hooks=hooks or [],
- stop_grace_period_secs=stop_grace_period_secs)
+ return MonitoredSession(
+ session_creator=session_creator,
+ hooks=hooks or [],
+ stop_grace_period_secs=stop_grace_period_secs)
all_hooks = []
if chief_only_hooks:
@@ -400,25 +479,29 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
if (save_summaries_steps and save_summaries_steps > 0) or (
save_summaries_secs and save_summaries_secs > 0):
- all_hooks.append(basic_session_run_hooks.SummarySaverHook(
- scaffold=scaffold,
- save_steps=save_summaries_steps,
- save_secs=save_summaries_secs,
- output_dir=summary_dir))
+ all_hooks.append(
+ basic_session_run_hooks.SummarySaverHook(
+ scaffold=scaffold,
+ save_steps=save_summaries_steps,
+ save_secs=save_summaries_secs,
+ output_dir=summary_dir))
if checkpoint_dir:
if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
save_checkpoint_steps and save_checkpoint_steps > 0):
- all_hooks.append(basic_session_run_hooks.CheckpointSaverHook(
- checkpoint_dir,
- save_steps=save_checkpoint_steps,
- save_secs=save_checkpoint_secs,
- scaffold=scaffold))
+ all_hooks.append(
+ basic_session_run_hooks.CheckpointSaverHook(
+ checkpoint_dir,
+ save_steps=save_checkpoint_steps,
+ save_secs=save_checkpoint_secs,
+ scaffold=scaffold))
if hooks:
all_hooks.extend(hooks)
- return MonitoredSession(session_creator=session_creator, hooks=all_hooks,
- stop_grace_period_secs=stop_grace_period_secs)
+ return MonitoredSession(
+ session_creator=session_creator,
+ hooks=all_hooks,
+ stop_grace_period_secs=stop_grace_period_secs)
@tf_export('train.SessionCreator')
@@ -546,6 +629,11 @@ class _MonitoredSession(object):
self._hooks = hooks or []
for h in self._hooks:
h.begin()
+
+ worker_context = distribute_coordinator_context.get_current_worker_context()
+ if not session_creator and worker_context:
+ session_creator = worker_context.session_creator()
+
# Create the session.
self._coordinated_creator = self._CoordinatedSessionCreator(
session_creator=session_creator or ChiefSessionCreator(),
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index 92533ca4f3..ff586b6c03 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -32,6 +32,7 @@ from tensorflow.contrib.testing.python.framework import util_test
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import debug_pb2
from tensorflow.python.client import session as session_lib
+from tensorflow.python.distribute import distribute_coordinator
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@@ -381,6 +382,119 @@ class MonitoredTrainingSessionTest(test.TestCase):
self.assertEqual(0, session.run(gstep))
+class MockStrategy(object):
+
+ def __init__(self,
+ between_graph=False,
+ should_init=True,
+ should_checkpoint=None,
+ should_save_summary=None):
+ self._between_graph = between_graph
+ self._should_init = should_init
+ self._should_checkpoint = should_checkpoint
+ self._should_save_summary = should_save_summary
+
+ @property
+ def between_graph(self):
+ return self._between_graph
+
+ @property
+ def should_init(self):
+ return self._should_init
+
+ @property
+ def should_checkpoint(self):
+ return self._should_checkpoint
+
+ @property
+ def should_save_summary(self):
+ return self._should_save_summary
+
+
+class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
+ """Test distribute coordinator controls summary saving and checkpointing."""
+
+ def test_summary_hook_enabled(self):
+ context = distribute_coordinator._WorkerContext(
+ MockStrategy(should_save_summary=True), None, None, None)
+
+ logdir = _test_dir(self.get_temp_dir(), 'test_summaries_enabled')
+ with ops.Graph().as_default():
+ gstep = variables_lib.get_or_create_global_step()
+ new_gstep = state_ops.assign_add(gstep, 1)
+ summary.scalar('my_summary_tag', new_gstep * 2)
+ with context, monitored_session.MonitoredTrainingSession(
+ checkpoint_dir=logdir,
+ save_summaries_steps=100,
+ log_step_count_steps=10) as session:
+ for _ in range(101):
+ session.run(new_gstep)
+
+ summaries = util_test.latest_summaries(logdir)
+ tags = [s.summary.value[0].tag for s in summaries]
+ self.assertIn('my_summary_tag', tags)
+ self.assertIn('global_step/sec', tags)
+
+ def test_summary_hook_disabled(self):
+ context = distribute_coordinator._WorkerContext(
+ MockStrategy(should_save_summary=False), None, None, None)
+
+ logdir = _test_dir(self.get_temp_dir(), 'test_summaries_disabled')
+ with ops.Graph().as_default():
+ gstep = variables_lib.get_or_create_global_step()
+ new_gstep = state_ops.assign_add(gstep, 1)
+ summary.scalar('my_summary_tag', new_gstep * 2)
+ with context, monitored_session.MonitoredTrainingSession(
+ checkpoint_dir=logdir,
+ save_summaries_steps=100,
+ log_step_count_steps=10) as session:
+ for _ in range(101):
+ session.run(new_gstep)
+
+ # No summary is saved.
+ summaries = util_test.latest_summaries(logdir)
+ self.assertEqual(len(summaries), 0)
+
+ def test_checkpoint_hook_enabled(self):
+ context = distribute_coordinator._WorkerContext(
+ MockStrategy(should_checkpoint=True), None, None, None)
+
+ logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_enabled')
+ with ops.Graph().as_default():
+ gstep = variables_lib.get_or_create_global_step()
+ new_gstep = state_ops.assign_add(gstep, 1)
+ with context, monitored_session.MonitoredTrainingSession(
+ checkpoint_dir=logdir,
+ save_checkpoint_steps=100,
+ log_step_count_steps=10) as session:
+ for _ in range(100):
+ session.run(new_gstep)
+
+ # A restart will find the checkpoint and recover automatically.
+ with monitored_session.MonitoredTrainingSession(
+ is_chief=True, checkpoint_dir=logdir) as session:
+ self.assertEqual(100, session.run(gstep))
+
+ def test_checkpoint_hook_disabled(self):
+ context = distribute_coordinator._WorkerContext(
+ MockStrategy(should_checkpoint=False), None, None, None)
+
+ logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled')
+ with ops.Graph().as_default():
+ gstep = variables_lib.get_or_create_global_step()
+ new_gstep = state_ops.assign_add(gstep, 1)
+ with context, monitored_session.MonitoredTrainingSession(
+ checkpoint_dir=logdir,
+ save_checkpoint_steps=100,
+ log_step_count_steps=10) as session:
+ for _ in range(100):
+ session.run(new_gstep)
+
+ # No checkpoint is saved.
+ checkpoint = checkpoint_management.latest_checkpoint(logdir)
+ self.assertIsNone(checkpoint)
+
+
class StopAtNSession(monitored_session._WrappedSession):
"""A wrapped session that stops at the N-th call to _check_stop."""
@@ -1365,8 +1479,8 @@ class MonitoredSessionTest(test.TestCase):
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold,
- checkpoint_filename_with_path=
- checkpoint_management.latest_checkpoint(logdir))) as session:
+ checkpoint_filename_with_path=checkpoint_management.
+ latest_checkpoint(logdir))) as session:
self.assertEqual(2, session.run(gstep))
def test_retry_initialization_on_aborted_error(self):
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 04fce496bd..274c856686 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -809,6 +809,22 @@ class BaseSaverBuilder(object):
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
version=self._write_version)
else:
+ graph = ops.get_default_graph()
+ # Do some sanity checking on collections containing
+ # PartitionedVariables. If a saved collection has a PartitionedVariable,
+ # the GraphDef needs to include concat ops to get the value (or there'll
+ # be a lookup error on load).
+ check_collection_list = graph.get_all_collection_keys()
+ for collection_type in check_collection_list:
+ for element in graph.get_collection(collection_type):
+ if isinstance(element, variables.PartitionedVariable):
+ try:
+ graph.get_operation_by_name(element.name)
+ except KeyError:
+ # Create a concat op for this PartitionedVariable. The user may
+ # not need it, but we'll try looking it up on MetaGraph restore
+ # since it's in a collection.
+ element.as_tensor()
return saver_pb2.SaverDef(
filename_tensor_name=filename_tensor.name,
save_tensor_name=save_tensor.name,
@@ -869,7 +885,7 @@ def _get_saver_or_default():
class Saver(object):
"""Saves and restores variables.
- See @{$variables$Variables}
+ See [Variables](https://tensorflow.org/guide/variables)
for an overview of variables, saving and restoring.
The `Saver` class adds ops to save and restore variables to and from
diff --git a/tensorflow/python/training/sync_replicas_optimizer.py b/tensorflow/python/training/sync_replicas_optimizer.py
index 0c6cf910d1..7afaa92699 100644
--- a/tensorflow/python/training/sync_replicas_optimizer.py
+++ b/tensorflow/python/training/sync_replicas_optimizer.py
@@ -53,7 +53,7 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
which replicas can fetch the new variables and continue.
The following accumulators/queue are created:
- <empty line>
+
* N `gradient accumulators`, one per variable to train. Gradients are pushed
to them and the chief worker will wait until enough gradients are collected
and then average them before applying to variables. The accumulator will
@@ -68,7 +68,7 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
The optimizer adds nodes to the graph to collect gradients and pause the
trainers until variables are updated.
For the Parameter Server job:
- <empty line>
+
1. An accumulator is created for each variable, and each replica pushes the
gradients into the accumulators instead of directly applying them to the
variables.
@@ -81,7 +81,7 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
update its local_step variable and start the next batch.
For the replicas:
- <empty line>
+
1. Start a step: fetch variables and compute gradients.
2. Once the gradients have been computed, push them into gradient
accumulators. Each accumulator will check the staleness and drop the stale.
diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py
index 6f6305a505..686c4be31a 100644
--- a/tensorflow/python/training/training.py
+++ b/tensorflow/python/training/training.py
@@ -15,7 +15,7 @@
"""Support for training models.
-See the @{$python/train} guide.
+See the [Training](https://tensorflow.org/api_guides/python/train) guide.
"""
# Optimizers.
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index faae0d89c3..2968ca9c07 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -62,6 +62,10 @@ def _is_namedtuple(instance, strict=False):
return _pywrap_tensorflow.IsNamedtuple(instance, strict)
+# See the swig file (util.i) for documentation.
+_is_mapping = _pywrap_tensorflow.IsMapping
+
+
def _sequence_like(instance, args):
"""Converts the sequence `args` to the same type as `instance`.
@@ -73,7 +77,7 @@ def _sequence_like(instance, args):
Returns:
`args` with the type of `instance`.
"""
- if isinstance(instance, (dict, _collections.Mapping)):
+ if _is_mapping(instance):
# Pack dictionaries in a deterministic order by sorting the keys.
# Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
@@ -89,7 +93,7 @@ def _sequence_like(instance, args):
def _yield_value(iterable):
- if isinstance(iterable, (dict, _collections.Mapping)):
+ if _is_mapping(iterable):
# Iterate through dictionaries in a deterministic order by sorting the
# keys. Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
@@ -102,53 +106,16 @@ def _yield_value(iterable):
yield value
-def is_sequence(seq):
- """Returns a true if its input is a collections.Sequence (except strings).
+# See the swig file (util.i) for documentation.
+is_sequence = _pywrap_tensorflow.IsSequence
- Args:
- seq: an input sequence.
- Returns:
- True if the sequence is a not a string and is a collections.Sequence or a
- dict.
- """
- return _pywrap_tensorflow.IsSequence(seq)
+# See the swig file (util.i) for documentation.
+flatten = _pywrap_tensorflow.Flatten
-def flatten(nest):
- """Returns a flat list from a given nested structure.
-
- If `nest` is not a sequence, tuple, or dict, then returns a single-element
- list: `[nest]`.
-
- In the case of dict instances, the sequence consists of the values, sorted by
- key to ensure deterministic behavior. This is true also for `OrderedDict`
- instances: their sequence order is ignored, the sorting order of keys is
- used instead. The same convention is followed in `pack_sequence_as`. This
- correctly repacks dicts and `OrderedDict`s after they have been flattened,
- and also allows flattening an `OrderedDict` and then repacking it back using
- a corresponding plain dict, or vice-versa.
- Dictionaries with non-sortable keys cannot be flattened.
-
- Users must not modify any collections used in `nest` while this function is
- running.
-
- Args:
- nest: an arbitrarily nested structure or a scalar object. Note, numpy
- arrays are considered scalars.
-
- Returns:
- A Python list, the flattened version of the input.
-
- Raises:
- TypeError: The nest is or contains a dict with non-sortable keys.
- """
- return _pywrap_tensorflow.Flatten(nest)
-
-
-def _same_namedtuples(nest1, nest2):
- """Returns True if the two namedtuples have the same name and fields."""
- return _pywrap_tensorflow.SameNamedtuples(nest1, nest2)
+# See the swig file (util.i) for documentation.
+_same_namedtuples = _pywrap_tensorflow.SameNamedtuples
def assert_same_structure(nest1, nest2, check_types=True):
@@ -311,14 +278,17 @@ def pack_sequence_as(structure, flat_sequence):
% len(flat_sequence))
return flat_sequence[0]
- flat_structure = flatten(structure)
- if len(flat_structure) != len(flat_sequence):
- raise ValueError(
- "Could not pack sequence. Structure had %d elements, but flat_sequence "
- "had %d elements. Structure: %s, flat_sequence: %s."
- % (len(flat_structure), len(flat_sequence), structure, flat_sequence))
-
- _, packed = _packed_nest_with_indices(structure, flat_sequence, 0)
+ try:
+ final_index, packed = _packed_nest_with_indices(structure, flat_sequence, 0)
+ if final_index < len(flat_sequence):
+ raise IndexError
+ except IndexError:
+ flat_structure = flatten(structure)
+ if len(flat_structure) != len(flat_sequence):
+ raise ValueError(
+ "Could not pack sequence. Structure had %d elements, but "
+ "flat_sequence had %d elements. Structure: %s, flat_sequence: %s." %
+ (len(flat_structure), len(flat_sequence), structure, flat_sequence))
return _sequence_like(structure, packed)
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index ebb72079ef..61249d664b 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -647,6 +647,7 @@ void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) {
}
bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
+bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
PyObject* Flatten(PyObject* nested) {
PyObject* list = PyList_New(0);
diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h
index 41dcc969f8..f15ebb6efe 100644
--- a/tensorflow/python/util/util.h
+++ b/tensorflow/python/util/util.h
@@ -47,6 +47,15 @@ bool IsSequence(PyObject* o);
// True if `instance` is a `namedtuple`.
PyObject* IsNamedtuple(PyObject* o, bool strict);
+// Returns a true if its input is a collections.Mapping.
+//
+// Args:
+// seq: the input to be checked.
+//
+// Returns:
+// True if the sequence subclasses mapping.
+bool IsMapping(PyObject* o);
+
// Implements the same interface as tensorflow.util.nest._same_namedtuples
// Returns Py_True iff the two namedtuples have the same name and fields.
// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have
diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i
index 6ad1484295..8d9f9615d7 100644
--- a/tensorflow/python/util/util.i
+++ b/tensorflow/python/util/util.i
@@ -37,18 +37,70 @@ limitations under the License.
%unignore tensorflow::swig::RegisterSparseTensorValueClass;
%noexception tensorflow::swig::RegisterSparseTensorValueClass;
+%feature("docstring") tensorflow::swig::IsSequence
+"""Returns a true if its input is a collections.Sequence (except strings).
+
+Args:
+ seq: an input sequence.
+
+Returns:
+ True if the sequence is a not a string and is a collections.Sequence or a
+ dict.
+"""
%unignore tensorflow::swig::IsSequence;
%noexception tensorflow::swig::IsSequence;
%unignore tensorflow::swig::IsNamedtuple;
%noexception tensorflow::swig::IsNamedtuple;
+%feature("docstring") tensorflow::swig::IsMapping
+"""Returns True iff `instance` is a `collections.Mapping`.
+
+Args:
+ instance: An instance of a Python object.
+
+Returns:
+ True if `instance` is a `collections.Mapping`.
+"""
+%unignore tensorflow::swig::IsMapping;
+%noexception tensorflow::swig::IsMapping;
+
+%feature("docstring") tensorflow::swig::SameNamedtuples
+"Returns True if the two namedtuples have the same name and fields."
%unignore tensorflow::swig::SameNamedtuples;
%noexception tensorflow::swig::SameNamedtuples;
%unignore tensorflow::swig::AssertSameStructure;
%noexception tensorflow::swig::AssertSameStructure;
+%feature("docstring") tensorflow::swig::Flatten
+"""Returns a flat list from a given nested structure.
+
+If `nest` is not a sequence, tuple, or dict, then returns a single-element
+list: `[nest]`.
+
+In the case of dict instances, the sequence consists of the values, sorted by
+key to ensure deterministic behavior. This is true also for `OrderedDict`
+instances: their sequence order is ignored, the sorting order of keys is
+used instead. The same convention is followed in `pack_sequence_as`. This
+correctly repacks dicts and `OrderedDict`s after they have been flattened,
+and also allows flattening an `OrderedDict` and then repacking it back using
+a corresponding plain dict, or vice-versa.
+Dictionaries with non-sortable keys cannot be flattened.
+
+Users must not modify any collections used in `nest` while this function is
+running.
+
+Args:
+ nest: an arbitrarily nested structure or a scalar object. Note, numpy
+ arrays are considered scalars.
+
+Returns:
+ A Python list, the flattened version of the input.
+
+Raises:
+ TypeError: The nest is or contains a dict with non-sortable keys.
+"""
%unignore tensorflow::swig::Flatten;
%noexception tensorflow::swig::Flatten;
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 2f19147dbb..6d6e8941c5 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -31,12 +31,10 @@ load(
"//third_party/mkl_dnn:build_defs.bzl",
"if_mkl_open_source_only",
)
-
load(
"//third_party/ngraph:build_defs.bzl",
"if_ngraph",
)
-
def register_extension_info(**kwargs):
pass
@@ -398,7 +396,7 @@ def tf_cc_binary(
srcs = srcs + tf_binary_additional_srcs(),
deps = deps + tf_binary_dynamic_kernel_deps(kernels) + if_mkl_ml(
[
- "//third_party/intel_mkl_ml",
+ "//third_party/mkl:intel_binary_blob",
],
),
data = data + tf_binary_dynamic_kernel_dsos(kernels),
@@ -736,7 +734,7 @@ def tf_cc_test(
}) + linkopts + _rpath_linkopts(name),
deps = deps + tf_binary_dynamic_kernel_deps(kernels) + if_mkl_ml(
[
- "//third_party/intel_mkl_ml",
+ "//third_party/mkl:intel_binary_blob",
],
),
data = data + tf_binary_dynamic_kernel_dsos(kernels),
diff --git a/tensorflow/tools/api/golden/BUILD b/tensorflow/tools/api/golden/BUILD
index 1f041ef193..4389a999e7 100644
--- a/tensorflow/tools/api/golden/BUILD
+++ b/tensorflow/tools/api/golden/BUILD
@@ -13,5 +13,5 @@ filegroup(
filegroup(
name = "api_golden_v2",
- srcs = glob(["v1/*.pbtxt"]),
+ srcs = glob(["v2/*.pbtxt"]),
)
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index 4de662fe33..a0d4ecc948 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -789,6 +789,10 @@ tf_module {
argspec: "args=[\'params\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "batch_scatter_update"
+ argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ }
+ member_method {
name: "batch_to_space"
argspec: "args=[\'input\', \'crops\', \'block_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1005,6 +1009,10 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "div_no_nan"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "divide"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
index 9a831fed26..018be7b9f9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
@@ -5,6 +5,10 @@ tf_module {
argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
+ name: "length"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "regex_full_match"
argspec: "args=[\'input\', \'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt
index ef9fe096a1..eb41deee13 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt
@@ -14,5 +14,11 @@ tf_proto {
label: LABEL_OPTIONAL
type: TYPE_BOOL
}
+ field {
+ name: "executor_type"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt
index eeef15515d..e565b903d2 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt
@@ -137,6 +137,12 @@ tf_proto {
label: LABEL_OPTIONAL
type: TYPE_BOOL
}
+ field {
+ name: "executor_type"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
}
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-iterator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-iterator.pbtxt
index 1f9aeb6ad6..4f0147a523 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-iterator.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-iterator.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.data.Iterator"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.Iterator\'>"
+ is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
name: "initializer"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index 9dbb5d16a4..c23b04b4ef 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\'], "
+ argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], "
}
member_method {
name: "eval_dir"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index 34a30c2874..6878d28fff 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\'], "
+ argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], "
}
member_method {
name: "eval_dir"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt
index 5aa4b3d4fb..bf1f94b6ae 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt
@@ -11,6 +11,10 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "eval_distribute"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "evaluation_master"
mtype: "<type \'property\'>"
}
@@ -92,7 +96,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "replace"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt
index 6ec3aba775..5c46dc5ee7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt
@@ -125,6 +125,10 @@ tf_module {
argspec: "args=[\'overlaps\', \'scores\', \'max_output_size\', \'overlap_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], "
}
member_method {
+ name: "non_max_suppression_padded"
+ argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'pad_to_max_output_size\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\'], "
+ }
+ member_method {
name: "pad_to_bounding_box"
argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
index 40e82b18b6..e579fe6a1a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
@@ -135,7 +135,7 @@ tf_class {
}
member_method {
name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
index 65cfad77d1..97688fcb0f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
@@ -140,7 +140,7 @@ tf_class {
}
member_method {
name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "compute_mask"
@@ -267,10 +267,6 @@ tf_class {
argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
- name: "symbolic_set_inputs"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "test_on_batch"
argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt
index 2cd83baf65..2e9de9ebb2 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt
@@ -22,7 +22,7 @@ tf_module {
}
member_method {
name: "relu"
- argspec: "args=[\'x\', \'alpha\', \'max_value\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\'], "
+ argspec: "args=[\'x\', \'alpha\', \'max_value\', \'threshold\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\', \'0\'], "
}
member_method {
name: "selu"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.densenet.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.densenet.pbtxt
deleted file mode 100644
index 42cb914450..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.densenet.pbtxt
+++ /dev/null
@@ -1,23 +0,0 @@
-path: "tensorflow.keras.applications.densenet"
-tf_module {
- member_method {
- name: "DenseNet121"
- argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "DenseNet169"
- argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "DenseNet201"
- argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "decode_predictions"
- argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
- }
- member_method {
- name: "preprocess_input"
- argspec: "args=[\'x\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.inception_resnet_v2.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.inception_resnet_v2.pbtxt
deleted file mode 100644
index 211080c19b..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.inception_resnet_v2.pbtxt
+++ /dev/null
@@ -1,15 +0,0 @@
-path: "tensorflow.keras.applications.inception_resnet_v2"
-tf_module {
- member_method {
- name: "InceptionResNetV2"
- argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "decode_predictions"
- argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
- }
- member_method {
- name: "preprocess_input"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.inception_v3.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.inception_v3.pbtxt
deleted file mode 100644
index b67cee80ab..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.inception_v3.pbtxt
+++ /dev/null
@@ -1,15 +0,0 @@
-path: "tensorflow.keras.applications.inception_v3"
-tf_module {
- member_method {
- name: "InceptionV3"
- argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "decode_predictions"
- argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
- }
- member_method {
- name: "preprocess_input"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.mobilenet.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.mobilenet.pbtxt
deleted file mode 100644
index ef774e1dd7..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.mobilenet.pbtxt
+++ /dev/null
@@ -1,15 +0,0 @@
-path: "tensorflow.keras.applications.mobilenet"
-tf_module {
- member_method {
- name: "MobileNet"
- argspec: "args=[\'input_shape\', \'alpha\', \'depth_multiplier\', \'dropout\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'1.0\', \'1\', \'0.001\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "decode_predictions"
- argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
- }
- member_method {
- name: "preprocess_input"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.nasnet.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.nasnet.pbtxt
deleted file mode 100644
index cd75b87540..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.nasnet.pbtxt
+++ /dev/null
@@ -1,19 +0,0 @@
-path: "tensorflow.keras.applications.nasnet"
-tf_module {
- member_method {
- name: "NASNetLarge"
- argspec: "args=[\'input_shape\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "NASNetMobile"
- argspec: "args=[\'input_shape\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "decode_predictions"
- argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
- }
- member_method {
- name: "preprocess_input"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.pbtxt
deleted file mode 100644
index 9fc086eb8e..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.pbtxt
+++ /dev/null
@@ -1,87 +0,0 @@
-path: "tensorflow.keras.applications"
-tf_module {
- member {
- name: "densenet"
- mtype: "<type \'module\'>"
- }
- member {
- name: "inception_resnet_v2"
- mtype: "<type \'module\'>"
- }
- member {
- name: "inception_v3"
- mtype: "<type \'module\'>"
- }
- member {
- name: "mobilenet"
- mtype: "<type \'module\'>"
- }
- member {
- name: "nasnet"
- mtype: "<type \'module\'>"
- }
- member {
- name: "resnet50"
- mtype: "<type \'module\'>"
- }
- member {
- name: "vgg16"
- mtype: "<type \'module\'>"
- }
- member {
- name: "vgg19"
- mtype: "<type \'module\'>"
- }
- member {
- name: "xception"
- mtype: "<type \'module\'>"
- }
- member_method {
- name: "DenseNet121"
- argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "DenseNet169"
- argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "DenseNet201"
- argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "InceptionResNetV2"
- argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "InceptionV3"
- argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "MobileNet"
- argspec: "args=[\'input_shape\', \'alpha\', \'depth_multiplier\', \'dropout\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'1.0\', \'1\', \'0.001\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "NASNetLarge"
- argspec: "args=[\'input_shape\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "NASNetMobile"
- argspec: "args=[\'input_shape\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "ResNet50"
- argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "VGG16"
- argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "VGG19"
- argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "Xception"
- argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.resnet50.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.resnet50.pbtxt
deleted file mode 100644
index 7385af064d..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.resnet50.pbtxt
+++ /dev/null
@@ -1,15 +0,0 @@
-path: "tensorflow.keras.applications.resnet50"
-tf_module {
- member_method {
- name: "ResNet50"
- argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "decode_predictions"
- argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
- }
- member_method {
- name: "preprocess_input"
- argspec: "args=[\'x\', \'data_format\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'caffe\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.vgg16.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.vgg16.pbtxt
deleted file mode 100644
index ba66fba8f3..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.vgg16.pbtxt
+++ /dev/null
@@ -1,15 +0,0 @@
-path: "tensorflow.keras.applications.vgg16"
-tf_module {
- member_method {
- name: "VGG16"
- argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "decode_predictions"
- argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
- }
- member_method {
- name: "preprocess_input"
- argspec: "args=[\'x\', \'data_format\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'caffe\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.vgg19.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.vgg19.pbtxt
deleted file mode 100644
index e55a1345b6..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.vgg19.pbtxt
+++ /dev/null
@@ -1,15 +0,0 @@
-path: "tensorflow.keras.applications.vgg19"
-tf_module {
- member_method {
- name: "VGG19"
- argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "decode_predictions"
- argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
- }
- member_method {
- name: "preprocess_input"
- argspec: "args=[\'x\', \'data_format\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'caffe\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.xception.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.xception.pbtxt
deleted file mode 100644
index 59dd2108f2..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.xception.pbtxt
+++ /dev/null
@@ -1,15 +0,0 @@
-path: "tensorflow.keras.applications.xception"
-tf_module {
- member_method {
- name: "Xception"
- argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
- }
- member_method {
- name: "decode_predictions"
- argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
- }
- member_method {
- name: "preprocess_input"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt
index fddac63b78..126ce8db6a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt
@@ -366,7 +366,7 @@ tf_module {
}
member_method {
name: "relu"
- argspec: "args=[\'x\', \'alpha\', \'max_value\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\'], "
+ argspec: "args=[\'x\', \'alpha\', \'max_value\', \'threshold\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\', \'0\'], "
}
member_method {
name: "repeat"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt
index 5d05cf689f..2dff7a6de4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt
@@ -118,7 +118,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt
index f754fa1da8..ff19dcc3a3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt
@@ -82,7 +82,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'implementation\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'1\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt
index c9516b8f07..3c278fead6 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt
@@ -82,7 +82,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'implementation\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'1\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt
index c00fa79adf..4d3de58bd1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt
@@ -82,7 +82,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'max_value\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'max_value\', \'negative_slope\', \'threshold\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'0\', \'0\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
index 1160d2840f..6718e36dc6 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
@@ -61,6 +61,10 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "output_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "state_size"
mtype: "<type \'property\'>"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
index 85f7c2bfed..56914e1746 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
@@ -135,7 +135,7 @@ tf_class {
}
member_method {
name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
index 6a83129f7d..acfb3521c0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
@@ -140,7 +140,7 @@ tf_class {
}
member_method {
name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "compute_mask"
@@ -267,10 +267,6 @@ tf_class {
argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
- name: "symbolic_set_inputs"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "test_on_batch"
argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt
deleted file mode 100644
index dddace87dc..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt
+++ /dev/null
@@ -1,23 +0,0 @@
-path: "tensorflow.keras.preprocessing.image.DirectoryIterator"
-tf_class {
- is_instance: "<class \'tensorflow.python.keras.preprocessing.image.DirectoryIterator\'>"
- is_instance: "<class \'tensorflow.python.keras.preprocessing.image.Iterator\'>"
- is_instance: "<class \'tensorflow.python.keras.utils.data_utils.Sequence\'>"
- is_instance: "<type \'object\'>"
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'directory\', \'image_data_generator\', \'target_size\', \'color_mode\', \'classes\', \'class_mode\', \'batch_size\', \'shuffle\', \'seed\', \'data_format\', \'save_to_dir\', \'save_prefix\', \'save_format\', \'follow_links\', \'subset\', \'interpolation\'], varargs=None, keywords=None, defaults=[\'(256, 256)\', \'rgb\', \'None\', \'categorical\', \'32\', \'True\', \'None\', \'None\', \'None\', \'\', \'png\', \'False\', \'None\', \'nearest\'], "
- }
- member_method {
- name: "next"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "on_epoch_end"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "reset"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt
deleted file mode 100644
index c1e2e94f0b..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt
+++ /dev/null
@@ -1,29 +0,0 @@
-path: "tensorflow.keras.preprocessing.image.ImageDataGenerator"
-tf_class {
- is_instance: "<class \'tensorflow.python.keras.preprocessing.image.ImageDataGenerator\'>"
- is_instance: "<type \'object\'>"
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'featurewise_center\', \'samplewise_center\', \'featurewise_std_normalization\', \'samplewise_std_normalization\', \'zca_whitening\', \'zca_epsilon\', \'rotation_range\', \'width_shift_range\', \'height_shift_range\', \'brightness_range\', \'shear_range\', \'zoom_range\', \'channel_shift_range\', \'fill_mode\', \'cval\', \'horizontal_flip\', \'vertical_flip\', \'rescale\', \'preprocessing_function\', \'data_format\', \'validation_split\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'1e-06\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'0.0\', \'0.0\', \'0.0\', \'nearest\', \'0.0\', \'False\', \'False\', \'None\', \'None\', \'None\', \'0.0\'], "
- }
- member_method {
- name: "fit"
- argspec: "args=[\'self\', \'x\', \'augment\', \'rounds\', \'seed\'], varargs=None, keywords=None, defaults=[\'False\', \'1\', \'None\'], "
- }
- member_method {
- name: "flow"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'shuffle\', \'seed\', \'save_to_dir\', \'save_prefix\', \'save_format\', \'subset\'], varargs=None, keywords=None, defaults=[\'None\', \'32\', \'True\', \'None\', \'None\', \'\', \'png\', \'None\'], "
- }
- member_method {
- name: "flow_from_directory"
- argspec: "args=[\'self\', \'directory\', \'target_size\', \'color_mode\', \'classes\', \'class_mode\', \'batch_size\', \'shuffle\', \'seed\', \'save_to_dir\', \'save_prefix\', \'save_format\', \'follow_links\', \'subset\', \'interpolation\'], varargs=None, keywords=None, defaults=[\'(256, 256)\', \'rgb\', \'None\', \'categorical\', \'32\', \'True\', \'None\', \'None\', \'\', \'png\', \'False\', \'None\', \'nearest\'], "
- }
- member_method {
- name: "random_transform"
- argspec: "args=[\'self\', \'x\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "standardize"
- argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-iterator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-iterator.pbtxt
deleted file mode 100644
index 825d9f1d1d..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-iterator.pbtxt
+++ /dev/null
@@ -1,18 +0,0 @@
-path: "tensorflow.keras.preprocessing.image.Iterator"
-tf_class {
- is_instance: "<class \'tensorflow.python.keras.preprocessing.image.Iterator\'>"
- is_instance: "<class \'tensorflow.python.keras.utils.data_utils.Sequence\'>"
- is_instance: "<type \'object\'>"
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'n\', \'batch_size\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "on_epoch_end"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "reset"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt
deleted file mode 100644
index 75924a254a..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt
+++ /dev/null
@@ -1,23 +0,0 @@
-path: "tensorflow.keras.preprocessing.image.NumpyArrayIterator"
-tf_class {
- is_instance: "<class \'tensorflow.python.keras.preprocessing.image.NumpyArrayIterator\'>"
- is_instance: "<class \'tensorflow.python.keras.preprocessing.image.Iterator\'>"
- is_instance: "<class \'tensorflow.python.keras.utils.data_utils.Sequence\'>"
- is_instance: "<type \'object\'>"
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'x\', \'y\', \'image_data_generator\', \'batch_size\', \'shuffle\', \'seed\', \'data_format\', \'save_to_dir\', \'save_prefix\', \'save_format\', \'subset\'], varargs=None, keywords=None, defaults=[\'32\', \'False\', \'None\', \'None\', \'None\', \'\', \'png\', \'None\'], "
- }
- member_method {
- name: "next"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "on_epoch_end"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "reset"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.pbtxt
deleted file mode 100644
index 6b850dd6b7..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.pbtxt
+++ /dev/null
@@ -1,63 +0,0 @@
-path: "tensorflow.keras.preprocessing.image"
-tf_module {
- member {
- name: "DirectoryIterator"
- mtype: "<type \'type\'>"
- }
- member {
- name: "ImageDataGenerator"
- mtype: "<type \'type\'>"
- }
- member {
- name: "Iterator"
- mtype: "<type \'type\'>"
- }
- member {
- name: "NumpyArrayIterator"
- mtype: "<type \'type\'>"
- }
- member_method {
- name: "apply_transform"
- argspec: "args=[\'x\', \'transform_matrix\', \'channel_axis\', \'fill_mode\', \'cval\'], varargs=None, keywords=None, defaults=[\'0\', \'nearest\', \'0.0\'], "
- }
- member_method {
- name: "array_to_img"
- argspec: "args=[\'x\', \'data_format\', \'scale\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
- }
- member_method {
- name: "flip_axis"
- argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "img_to_array"
- argspec: "args=[\'img\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "load_img"
- argspec: "args=[\'path\', \'grayscale\', \'target_size\', \'interpolation\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'nearest\'], "
- }
- member_method {
- name: "random_brightness"
- argspec: "args=[\'x\', \'brightness_range\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "random_channel_shift"
- argspec: "args=[\'x\', \'intensity\', \'channel_axis\'], varargs=None, keywords=None, defaults=[\'0\'], "
- }
- member_method {
- name: "random_rotation"
- argspec: "args=[\'x\', \'rg\', \'row_axis\', \'col_axis\', \'channel_axis\', \'fill_mode\', \'cval\'], varargs=None, keywords=None, defaults=[\'1\', \'2\', \'0\', \'nearest\', \'0.0\'], "
- }
- member_method {
- name: "random_shear"
- argspec: "args=[\'x\', \'intensity\', \'row_axis\', \'col_axis\', \'channel_axis\', \'fill_mode\', \'cval\'], varargs=None, keywords=None, defaults=[\'1\', \'2\', \'0\', \'nearest\', \'0.0\'], "
- }
- member_method {
- name: "random_shift"
- argspec: "args=[\'x\', \'wrg\', \'hrg\', \'row_axis\', \'col_axis\', \'channel_axis\', \'fill_mode\', \'cval\'], varargs=None, keywords=None, defaults=[\'1\', \'2\', \'0\', \'nearest\', \'0.0\'], "
- }
- member_method {
- name: "random_zoom"
- argspec: "args=[\'x\', \'zoom_range\', \'row_axis\', \'col_axis\', \'channel_axis\', \'fill_mode\', \'cval\'], varargs=None, keywords=None, defaults=[\'1\', \'2\', \'0\', \'nearest\', \'0.0\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.pbtxt
deleted file mode 100644
index 5a78581fc5..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.pbtxt
+++ /dev/null
@@ -1,15 +0,0 @@
-path: "tensorflow.keras.preprocessing"
-tf_module {
- member {
- name: "image"
- mtype: "<type \'module\'>"
- }
- member {
- name: "sequence"
- mtype: "<type \'module\'>"
- }
- member {
- name: "text"
- mtype: "<type \'module\'>"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.sequence.-timeseries-generator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.sequence.-timeseries-generator.pbtxt
deleted file mode 100644
index 326b1fa4fd..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.sequence.-timeseries-generator.pbtxt
+++ /dev/null
@@ -1,14 +0,0 @@
-path: "tensorflow.keras.preprocessing.sequence.TimeseriesGenerator"
-tf_class {
- is_instance: "<class \'tensorflow.python.keras.preprocessing.sequence.TimeseriesGenerator\'>"
- is_instance: "<class \'tensorflow.python.keras.utils.data_utils.Sequence\'>"
- is_instance: "<type \'object\'>"
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'data\', \'targets\', \'length\', \'sampling_rate\', \'stride\', \'start_index\', \'end_index\', \'shuffle\', \'reverse\', \'batch_size\'], varargs=None, keywords=None, defaults=[\'1\', \'1\', \'0\', \'None\', \'False\', \'False\', \'128\'], "
- }
- member_method {
- name: "on_epoch_end"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.sequence.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.sequence.pbtxt
deleted file mode 100644
index cf59f8a272..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.sequence.pbtxt
+++ /dev/null
@@ -1,19 +0,0 @@
-path: "tensorflow.keras.preprocessing.sequence"
-tf_module {
- member {
- name: "TimeseriesGenerator"
- mtype: "<type \'type\'>"
- }
- member_method {
- name: "make_sampling_table"
- argspec: "args=[\'size\', \'sampling_factor\'], varargs=None, keywords=None, defaults=[\'1e-05\'], "
- }
- member_method {
- name: "pad_sequences"
- argspec: "args=[\'sequences\', \'maxlen\', \'dtype\', \'padding\', \'truncating\', \'value\'], varargs=None, keywords=None, defaults=[\'None\', \'int32\', \'pre\', \'pre\', \'0.0\'], "
- }
- member_method {
- name: "skipgrams"
- argspec: "args=[\'sequence\', \'vocabulary_size\', \'window_size\', \'negative_samples\', \'shuffle\', \'categorical\', \'sampling_table\', \'seed\'], varargs=None, keywords=None, defaults=[\'4\', \'1.0\', \'True\', \'False\', \'None\', \'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt
deleted file mode 100644
index b42b12b6c0..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt
+++ /dev/null
@@ -1,33 +0,0 @@
-path: "tensorflow.keras.preprocessing.text.Tokenizer"
-tf_class {
- is_instance: "<class \'tensorflow.python.keras.preprocessing.text.Tokenizer\'>"
- is_instance: "<type \'object\'>"
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'num_words\', \'filters\', \'lower\', \'split\', \'char_level\', \'oov_token\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n\', \'True\', \' \', \'False\', \'None\'], "
- }
- member_method {
- name: "fit_on_sequences"
- argspec: "args=[\'self\', \'sequences\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "fit_on_texts"
- argspec: "args=[\'self\', \'texts\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "sequences_to_matrix"
- argspec: "args=[\'self\', \'sequences\', \'mode\'], varargs=None, keywords=None, defaults=[\'binary\'], "
- }
- member_method {
- name: "texts_to_matrix"
- argspec: "args=[\'self\', \'texts\', \'mode\'], varargs=None, keywords=None, defaults=[\'binary\'], "
- }
- member_method {
- name: "texts_to_sequences"
- argspec: "args=[\'self\', \'texts\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "texts_to_sequences_generator"
- argspec: "args=[\'self\', \'texts\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.text.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.text.pbtxt
deleted file mode 100644
index 50b54fc7e1..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.text.pbtxt
+++ /dev/null
@@ -1,19 +0,0 @@
-path: "tensorflow.keras.preprocessing.text"
-tf_module {
- member {
- name: "Tokenizer"
- mtype: "<type \'type\'>"
- }
- member_method {
- name: "hashing_trick"
- argspec: "args=[\'text\', \'n\', \'hash_function\', \'filters\', \'lower\', \'split\'], varargs=None, keywords=None, defaults=[\'None\', \'!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n\', \'True\', \' \'], "
- }
- member_method {
- name: "one_hot"
- argspec: "args=[\'text\', \'n\', \'filters\', \'lower\', \'split\'], varargs=None, keywords=None, defaults=[\'!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n\', \'True\', \' \'], "
- }
- member_method {
- name: "text_to_word_sequence"
- argspec: "args=[\'text\', \'filters\', \'lower\', \'split\'], varargs=None, keywords=None, defaults=[\'!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n\', \'True\', \' \'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
index c74773000a..e606eab919 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
@@ -101,7 +101,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'num_units\', \'forget_bias\', \'state_is_tuple\', \'activation\', \'reuse\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'num_units\', \'forget_bias\', \'state_is_tuple\', \'activation\', \'reuse\', \'name\', \'dtype\'], varargs=None, keywords=kwargs, defaults=[\'1.0\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -125,7 +125,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'inputs_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
index d251f54806..5deb02d569 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
@@ -101,7 +101,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\', \'name\', \'dtype\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -125,7 +125,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'inputs_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
index d76eab7eb8..32fa151a8e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
@@ -101,7 +101,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\', \'kernel_initializer\', \'bias_initializer\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\', \'kernel_initializer\', \'bias_initializer\', \'name\', \'dtype\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -125,7 +125,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'inputs_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
index 944db6ac93..30c6c2ce3b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
@@ -101,7 +101,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'num_units\', \'use_peepholes\', \'cell_clip\', \'initializer\', \'num_proj\', \'proj_clip\', \'num_unit_shards\', \'num_proj_shards\', \'forget_bias\', \'state_is_tuple\', \'activation\', \'reuse\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'1.0\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'num_units\', \'use_peepholes\', \'cell_clip\', \'initializer\', \'num_proj\', \'proj_clip\', \'num_unit_shards\', \'num_proj_shards\', \'forget_bias\', \'state_is_tuple\', \'activation\', \'reuse\', \'name\', \'dtype\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'1.0\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -125,7 +125,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'inputs_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index 5eb42b4db3..a0d4ecc948 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -785,6 +785,14 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "batch_gather"
+ argspec: "args=[\'params\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "batch_scatter_update"
+ argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ }
+ member_method {
name: "batch_to_space"
argspec: "args=[\'input\', \'crops\', \'block_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1001,6 +1009,10 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "div_no_nan"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "divide"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1902,19 +1914,19 @@ tf_module {
}
member_method {
name: "sparse_reduce_max"
- argspec: "args=[\'sp_input\', \'axis\', \'keep_dims\', \'reduction_axes\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'sp_input\', \'axis\', \'keepdims\', \'reduction_axes\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sparse_reduce_max_sparse"
- argspec: "args=[\'sp_input\', \'axis\', \'keep_dims\', \'reduction_axes\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'sp_input\', \'axis\', \'keepdims\', \'reduction_axes\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sparse_reduce_sum"
- argspec: "args=[\'sp_input\', \'axis\', \'keep_dims\', \'reduction_axes\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'sp_input\', \'axis\', \'keepdims\', \'reduction_axes\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sparse_reduce_sum_sparse"
- argspec: "args=[\'sp_input\', \'axis\', \'keep_dims\', \'reduction_axes\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'sp_input\', \'axis\', \'keepdims\', \'reduction_axes\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sparse_reorder"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
index 9a831fed26..018be7b9f9 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
@@ -5,6 +5,10 @@ tf_module {
argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
+ name: "length"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "regex_full_match"
argspec: "args=[\'input\', \'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt
index 871ebb5247..7ed9cd77a0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt
@@ -50,7 +50,7 @@ tf_module {
}
member_method {
name: "merge_all"
- argspec: "args=[\'key\', \'scope\'], varargs=None, keywords=None, defaults=[\'summaries\', \'None\'], "
+ argspec: "args=[\'key\', \'scope\', \'name\'], varargs=None, keywords=None, defaults=[\'summaries\', \'None\', \'None\'], "
}
member_method {
name: "scalar"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt
index 2d067e4eff..5be37200f3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt
@@ -20,4 +20,8 @@ tf_class {
name: "save"
argspec: "args=[\'self\', \'file_prefix\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
+ member_method {
+ name: "write"
+ argspec: "args=[\'self\', \'file_prefix\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
index b0fb04d7d4..9f35395284 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
@@ -298,7 +298,7 @@ tf_module {
}
member_method {
name: "generate_checkpoint_state_proto"
- argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\', \'all_model_checkpoint_timestamps\', \'last_preserved_timestamp\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "get_checkpoint_mtimes"
@@ -446,7 +446,7 @@ tf_module {
}
member_method {
name: "update_checkpoint_state"
- argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\', \'latest_filename\', \'all_model_checkpoint_timestamps\', \'last_preserved_timestamp\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "warm_start"
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index b65dbc4b7d..43d19bc99c 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -79,7 +79,7 @@ def _KeyToFilePath(key, api_version):
case_insensitive_key = re.sub('([A-Z]{1})', _ReplaceCapsWithDash, key)
api_folder = (
_API_GOLDEN_FOLDER_V2 if api_version == 2 else _API_GOLDEN_FOLDER_V1)
- return os.path.join(_API_GOLDEN_FOLDER_V1, '%s.pbtxt' % case_insensitive_key)
+ return os.path.join(api_folder, '%s.pbtxt' % case_insensitive_key)
def _FileNameToKey(filename):
@@ -315,7 +315,7 @@ class ApiCompatibilityTest(test.TestCase):
def testAPIBackwardsCompatibilityV2(self):
if not hasattr(tf.compat, 'v2'):
return
- api_version = 1
+ api_version = 2
golden_file_pattern = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version))
diff --git a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
index 47e0e5dd59..5d0a8efc69 100644
--- a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
@@ -68,6 +68,7 @@ TEST_TARGET="//${PY_TEST_DIR}/tensorflow/python/... \
# --test_contrib_only Use tensorflow/contrib/... as test target
for ARG in "$@"; do
case "$ARG" in
+ --tf_nightly) TF_NIGHTLY=1 ;;
--skip_test) SKIP_TEST=1 ;;
--enable_remote_cache) set_remote_cache_options ;;
--release_build) RELEASE_BUILD=1 ;;
@@ -86,6 +87,11 @@ else
export TF_OVERRIDE_EIGEN_STRONG_INLINE=1
fi
+if [[ "$TF_NIGHTLY" == 1 ]]; then
+ python tensorflow/tools/ci_build/update_version.py --nightly
+ EXTRA_PIP_FLAG="--nightly_flag"
+fi
+
# Enable short object file path to avoid long path issue on Windows.
echo "startup --output_user_root=${TMPDIR}" >> "${TMP_BAZELRC}"
@@ -104,7 +110,11 @@ fi
# Create a python test directory to avoid package name conflict
create_python_test_dir "${PY_TEST_DIR}"
-./bazel-bin/tensorflow/tools/pip_package/build_pip_package "$PWD/${PY_TEST_DIR}"
+./bazel-bin/tensorflow/tools/pip_package/build_pip_package "$PWD/${PY_TEST_DIR}" "${EXTRA_PIP_FLAG}"
+
+if [[ "$TF_NIGHTLY" == 1 ]]; then
+ exit 0
+fi
# Running python tests on Windows needs pip package installed
PIP_NAME=$(ls ${PY_TEST_DIR}/tensorflow-*.whl)
diff --git a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
index e3eee11080..7ac07872e9 100644
--- a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
@@ -68,6 +68,7 @@ TEST_TARGET="//${PY_TEST_DIR}/tensorflow/python/... \
# --test_contrib_only Use tensorflow/contrib/... as test target
for ARG in "$@"; do
case "$ARG" in
+ --tf_nightly) TF_NIGHTLY=1 ;;
--skip_test) SKIP_TEST=1 ;;
--enable_remote_cache) set_remote_cache_options ;;
--release_build) RELEASE_BUILD=1 ;;
@@ -86,6 +87,11 @@ else
export TF_OVERRIDE_EIGEN_STRONG_INLINE=1
fi
+if [[ "$TF_NIGHTLY" == 1 ]]; then
+ python tensorflow/tools/ci_build/update_version.py --nightly
+ EXTRA_PIP_FLAG="--nightly_flag"
+fi
+
# Enable short object file path to avoid long path issue on Windows.
echo "startup --output_user_root=${TMPDIR}" >> "${TMP_BAZELRC}"
@@ -107,7 +113,11 @@ fi
# Create a python test directory to avoid package name conflict
create_python_test_dir "${PY_TEST_DIR}"
-./bazel-bin/tensorflow/tools/pip_package/build_pip_package "$PWD/${PY_TEST_DIR}"
+./bazel-bin/tensorflow/tools/pip_package/build_pip_package "$PWD/${PY_TEST_DIR}" --gpu "${EXTRA_PIP_FLAG}"
+
+if [[ "$TF_NIGHTLY" == 1 ]]; then
+ exit 0
+fi
# Running python tests on Windows needs pip package installed
PIP_NAME=$(ls ${PY_TEST_DIR}/tensorflow-*.whl)
diff --git a/tensorflow/tools/common/public_api.py b/tensorflow/tools/common/public_api.py
index 09933d266b..82bb0713c4 100644
--- a/tensorflow/tools/common/public_api.py
+++ b/tensorflow/tools/common/public_api.py
@@ -102,9 +102,10 @@ class PublicAPIVisitor(object):
"""Override the default root name of 'tf'."""
self._root_name = root_name
- def _is_private(self, path, name):
+ def _is_private(self, path, name, obj=None):
"""Return whether a name is private."""
# TODO(wicke): Find out what names to exclude.
+ del obj # Unused.
return ((path in self._private_map and
name in self._private_map[path]) or
(name.startswith('_') and not re.match('__.*__$', name) or
@@ -129,7 +130,7 @@ class PublicAPIVisitor(object):
# Remove things that are not visible.
for name, child in list(children):
- if self._is_private(full_path, name):
+ if self._is_private(full_path, name, child):
children.remove((name, child))
self._visitor(path, parent, children)
diff --git a/tensorflow/tools/docker/parameterized_docker_build.sh b/tensorflow/tools/docker/parameterized_docker_build.sh
index c9f17a8242..448a3a7647 100755
--- a/tensorflow/tools/docker/parameterized_docker_build.sh
+++ b/tensorflow/tools/docker/parameterized_docker_build.sh
@@ -387,7 +387,7 @@ else # TF_DOCKER_BUILD_IS_DEVEL == 'yes'
TF_DOCKER_BUILD_ARGS+=("--build-arg PIP=pip3")
cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}"
else
- if [[ "${TF_DOCKER_BUILD_TYPE}" != "mkl" ]]; then
+ if [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3.6" ]] && [[ "${TF_DOCKER_BUILD_TYPE}" != "mkl" ]]; then
die "Python 3.6 build only supported for MKL builds."
fi
if sed -i -e 's/python-dev/python-dev python3-dev/g' "${DOCKERFILE}" && \
diff --git a/tensorflow/tools/docs/generate.py b/tensorflow/tools/docs/generate.py
index f96887e4c7..fc93085e3e 100644
--- a/tensorflow/tools/docs/generate.py
+++ b/tensorflow/tools/docs/generate.py
@@ -31,11 +31,6 @@ if __name__ == '__main__':
doc_generator = generate_lib.DocGenerator()
doc_generator.add_output_dir_argument()
doc_generator.add_src_dir_argument()
- doc_generator.argument_parser.add_argument(
- '--site_api_path',
- type=str, default='api_docs/python',
- help='The path from the site-root to api_docs'
- 'directory for this project')
# This doc generator works on the TensorFlow codebase. Since this script lives
# at tensorflow/tools/docs, and all code is defined somewhere inside
diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py
index 9387042224..653e46fc41 100644
--- a/tensorflow/tools/docs/generate_lib.py
+++ b/tensorflow/tools/docs/generate_lib.py
@@ -22,6 +22,7 @@ import argparse
import fnmatch
import os
import shutil
+import tempfile
import six
@@ -57,7 +58,7 @@ def write_docs(output_dir,
yaml_toc,
root_title='TensorFlow',
search_hints=True,
- site_api_path=None):
+ site_api_path=''):
"""Write previously extracted docs to disk.
Write a docs page for each symbol included in the indices of parser_config to
@@ -75,8 +76,8 @@ def write_docs(output_dir,
root_title: The title name for the root level index.md.
search_hints: (bool) include meta-data search hints at the top of each
output file.
- site_api_path: Used to write the api-duplicates _redirects.yaml file. if
- None (the default) the file is not generated.
+ site_api_path: The output path relative to the site root. Used in the
+ `_toc.yaml` and `_redirects.yaml` files.
Raises:
ValueError: if `output_dir` is not an absolute path
@@ -111,9 +112,6 @@ def write_docs(output_dir,
_is_free_function(py_object, full_name, parser_config.index)):
continue
- if doc_controls.should_skip(py_object):
- continue
-
sitepath = os.path.join('api_docs/python',
parser.documentation_path(full_name)[:-3])
@@ -160,27 +158,27 @@ def write_docs(output_dir,
raise OSError(
'Cannot write documentation for %s to %s' % (full_name, directory))
- if site_api_path:
- duplicates = parser_config.duplicates.get(full_name, [])
- if not duplicates:
- continue
+ duplicates = parser_config.duplicates.get(full_name, [])
+ if not duplicates:
+ continue
- duplicates = [item for item in duplicates if item != full_name]
+ duplicates = [item for item in duplicates if item != full_name]
- for dup in duplicates:
- from_path = os.path.join(site_api_path, dup.replace('.', '/'))
- to_path = os.path.join(site_api_path, full_name.replace('.', '/'))
- redirects.append((from_path, to_path))
+ for dup in duplicates:
+ from_path = os.path.join(site_api_path, dup.replace('.', '/'))
+ to_path = os.path.join(site_api_path, full_name.replace('.', '/'))
+ redirects.append((
+ os.path.join('/', from_path),
+ os.path.join('/', to_path)))
- if site_api_path and redirects:
- redirects = sorted(redirects)
- template = ('- from: /{}\n'
- ' to: /{}\n')
- redirects = [template.format(f, t) for f, t in redirects]
- api_redirects_path = os.path.join(output_dir, '_redirects.yaml')
- with open(api_redirects_path, 'w') as redirect_file:
- redirect_file.write('redirects:\n')
- redirect_file.write(''.join(redirects))
+ redirects = sorted(redirects)
+ template = ('- from: {}\n'
+ ' to: {}\n')
+ redirects = [template.format(f, t) for f, t in redirects]
+ api_redirects_path = os.path.join(output_dir, '_redirects.yaml')
+ with open(api_redirects_path, 'w') as redirect_file:
+ redirect_file.write('redirects:\n')
+ redirect_file.write(''.join(redirects))
if yaml_toc:
# Generate table of contents
@@ -210,7 +208,8 @@ def write_docs(output_dir,
'- title: ' + title,
' section:',
' - title: Overview',
- ' path: /TARGET_DOC_ROOT/VERSION/' + symbol_to_file[module]]
+ ' path: ' + os.path.join('/', site_api_path,
+ symbol_to_file[module])]
header = ''.join([indent+line+'\n' for line in header])
f.write(header)
@@ -221,7 +220,8 @@ def write_docs(output_dir,
for full_name in symbols_in_module:
item = [
' - title: ' + full_name[len(module) + 1:],
- ' path: /TARGET_DOC_ROOT/VERSION/' + symbol_to_file[full_name]]
+ ' path: ' + os.path.join('/', site_api_path,
+ symbol_to_file[full_name])]
item = ''.join([indent+line+'\n' for line in item])
f.write(item)
@@ -295,6 +295,15 @@ def _get_default_do_not_descend_map():
}
+class DocControlsAwareCrawler(public_api.PublicAPIVisitor):
+ """A `docs_controls` aware API-crawler."""
+
+ def _is_private(self, path, name, obj):
+ if doc_controls.should_skip(obj):
+ return True
+ return super(DocControlsAwareCrawler, self)._is_private(path, name, obj)
+
+
def extract(py_modules,
private_map,
do_not_descend_map,
@@ -302,7 +311,7 @@ def extract(py_modules,
"""Extract docs from tf namespace and write them to disk."""
# Traverse the first module.
visitor = visitor_cls(py_modules[0][0])
- api_visitor = public_api.PublicAPIVisitor(visitor)
+ api_visitor = DocControlsAwareCrawler(visitor)
api_visitor.set_root_name(py_modules[0][0])
add_dict_to_dict(private_map, api_visitor.private_map)
add_dict_to_dict(do_not_descend_map, api_visitor.do_not_descend_map)
@@ -532,6 +541,12 @@ class DocGenerator(object):
action='store_false',
default=True)
+ self.argument_parser.add_argument(
+ '--site_api_path',
+ type=str, default='',
+ help='The path from the site-root to api_docs'
+ 'directory for this project')
+
def add_output_dir_argument(self):
self.argument_parser.add_argument(
'--output_dir',
@@ -544,9 +559,9 @@ class DocGenerator(object):
self.argument_parser.add_argument(
'--src_dir',
type=str,
- default=None,
- required=True,
- help='Directory with the source docs.')
+ default=tempfile.mkdtemp(),
+ required=False,
+ help='Optional directory of source docs to add api_docs links to')
def add_base_dir_argument(self, default_base_dir):
self.argument_parser.add_argument(
@@ -648,7 +663,7 @@ class DocGenerator(object):
yaml_toc=self.yaml_toc,
root_title=root_title,
search_hints=getattr(flags, 'search_hints', True),
- site_api_path=getattr(flags, 'site_api_path', None))
+ site_api_path=getattr(flags, 'site_api_path', ''))
# Replace all the @{} references in files under `FLAGS.src_dir`
replace_refs(flags.src_dir, flags.output_dir, reference_resolver, '*.md')
diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py
index 801c8bcb4a..8e444a15cf 100644
--- a/tensorflow/tools/docs/parser.py
+++ b/tensorflow/tools/docs/parser.py
@@ -1695,15 +1695,18 @@ class _Metadata(object):
Attributes:
name: The name of the page being described by the Metadata block.
+ version: The source version.
"""
- def __init__(self, name):
+ def __init__(self, name, version='stable'):
"""Creates a Metadata builder.
Args:
name: The name of the page being described by the Metadata block.
+ version: The source version.
"""
self.name = name
+ self.version = version
self._content = []
def append(self, item):
@@ -1720,6 +1723,7 @@ class _Metadata(object):
parts = ['<div itemscope itemtype="%s">' % schema]
parts.append('<meta itemprop="name" content="%s" />' % self.name)
+ parts.append('<meta itemprop="path" content="%s" />' % self.version)
for item in self._content:
parts.append('<meta itemprop="property" content="%s"/>' % item)
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 6bba139b4d..1a4679c8a3 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -200,21 +200,30 @@ filegroup(
"@grpc//third_party/nanopb:LICENSE.txt",
"@grpc//third_party/address_sorting:LICENSE",
],
- ) + tf_additional_license_deps(),
+ ) + if_not_windows([
+ "@ngraph//:LICENSE",
+ "@ngraph_tf//:LICENSE",
+ "@nlohmann_json_lib//:LICENSE",
+ ]) + tf_additional_license_deps(),
)
sh_binary(
name = "build_pip_package",
srcs = ["build_pip_package.sh"],
data = select({
- "//tensorflow:windows": [":simple_console_for_windows"],
+ "//tensorflow:windows": [
+ ":simple_console_for_windows",
+ "//tensorflow/contrib/lite/python:interpreter_test_data",
+ "//tensorflow/contrib/lite/python:tflite_convert",
+ "//tensorflow/contrib/lite/toco/python:toco_from_protos",
+ ],
"//conditions:default": COMMON_PIP_DEPS + [
":simple_console",
"//tensorflow/contrib/lite/python:interpreter_test_data",
"//tensorflow/contrib/lite/python:tflite_convert",
"//tensorflow/contrib/lite/toco/python:toco_from_protos",
],
- }) + if_mkl_ml(["//third_party/intel_mkl_ml"]),
+ }) + if_mkl_ml(["//third_party/mkl:intel_binary_blob"]),
)
# A genrule for generating a marker file for the pip package on Windows
diff --git a/tensorflow/tools/pip_package/MANIFEST.in b/tensorflow/tools/pip_package/MANIFEST.in
index 86c5e4776d..c4b4af93b8 100644
--- a/tensorflow/tools/pip_package/MANIFEST.in
+++ b/tensorflow/tools/pip_package/MANIFEST.in
@@ -1,5 +1,6 @@
include README
recursive-include * *.py
+recursive-include * *.pd
recursive-include * *.so
recursive-include * *.dll
recursive-include * *.lib
diff --git a/tensorflow/tools/proto_text/BUILD b/tensorflow/tools/proto_text/BUILD
index fc2c041b6c..b4b70e0a78 100644
--- a/tensorflow/tools/proto_text/BUILD
+++ b/tensorflow/tools/proto_text/BUILD
@@ -39,6 +39,7 @@ cc_binary(
":gen_proto_text_functions_lib",
"@protobuf_archive//:protobuf",
"//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core:lib_proto_compiler",
] + if_ios(["//tensorflow/core/platform/default/build_config:logging"]),
)
diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions.cc b/tensorflow/tools/proto_text/gen_proto_text_functions.cc
index 234afe879b..159976f1b0 100644
--- a/tensorflow/tools/proto_text/gen_proto_text_functions.cc
+++ b/tensorflow/tools/proto_text/gen_proto_text_functions.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/protobuf_compiler.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/tools/proto_text/gen_proto_text_functions_lib.h"
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 10bfe7e0f6..dfb5d8a6e1 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -20,6 +20,10 @@ load(
"def_file_filter_configure",
)
+def initialize_third_party():
+ # Fill in later
+ pass
+
# Sanitize a dependency so that it works correctly from code that includes
# TensorFlow as a submodule.
def clean_dep(dep):
@@ -40,6 +44,8 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
syslibs_configure(name = "local_config_syslibs")
python_configure(name = "local_config_python")
+ initialize_third_party()
+
# For windows bazel build
# TODO: Remove def file filter when TensorFlow can export symbols properly on Windows.
def_file_filter_configure(name = "local_config_def_file_filter")
@@ -395,21 +401,22 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "nsync",
urls = [
- "https://mirror.bazel.build/github.com/google/nsync/archive/1.20.0.tar.gz",
- "https://github.com/google/nsync/archive/1.20.0.tar.gz",
+ "https://mirror.bazel.build/github.com/google/nsync/archive/1.20.1.tar.gz",
+ "https://github.com/google/nsync/archive/1.20.1.tar.gz",
],
- sha256 = "0c1b03962b2f8450f21e74a5a46116bf2d6009a807c57eb4207e974a8c4bb7dd",
- strip_prefix = "nsync-1.20.0",
+ sha256 = "692f9b30e219f71a6371b98edd39cef3cbda35ac3abc4cd99ce19db430a5591a",
+ strip_prefix = "nsync-1.20.1",
+ system_build_file = clean_dep("//third_party/systemlibs:nsync.BUILD"),
)
tf_http_archive(
name = "com_google_googletest",
urls = [
- "https://mirror.bazel.build/github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip",
- "https://github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip",
+ "https://mirror.bazel.build/github.com/google/googletest/archive/997d343dd680e541ef96ce71ee54a91daf2577a0.zip",
+ "https://github.com/google/googletest/archive/997d343dd680e541ef96ce71ee54a91daf2577a0.zip",
],
- sha256 = "9cbca84c4256bed17df2c8f4d00c912c19d247c11c9ba6647cd6dd5b5c996b8d",
- strip_prefix = "googletest-9816b96a6ddc0430671693df90192bbee57108b6",
+ sha256 = "353ab86e35cea1cd386115279cf4b16695bbf21b897bfbf2721cf4cb5f64ade8",
+ strip_prefix = "googletest-997d343dd680e541ef96ce71ee54a91daf2577a0",
)
tf_http_archive(
@@ -486,11 +493,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/17454e67ca55357e103cec104c3dc973bbb11ff0.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/17454e67ca55357e103cec104c3dc973bbb11ff0.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/6203c9bd082a877a20c218033636712135a3c2db.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/6203c9bd082a877a20c218033636712135a3c2db.tar.gz",
],
- sha256 = "7543322052e27e70f882801ef70a45afc268e09aaf6a07b840450bfcac366eb6",
- strip_prefix = "llvm-17454e67ca55357e103cec104c3dc973bbb11ff0",
+ sha256 = "83a80f9fb2a5949ca77e526344cbd4581388c3ec7fea5c59e488d46fd38e06d9",
+ strip_prefix = "llvm-6203c9bd082a877a20c218033636712135a3c2db",
build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
)
@@ -521,11 +528,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "boringssl",
urls = [
- "https://mirror.bazel.build/github.com/google/boringssl/archive/45c4a87ae97eb95a8fc2906c035d6a8d0e02e1b8.tar.gz",
- "https://github.com/google/boringssl/archive/45c4a87ae97eb95a8fc2906c035d6a8d0e02e1b8.tar.gz",
+ "https://mirror.bazel.build/github.com/google/boringssl/archive/7f634429a04abc48e2eb041c81c5235816c96514.tar.gz",
+ "https://github.com/google/boringssl/archive/7f634429a04abc48e2eb041c81c5235816c96514.tar.gz",
],
- sha256 = "972e8d8a9d1daf9892fff7155312b1af46b4754446575a7b285e62f917424c78",
- strip_prefix = "boringssl-45c4a87ae97eb95a8fc2906c035d6a8d0e02e1b8",
+ sha256 = "1188e29000013ed6517168600fc35a010d58c5d321846d6a6dfee74e4c788b45",
+ strip_prefix = "boringssl-7f634429a04abc48e2eb041c81c5235816c96514",
)
tf_http_archive(
@@ -576,11 +583,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "kafka",
urls = [
- "https://mirror.bazel.build/github.com/edenhill/librdkafka/archive/v0.11.4.tar.gz",
- "https://github.com/edenhill/librdkafka/archive/v0.11.4.tar.gz",
+ "https://mirror.bazel.build/github.com/edenhill/librdkafka/archive/v0.11.5.tar.gz",
+ "https://github.com/edenhill/librdkafka/archive/v0.11.5.tar.gz",
],
- sha256 = "9d8f1eb7b0e29e9ab1168347c939cb7ae5dff00a39cef99e7ef033fd8f92737c",
- strip_prefix = "librdkafka-0.11.4",
+ sha256 = "cc6ebbcd0a826eec1b8ce1f625ffe71b53ef3290f8192b6cae38412a958f4fd3",
+ strip_prefix = "librdkafka-0.11.5",
build_file = clean_dep("//third_party:kafka/BUILD"),
patch_file = clean_dep("//third_party/kafka:config.patch"),
)
diff --git a/third_party/gpus/cuda/BUILD.windows.tpl b/third_party/gpus/cuda/BUILD.windows.tpl
index ff6b3cc351..325d18b9cb 100644
--- a/third_party/gpus/cuda/BUILD.windows.tpl
+++ b/third_party/gpus/cuda/BUILD.windows.tpl
@@ -142,6 +142,7 @@ cc_library(
],
includes = [
".",
+ "cuda/",
"cuda/extras/CUPTI/include/",
],
visibility = ["//visibility:public"],
diff --git a/third_party/kafka/BUILD b/third_party/kafka/BUILD
index 3c50b8cf52..11ec50069a 100644
--- a/third_party/kafka/BUILD
+++ b/third_party/kafka/BUILD
@@ -48,8 +48,13 @@ cc_library(
"src/rdinterval.h",
"src/rdkafka.c",
"src/rdkafka.h",
+ "src/rdkafka_admin.c",
+ "src/rdkafka_admin.h",
"src/rdkafka_assignor.c",
"src/rdkafka_assignor.h",
+ "src/rdkafka_aux.c",
+ "src/rdkafka_aux.h",
+ "src/rdkafka_background.c",
"src/rdkafka_broker.c",
"src/rdkafka_broker.h",
"src/rdkafka_buf.c",
@@ -58,6 +63,7 @@ cc_library(
"src/rdkafka_cgrp.h",
"src/rdkafka_conf.c",
"src/rdkafka_conf.h",
+ "src/rdkafka_confval.h",
"src/rdkafka_event.h",
"src/rdkafka_feature.c",
"src/rdkafka_feature.h",
diff --git a/third_party/ngraph/build_defs.bzl b/third_party/ngraph/build_defs.bzl
index 8ad7515aed..3c34be524b 100644
--- a/third_party/ngraph/build_defs.bzl
+++ b/third_party/ngraph/build_defs.bzl
@@ -1,14 +1,11 @@
+"""Build configurations for nGraph."""
+
def clean_dep(dep):
- return str(Label(dep))
+ return str(Label(dep))
def if_ngraph(if_true, if_false = []):
- """Shorthand for select()'ing on whether we're building with nGraph support.
-
- Returns a select statement which evaluates to if_true if we're building
- with nGraph. Otherwise, the select statement evaluates to default.
-
- """
+ """select()'ing on whether we're building with nGraph support."""
return select({
clean_dep("//tensorflow:with_ngraph_support"): if_true,
- "//conditions:default": if_false
+ "//conditions:default": if_false,
})
diff --git a/third_party/ngraph/ngraph.BUILD b/third_party/ngraph/ngraph.BUILD
index 17710b2cb9..f73ce4f674 100644
--- a/third_party/ngraph/ngraph.BUILD
+++ b/third_party/ngraph/ngraph.BUILD
@@ -13,22 +13,22 @@ filegroup(
cc_library(
name = "ngraph_core",
srcs = glob([
- "src/ngraph/*.cpp",
- "src/ngraph/autodiff/*.cpp",
- "src/ngraph/builder/*.cpp",
- "src/ngraph/descriptor/*.cpp",
- "src/ngraph/descriptor/layout/*.cpp",
- "src/ngraph/op/*.cpp",
- "src/ngraph/op/util/*.cpp",
- "src/ngraph/pattern/*.cpp",
- "src/ngraph/pattern/*.hpp",
- "src/ngraph/pass/*.cpp",
- "src/ngraph/pass/*.hpp",
- "src/ngraph/runtime/*.cpp",
- "src/ngraph/type/*.cpp",
- "src/ngraph/runtime/interpreter/*.cpp",
- "src/ngraph/runtime/interpreter/*.hpp",
- ]),
+ "src/ngraph/*.cpp",
+ "src/ngraph/autodiff/*.cpp",
+ "src/ngraph/builder/*.cpp",
+ "src/ngraph/descriptor/*.cpp",
+ "src/ngraph/descriptor/layout/*.cpp",
+ "src/ngraph/op/*.cpp",
+ "src/ngraph/op/util/*.cpp",
+ "src/ngraph/pattern/*.cpp",
+ "src/ngraph/pattern/*.hpp",
+ "src/ngraph/pass/*.cpp",
+ "src/ngraph/pass/*.hpp",
+ "src/ngraph/runtime/*.cpp",
+ "src/ngraph/type/*.cpp",
+ "src/ngraph/runtime/interpreter/*.cpp",
+ "src/ngraph/runtime/interpreter/*.hpp",
+ ]),
hdrs = glob(["src/ngraph/**/*.hpp"]),
deps = [
"@eigen_archive//:eigen",
@@ -41,5 +41,5 @@ cc_library(
'-D NGRAPH_VERSION=\\"0.5.0\\"',
],
visibility = ["//visibility:public"],
- alwayslink=1
+ alwayslink = 1,
)
diff --git a/third_party/ngraph/ngraph_tf.BUILD b/third_party/ngraph/ngraph_tf.BUILD
index f36532449c..0c2c8a718f 100644
--- a/third_party/ngraph/ngraph_tf.BUILD
+++ b/third_party/ngraph/ngraph_tf.BUILD
@@ -12,7 +12,7 @@ filegroup(
load(
"@org_tensorflow//tensorflow:tensorflow.bzl",
- "tf_cc_test"
+ "tf_cc_test",
)
cc_library(
@@ -54,13 +54,13 @@ cc_library(
"logging/tf_graph_writer.cc",
],
hdrs = [
- "src/tf_graphcycles.h"
+ "src/tf_graphcycles.h",
],
deps = [
"@org_tensorflow//tensorflow/core:protos_all_proto_text",
"@org_tensorflow//tensorflow/core:framework_headers_lib",
"@org_tensorflow//tensorflow/core:core_cpu_headers_lib",
- "@ngraph//:ngraph_core"
+ "@ngraph//:ngraph_core",
],
copts = [
"-I external/ngraph_tf/src",
@@ -68,7 +68,7 @@ cc_library(
"-I external/ngraph/src",
"-D NGRAPH_EMBEDDED_IN_TENSORFLOW=1",
],
- alwayslink=1,
+ alwayslink = 1,
visibility = ["//visibility:public"],
)
diff --git a/third_party/ngraph/nlohmann_json.BUILD b/third_party/ngraph/nlohmann_json.BUILD
index 396e158535..a0b18a51cb 100644
--- a/third_party/ngraph/nlohmann_json.BUILD
+++ b/third_party/ngraph/nlohmann_json.BUILD
@@ -19,5 +19,5 @@ cc_library(
"-I external/nlohmann_json_lib",
],
visibility = ["//visibility:public"],
- alwayslink=1
+ alwayslink = 1,
)
diff --git a/third_party/repo.bzl b/third_party/repo.bzl
index 5cb42691c5..7d1aa5dce9 100644
--- a/third_party/repo.bzl
+++ b/third_party/repo.bzl
@@ -19,104 +19,111 @@ _SINGLE_URL_WHITELIST = depset([
])
def _is_windows(ctx):
- return ctx.os.name.lower().find("windows") != -1
+ return ctx.os.name.lower().find("windows") != -1
def _wrap_bash_cmd(ctx, cmd):
- if _is_windows(ctx):
- bazel_sh = _get_env_var(ctx, "BAZEL_SH")
- if not bazel_sh:
- fail("BAZEL_SH environment variable is not set")
- cmd = [bazel_sh, "-l", "-c", " ".join(cmd)]
- return cmd
+ if _is_windows(ctx):
+ bazel_sh = _get_env_var(ctx, "BAZEL_SH")
+ if not bazel_sh:
+ fail("BAZEL_SH environment variable is not set")
+ cmd = [bazel_sh, "-l", "-c", " ".join(cmd)]
+ return cmd
def _get_env_var(ctx, name):
- if name in ctx.os.environ:
- return ctx.os.environ[name]
- else:
- return None
+ if name in ctx.os.environ:
+ return ctx.os.environ[name]
+ else:
+ return None
# Checks if we should use the system lib instead of the bundled one
def _use_system_lib(ctx, name):
- syslibenv = _get_env_var(ctx, "TF_SYSTEM_LIBS")
- if syslibenv:
- for n in syslibenv.strip().split(","):
- if n.strip() == name:
- return True
- return False
+ syslibenv = _get_env_var(ctx, "TF_SYSTEM_LIBS")
+ if syslibenv:
+ for n in syslibenv.strip().split(","):
+ if n.strip() == name:
+ return True
+ return False
# Executes specified command with arguments and calls 'fail' if it exited with
# non-zero code
def _execute_and_check_ret_code(repo_ctx, cmd_and_args):
- result = repo_ctx.execute(cmd_and_args, timeout=10)
- if result.return_code != 0:
- fail(("Non-zero return code({1}) when executing '{0}':\n" + "Stdout: {2}\n"
- + "Stderr: {3}").format(" ".join(cmd_and_args), result.return_code,
- result.stdout, result.stderr))
+ result = repo_ctx.execute(cmd_and_args, timeout = 10)
+ if result.return_code != 0:
+ fail(("Non-zero return code({1}) when executing '{0}':\n" + "Stdout: {2}\n" +
+ "Stderr: {3}").format(
+ " ".join(cmd_and_args),
+ result.return_code,
+ result.stdout,
+ result.stderr,
+ ))
def _repos_are_siblings():
- return Label("@foo//bar").workspace_root.startswith("../")
+ return Label("@foo//bar").workspace_root.startswith("../")
# Apply a patch_file to the repository root directory
# Runs 'patch -p1'
def _apply_patch(ctx, patch_file):
- # Don't check patch on Windows, because patch is only available under bash.
- if not _is_windows(ctx) and not ctx.which("patch"):
- fail("patch command is not found, please install it")
- cmd = _wrap_bash_cmd(
- ctx, ["patch", "-p1", "-d", ctx.path("."), "-i", ctx.path(patch_file)])
- _execute_and_check_ret_code(ctx, cmd)
+ # Don't check patch on Windows, because patch is only available under bash.
+ if not _is_windows(ctx) and not ctx.which("patch"):
+ fail("patch command is not found, please install it")
+ cmd = _wrap_bash_cmd(
+ ctx,
+ ["patch", "-p1", "-d", ctx.path("."), "-i", ctx.path(patch_file)],
+ )
+ _execute_and_check_ret_code(ctx, cmd)
def _apply_delete(ctx, paths):
- for path in paths:
- if path.startswith("/"):
- fail("refusing to rm -rf path starting with '/': " + path)
- if ".." in path:
- fail("refusing to rm -rf path containing '..': " + path)
- cmd = _wrap_bash_cmd(ctx, ["rm", "-rf"] + [ctx.path(path) for path in paths])
- _execute_and_check_ret_code(ctx, cmd)
+ for path in paths:
+ if path.startswith("/"):
+ fail("refusing to rm -rf path starting with '/': " + path)
+ if ".." in path:
+ fail("refusing to rm -rf path containing '..': " + path)
+ cmd = _wrap_bash_cmd(ctx, ["rm", "-rf"] + [ctx.path(path) for path in paths])
+ _execute_and_check_ret_code(ctx, cmd)
def _tf_http_archive(ctx):
- if ("mirror.bazel.build" not in ctx.attr.urls[0] and
- (len(ctx.attr.urls) < 2 and
- ctx.attr.name not in _SINGLE_URL_WHITELIST)):
- fail("tf_http_archive(urls) must have redundant URLs. The " +
- "mirror.bazel.build URL must be present and it must come first. " +
- "Even if you don't have permission to mirror the file, please " +
- "put the correctly formatted mirror URL there anyway, because " +
- "someone will come along shortly thereafter and mirror the file.")
-
- use_syslib = _use_system_lib(ctx, ctx.attr.name)
- if not use_syslib:
- ctx.download_and_extract(
- ctx.attr.urls,
- "",
- ctx.attr.sha256,
- ctx.attr.type,
- ctx.attr.strip_prefix)
- if ctx.attr.delete:
- _apply_delete(ctx, ctx.attr.delete)
- if ctx.attr.patch_file != None:
- _apply_patch(ctx, ctx.attr.patch_file)
-
- if use_syslib and ctx.attr.system_build_file != None:
- # Use BUILD.bazel to avoid conflict with third party projects with
- # BUILD or build (directory) underneath.
- ctx.template("BUILD.bazel", ctx.attr.system_build_file, {
- "%prefix%": ".." if _repos_are_siblings() else "external",
- }, False)
-
- elif ctx.attr.build_file != None:
- # Use BUILD.bazel to avoid conflict with third party projects with
- # BUILD or build (directory) underneath.
- ctx.template("BUILD.bazel", ctx.attr.build_file, {
- "%prefix%": ".." if _repos_are_siblings() else "external",
- }, False)
+ if ("mirror.bazel.build" not in ctx.attr.urls[0] and
+ (len(ctx.attr.urls) < 2 and
+ ctx.attr.name not in _SINGLE_URL_WHITELIST)):
+ fail("tf_http_archive(urls) must have redundant URLs. The " +
+ "mirror.bazel.build URL must be present and it must come first. " +
+ "Even if you don't have permission to mirror the file, please " +
+ "put the correctly formatted mirror URL there anyway, because " +
+ "someone will come along shortly thereafter and mirror the file.")
+
+ use_syslib = _use_system_lib(ctx, ctx.attr.name)
+ if not use_syslib:
+ ctx.download_and_extract(
+ ctx.attr.urls,
+ "",
+ ctx.attr.sha256,
+ ctx.attr.type,
+ ctx.attr.strip_prefix,
+ )
+ if ctx.attr.delete:
+ _apply_delete(ctx, ctx.attr.delete)
+ if ctx.attr.patch_file != None:
+ _apply_patch(ctx, ctx.attr.patch_file)
+
+ if use_syslib and ctx.attr.system_build_file != None:
+ # Use BUILD.bazel to avoid conflict with third party projects with
+ # BUILD or build (directory) underneath.
+ ctx.template("BUILD.bazel", ctx.attr.system_build_file, {
+ "%prefix%": ".." if _repos_are_siblings() else "external",
+ }, False)
+
+ elif ctx.attr.build_file != None:
+ # Use BUILD.bazel to avoid conflict with third party projects with
+ # BUILD or build (directory) underneath.
+ ctx.template("BUILD.bazel", ctx.attr.build_file, {
+ "%prefix%": ".." if _repos_are_siblings() else "external",
+ }, False)
tf_http_archive = repository_rule(
- implementation=_tf_http_archive,
- attrs={
- "sha256": attr.string(mandatory=True),
- "urls": attr.string_list(mandatory=True, allow_empty=False),
+ implementation = _tf_http_archive,
+ attrs = {
+ "sha256": attr.string(mandatory = True),
+ "urls": attr.string_list(mandatory = True, allow_empty = False),
"strip_prefix": attr.string(),
"type": attr.string(),
"delete": attr.string_list(),
@@ -124,12 +131,78 @@ tf_http_archive = repository_rule(
"build_file": attr.label(),
"system_build_file": attr.label(),
},
- environ=[
- "TF_SYSTEM_LIBS",
- ])
+ environ = [
+ "TF_SYSTEM_LIBS",
+ ],
+)
"""Downloads and creates Bazel repos for dependencies.
This is a swappable replacement for both http_archive() and
new_http_archive() that offers some additional features. It also helps
ensure best practices are followed.
"""
+
+def _third_party_http_archive(ctx):
+ if ("mirror.bazel.build" not in ctx.attr.urls[0] and
+ (len(ctx.attr.urls) < 2 and
+ ctx.attr.name not in _SINGLE_URL_WHITELIST)):
+ fail("tf_http_archive(urls) must have redundant URLs. The " +
+ "mirror.bazel.build URL must be present and it must come first. " +
+ "Even if you don't have permission to mirror the file, please " +
+ "put the correctly formatted mirror URL there anyway, because " +
+ "someone will come along shortly thereafter and mirror the file.")
+
+ use_syslib = _use_system_lib(ctx, ctx.attr.name)
+
+ # Use "BUILD.bazel" to avoid conflict with third party projects that contain a
+ # file or directory called "BUILD"
+ buildfile_path = ctx.path("BUILD.bazel")
+
+ if use_syslib:
+ if ctx.attr.system_build_file == None:
+ fail("Bazel was configured with TF_SYSTEM_LIBS to use a system " +
+ "library for %s, but no system build file for %s was configured. " +
+ "Please add a system_build_file attribute to the repository rule" +
+ "for %s." % (ctx.attr.name, ctx.attr.name, ctx.attr.name))
+ ctx.symlink(Label(ctx.attr.system_build_file), buildfile_path)
+
+ else:
+ ctx.download_and_extract(
+ ctx.attr.urls,
+ "",
+ ctx.attr.sha256,
+ ctx.attr.type,
+ ctx.attr.strip_prefix,
+ )
+ if ctx.attr.delete:
+ _apply_delete(ctx, ctx.attr.delete)
+ if ctx.attr.patch_file != None:
+ _apply_patch(ctx, ctx.attr.patch_file)
+ ctx.symlink(Label(ctx.attr.build_file), buildfile_path)
+
+ for internal_src, external_dest in ctx.attr.link_files.items():
+ ctx.symlink(Label(internal_src), ctx.path(external_dest))
+
+# Downloads and creates Bazel repos for dependencies.
+#
+# This is an upgrade for tf_http_archive that works with go/tfbr-thirdparty.
+#
+# For link_files, specify each dict entry as:
+# "//path/to/source:file": "localfile"
+third_party_http_archive = repository_rule(
+ implementation = _third_party_http_archive,
+ attrs = {
+ "sha256": attr.string(mandatory = True),
+ "urls": attr.string_list(mandatory = True, allow_empty = False),
+ "strip_prefix": attr.string(),
+ "type": attr.string(),
+ "delete": attr.string_list(),
+ "build_file": attr.string(mandatory = True),
+ "system_build_file": attr.string(mandatory = False),
+ "patch_file": attr.label(),
+ "link_files": attr.string_dict(),
+ },
+ environ = [
+ "TF_SYSTEM_LIBS",
+ ],
+)
diff --git a/third_party/systemlibs/nsync.BUILD b/third_party/systemlibs/nsync.BUILD
new file mode 100644
index 0000000000..c5d4ad0a76
--- /dev/null
+++ b/third_party/systemlibs/nsync.BUILD
@@ -0,0 +1,23 @@
+licenses(["notice"]) # BSD 3-Clause
+
+filegroup(
+ name = "LICENSE",
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "nsync_headers",
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "nsync",
+ linkopts = ["-lnsync"],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "nsync_cpp",
+ linkopts = ["-lnsync_cpp"],
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/systemlibs/syslibs_configure.bzl b/third_party/systemlibs/syslibs_configure.bzl
index 07a44c317e..8b09c9ac1f 100644
--- a/third_party/systemlibs/syslibs_configure.bzl
+++ b/third_party/systemlibs/syslibs_configure.bzl
@@ -7,9 +7,9 @@
the system version instead
"""
-_TF_SYSTEM_LIBS="TF_SYSTEM_LIBS"
+_TF_SYSTEM_LIBS = "TF_SYSTEM_LIBS"
-VALID_LIBS=[
+VALID_LIBS = [
"astor_archive",
"com_googlesource_code_re2",
"curl",
@@ -22,6 +22,7 @@ VALID_LIBS=[
"jsoncpp_git",
"lmdb",
"nasm",
+ "nsync",
"org_sqlite",
"pcre",
"png_archive",
@@ -32,112 +33,109 @@ VALID_LIBS=[
"zlib_archive",
]
-
def auto_configure_fail(msg):
- """Output failure message when syslibs configuration fails."""
- red = "\033[0;31m"
- no_color = "\033[0m"
- fail("\n%sSystem Library Configuration Error:%s %s\n" % (red, no_color, msg))
-
+ """Output failure message when syslibs configuration fails."""
+ red = "\033[0;31m"
+ no_color = "\033[0m"
+ fail("\n%sSystem Library Configuration Error:%s %s\n" % (red, no_color, msg))
def _is_windows(repository_ctx):
- """Returns true if the host operating system is windows."""
- os_name = repository_ctx.os.name.lower()
- if os_name.find("windows") != -1:
- return True
- return False
-
+ """Returns true if the host operating system is windows."""
+ os_name = repository_ctx.os.name.lower()
+ if os_name.find("windows") != -1:
+ return True
+ return False
def _enable_syslibs(repository_ctx):
- s = repository_ctx.os.environ.get(_TF_SYSTEM_LIBS, '').strip()
- if not _is_windows(repository_ctx) and s != None and s != '':
- return True
- return False
-
+ s = repository_ctx.os.environ.get(_TF_SYSTEM_LIBS, "").strip()
+ if not _is_windows(repository_ctx) and s != None and s != "":
+ return True
+ return False
def _get_system_lib_list(repository_ctx):
- """Gets the list of deps that should use the system lib.
+ """Gets the list of deps that should use the system lib.
- Args:
- repository_ctx: The repository context.
+ Args:
+ repository_ctx: The repository context.
- Returns:
- A string version of a python list
- """
- if _TF_SYSTEM_LIBS not in repository_ctx.os.environ:
- return []
+ Returns:
+ A string version of a python list
+ """
+ if _TF_SYSTEM_LIBS not in repository_ctx.os.environ:
+ return []
- libenv = repository_ctx.os.environ[_TF_SYSTEM_LIBS].strip()
- libs = []
+ libenv = repository_ctx.os.environ[_TF_SYSTEM_LIBS].strip()
+ libs = []
- for lib in list(libenv.split(',')):
- lib = lib.strip()
- if lib == "":
- continue
- if lib not in VALID_LIBS:
- auto_configure_fail("Invalid system lib set: %s" % lib)
- return []
- libs.append(lib)
-
- return libs
+ for lib in list(libenv.split(",")):
+ lib = lib.strip()
+ if lib == "":
+ continue
+ if lib not in VALID_LIBS:
+ auto_configure_fail("Invalid system lib set: %s" % lib)
+ return []
+ libs.append(lib)
+ return libs
def _format_system_lib_list(repository_ctx):
- """Formats the list of deps that should use the system lib.
-
- Args:
- repository_ctx: The repository context.
-
- Returns:
- A list of the names of deps that should use the system lib.
- """
- libs = _get_system_lib_list(repository_ctx)
- ret = ''
- for lib in libs:
- ret += "'%s',\n" % lib
-
- return ret
-
-
-def _tpl(repository_ctx, tpl, substitutions={}, out=None):
- if not out:
- out = tpl.replace(":", "")
- repository_ctx.template(
- out,
- Label("//third_party/systemlibs%s.tpl" % tpl),
- substitutions,
- False)
-
+ """Formats the list of deps that should use the system lib.
+
+ Args:
+ repository_ctx: The repository context.
+
+ Returns:
+ A list of the names of deps that should use the system lib.
+ """
+ libs = _get_system_lib_list(repository_ctx)
+ ret = ""
+ for lib in libs:
+ ret += "'%s',\n" % lib
+
+ return ret
+
+def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
+ if not out:
+ out = tpl.replace(":", "")
+ repository_ctx.template(
+ out,
+ Label("//third_party/systemlibs%s.tpl" % tpl),
+ substitutions,
+ False,
+ )
def _create_dummy_repository(repository_ctx):
- """Creates the dummy repository to build with all bundled libraries."""
-
- _tpl(repository_ctx, ":BUILD")
- _tpl(repository_ctx, ":build_defs.bzl",
- {
- "%{syslibs_enabled}": 'False',
- "%{syslibs_list}": '',
- })
-
+ """Creates the dummy repository to build with all bundled libraries."""
+
+ _tpl(repository_ctx, ":BUILD")
+ _tpl(
+ repository_ctx,
+ ":build_defs.bzl",
+ {
+ "%{syslibs_enabled}": "False",
+ "%{syslibs_list}": "",
+ },
+ )
def _create_local_repository(repository_ctx):
- """Creates the repository to build with system libraries."""
-
- _tpl(repository_ctx, ":BUILD")
- _tpl(repository_ctx, ":build_defs.bzl",
- {
- "%{syslibs_enabled}": 'True',
- "%{syslibs_list}": _format_system_lib_list(repository_ctx),
- })
-
+ """Creates the repository to build with system libraries."""
+
+ _tpl(repository_ctx, ":BUILD")
+ _tpl(
+ repository_ctx,
+ ":build_defs.bzl",
+ {
+ "%{syslibs_enabled}": "True",
+ "%{syslibs_list}": _format_system_lib_list(repository_ctx),
+ },
+ )
def _syslibs_autoconf_impl(repository_ctx):
- """Implementation of the syslibs_configure repository rule."""
- if not _enable_syslibs(repository_ctx):
- _create_dummy_repository(repository_ctx)
- else:
- _create_local_repository(repository_ctx)
-
+ """Implementation of the syslibs_configure repository rule."""
+ if not _enable_syslibs(repository_ctx):
+ _create_dummy_repository(repository_ctx)
+ else:
+ _create_local_repository(repository_ctx)
syslibs_configure = repository_rule(
implementation = _syslibs_autoconf_impl,