aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--README.md2
-rw-r--r--tensorflow/compiler/aot/BUILD5
-rw-r--r--tensorflow/compiler/aot/codegen.cc19
-rw-r--r--tensorflow/compiler/aot/codegen_test.cc6
-rw-r--r--tensorflow/compiler/aot/embedded_protocol_buffers.cc9
-rw-r--r--tensorflow/compiler/aot/tests/BUILD1
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc4
-rw-r--r--tensorflow/compiler/aot/tfcompile_main.cc7
-rw-r--r--tensorflow/compiler/jit/BUILD55
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op.cc3
-rw-r--r--tensorflow/compiler/jit/deadness_analysis.cc5
-rw-r--r--tensorflow/compiler/jit/deadness_analysis_test.cc1
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc5
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc10
-rw-r--r--tensorflow/compiler/jit/kernels/BUILD13
-rw-r--r--tensorflow/compiler/jit/kernels/parallel_check_op.cc144
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc2
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc137
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc178
-rw-r--r--tensorflow/compiler/jit/ops/BUILD7
-rw-r--r--tensorflow/compiler/jit/ops/parallel_check_op.cc30
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass_test.cc1
-rw-r--r--tensorflow/compiler/jit/resource_operation_safety_analysis.cc336
-rw-r--r--tensorflow/compiler/jit/resource_operation_safety_analysis.h73
-rw-r--r--tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc540
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.cc24
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.h7
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util_test.cc1
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.cc15
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.h6
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc2
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc17
-rw-r--r--tensorflow/compiler/jit/xla_fusion_optimizer.cc2
-rw-r--r--tensorflow/compiler/jit/xla_fusion_optimizer_test.cc25
-rw-r--r--tensorflow/compiler/tests/BUILD16
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py35
-rw-r--r--tensorflow/compiler/tests/image_ops_test.py26
-rw-r--r--tensorflow/compiler/tests/xla_ops_test.py301
-rw-r--r--tensorflow/compiler/tf2xla/BUILD49
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.cc25
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.h14
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis_test.cc19
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond.cc45
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow_util.h11
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD9
-rw-r--r--tensorflow/compiler/tf2xla/kernels/bcast_ops.cc9
-rw-r--r--tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc101
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc161
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pooling_ops.cc144
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc92
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/softmax_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc115
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc101
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc65
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc105
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc102
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc147
-rw-r--r--tensorflow/compiler/tf2xla/lib/BUILD7
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.cc12
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.h6
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.cc17
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.h6
-rw-r--r--tensorflow/compiler/tf2xla/lib/qr.cc51
-rw-r--r--tensorflow/compiler/tf2xla/lib/qr.h7
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.cc46
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.h6
-rw-r--r--tensorflow/compiler/tf2xla/ops/BUILD2
-rw-r--r--tensorflow/compiler/tf2xla/ops/xla_ops.cc192
-rw-r--r--tensorflow/compiler/tf2xla/python/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/python/xla.py336
-rw-r--r--tensorflow/compiler/tf2xla/resource_operation_table.cc130
-rw-r--r--tensorflow/compiler/tf2xla/resource_operation_table.h71
-rw-r--r--tensorflow/compiler/tf2xla/resource_operation_table_test.cc66
-rw-r--r--tensorflow/compiler/tf2xla/sharding_util.cc6
-rw-r--r--tensorflow/compiler/tf2xla/str_util.cc44
-rw-r--r--tensorflow/compiler/tf2xla/str_util.h42
-rw-r--r--tensorflow/compiler/tf2xla/str_util_test.cc60
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla.cc6
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc8
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util_test.cc6
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc31
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc37
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h4
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.cc11
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.h3
-rw-r--r--tensorflow/compiler/xla/BUILD19
-rw-r--r--tensorflow/compiler/xla/array.h10
-rw-r--r--tensorflow/compiler/xla/array2d.h3
-rw-r--r--tensorflow/compiler/xla/array4d.h3
-rw-r--r--tensorflow/compiler/xla/client/BUILD3
-rw-r--r--tensorflow/compiler/xla/client/client.cc4
-rw-r--r--tensorflow/compiler/xla/client/compile_only_client.cc2
-rw-r--r--tensorflow/compiler/xla/client/compile_only_client.h2
-rw-r--r--tensorflow/compiler/xla/client/executable_build_options.cc8
-rw-r--r--tensorflow/compiler/xla/client/executable_build_options.h10
-rw-r--r--tensorflow/compiler/xla/client/lib/BUILD3
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.cc4
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.cc5
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc19
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h2
-rw-r--r--tensorflow/compiler/xla/device_util.h6
-rw-r--r--tensorflow/compiler/xla/index_util.cc4
-rw-r--r--tensorflow/compiler/xla/layout_util.cc12
-rw-r--r--tensorflow/compiler/xla/legacy_flags/BUILD2
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc4
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h29
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc1
-rw-r--r--tensorflow/compiler/xla/literal.cc20
-rw-r--r--tensorflow/compiler/xla/literal.h2
-rw-r--r--tensorflow/compiler/xla/literal_comparison.cc195
-rw-r--r--tensorflow/compiler/xla/literal_test.cc23
-rw-r--r--tensorflow/compiler/xla/literal_util.cc13
-rw-r--r--tensorflow/compiler/xla/literal_util.h4
-rw-r--r--tensorflow/compiler/xla/metric_table_report.cc23
-rw-r--r--tensorflow/compiler/xla/metric_table_report.h5
-rw-r--r--tensorflow/compiler/xla/packed_literal_reader.cc4
-rw-r--r--tensorflow/compiler/xla/python/BUILD1
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i3
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.cc5
-rw-r--r--tensorflow/compiler/xla/service/BUILD63
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc7
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.h2
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc42
-rw-r--r--tensorflow/compiler/xla/service/allocation_tracker.cc2
-rw-r--r--tensorflow/compiler/xla/service/backend.h4
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification.cc2
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification.h2
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.h2
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander_test.cc1
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding.h2
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization.cc28
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization.h4
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.h4
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc16
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness.cc8
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness_test.cc42
-rw-r--r--tensorflow/compiler/xla/service/buffer_value.cc3
-rw-r--r--tensorflow/compiler/xla/service/call_graph.cc14
-rw-r--r--tensorflow/compiler/xla/service/call_inliner.h2
-rw-r--r--tensorflow/compiler/xla/service/call_inliner_test.cc1
-rw-r--r--tensorflow/compiler/xla/service/channel_tracker.cc2
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.cc8
-rw-r--r--tensorflow/compiler/xla/service/computation_layout.cc9
-rw-r--r--tensorflow/compiler/xla/service/computation_placer.cc6
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier.cc3
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier.h6
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/convolution_feature_group_converter.h4
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc35
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD9
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc72
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.h11
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_options.cc22
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc19
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc21
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h13
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.cc14
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc17
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc13
-rw-r--r--tensorflow/compiler/xla/service/defuser.h2
-rw-r--r--tensorflow/compiler/xla/service/defuser_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/despecializer.cc4
-rw-r--r--tensorflow/compiler/xla/service/despecializer.h2
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h2
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h2
-rw-r--r--tensorflow/compiler/xla/service/dot_decomposer.h2
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc17
-rw-r--r--tensorflow/compiler/xla/service/flatten_call_graph.h2
-rw-r--r--tensorflow/compiler/xla/service/gather_expander.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD13
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_comparator.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc10
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger.cc11
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc31
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc17
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc35
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc76
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc47
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc41
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/thunk_schedule.cc11
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/graphviz_example.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_buffer.cc16
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc37
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc48
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_isolator.cc26
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_isolator.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_metadata.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_remover.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc44
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_verifier.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_verifier.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_execution_profile_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc56
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc111
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h24
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc85
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h7
-rw-r--r--tensorflow/compiler/xla/service/hlo_lexer.cc38
-rw-r--r--tensorflow/compiler/xla/service/hlo_lexer.h12
-rw-r--r--tensorflow/compiler/xla/service/hlo_liveness_analysis.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_dce.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc120
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.h16
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc28
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_interface.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_proto_util_test.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc59
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.cc147
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.h38
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_subcomputation_unification.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc25
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc153
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h60
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/human_readable_profile_builder.cc6
-rw-r--r--tensorflow/compiler/xla/service/human_readable_profile_builder.h9
-rw-r--r--tensorflow/compiler/xla/service/implicit_broadcast_remover.h2
-rw-r--r--tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc49
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis_test.cc11
-rw-r--r--tensorflow/compiler/xla/service/inliner.h2
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h2
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc36
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h2
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/BUILD8
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc3
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc5
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h3
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.cc8
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.h6
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc8
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h42
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc37
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h30
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc57
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.h29
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc4
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h4
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/sort_util.cc4
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/sort_util.h4
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc2
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer.cc11
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.h6
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.cc15
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.h4
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher.h8
-rw-r--r--tensorflow/compiler/xla/service/platform_util.cc21
-rw-r--r--tensorflow/compiler/xla/service/reduce_precision_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover.h2
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover_test.cc13
-rw-r--r--tensorflow/compiler/xla/service/scatter_expander.h2
-rw-r--r--tensorflow/compiler/xla/service/service.cc4
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc122
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.cc3
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.cc3
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.h2
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc40
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_constant_sinking.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.cc15
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.h4
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/while_util.cc4
-rw-r--r--tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h2
-rw-r--r--tensorflow/compiler/xla/shape_util.cc60
-rw-r--r--tensorflow/compiler/xla/shape_util.h2
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc10
-rw-r--r--tensorflow/compiler/xla/status_macros.cc24
-rw-r--r--tensorflow/compiler/xla/test_helpers.h2
-rw-r--r--tensorflow/compiler/xla/tests/BUILD38
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc10
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h4
-rw-r--r--tensorflow/compiler/xla/tests/compute_constant_test.cc10
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc17
-rw-r--r--tensorflow/compiler/xla/tests/floor_ceil_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc22
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h16
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.cc12
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.h19
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util_test.cc10
-rw-r--r--tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc50
-rw-r--r--tensorflow/compiler/xla/tests/reduce_hlo_test.cc11
-rw-r--r--tensorflow/compiler/xla/tests/reduce_precision_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc7
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc68
-rw-r--r--tensorflow/compiler/xla/tests/reverse_test.cc7
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/test_macros.cc13
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc8
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.h4
-rw-r--r--tensorflow/compiler/xla/tests/token_hlo_test.cc18
-rw-r--r--tensorflow/compiler/xla/tests/unary_op_test.cc19
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc29
-rw-r--r--tensorflow/compiler/xla/tests/xla_internal_test_main.cc14
-rw-r--r--tensorflow/compiler/xla/text_literal_reader.cc72
-rw-r--r--tensorflow/compiler/xla/text_literal_reader.h5
-rw-r--r--tensorflow/compiler/xla/text_literal_writer.cc17
-rw-r--r--tensorflow/compiler/xla/text_literal_writer.h5
-rw-r--r--tensorflow/compiler/xla/tools/BUILD2
-rw-r--r--tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc7
-rw-r--r--tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc4
-rw-r--r--tensorflow/compiler/xla/util.cc52
-rw-r--r--tensorflow/compiler/xla/util.h32
-rw-r--r--tensorflow/compiler/xla/window_util.cc11
-rw-r--r--tensorflow/contrib/BUILD1
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow.py4
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow_test.py8
-rw-r--r--tensorflow/contrib/autograph/pyct/testing/BUILD2
-rw-r--r--tensorflow/contrib/data/__init__.py4
-rw-r--r--tensorflow/contrib/data/kernels/BUILD12
-rw-r--r--tensorflow/contrib/data/kernels/lmdb_dataset_op.cc215
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc9
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD28
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py66
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/BUILD13
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py42
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py50
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD1
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py59
-rw-r--r--tensorflow/contrib/data/python/ops/parsing_ops.py2
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py52
-rw-r--r--tensorflow/contrib/distribute/python/BUILD61
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py8
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py109
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops_test.py28
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py6
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py31
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_test.py10
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_test_base.py107
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy.py44
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py133
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py19
-rw-r--r--tensorflow/contrib/distribute/python/values.py27
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py96
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py56
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py19
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py3
-rw-r--r--tensorflow/contrib/eager/python/BUILD14
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50.py4
-rw-r--r--tensorflow/contrib/eager/python/remote.py73
-rw-r--r--tensorflow/contrib/eager/python/remote_test.py13
-rw-r--r--tensorflow/contrib/eager/python/tfe.py3
-rw-r--r--tensorflow/contrib/lite/build_def.bzl1
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h5
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc8
-rw-r--r--tensorflow/contrib/lite/g3doc/apis.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/custom_operators.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/demo_android.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/demo_ios.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/devguide.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/ios.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/models.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/ops_versioning.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/overview.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/performance.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/rpi.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md15
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/android_build.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/index.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md2
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD15
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h175
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h239
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h9
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc76
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_test.cc49
-rw-r--r--tensorflow/contrib/lite/kernels/optional_tensor_test.cc24
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/unpack.cc130
-rw-r--r--tensorflow/contrib/lite/kernels/unpack_test.cc225
-rwxr-xr-xtensorflow/contrib/lite/lib_package/create_ios_frameworks.sh7
-rw-r--r--tensorflow/contrib/lite/model.cc10
-rw-r--r--tensorflow/contrib/lite/models/speech_test.cc10
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py32
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc22
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc17
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc9
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc29
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc21
-rw-r--r--tensorflow/contrib/lite/toco/model.h15
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc20
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc10
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc1
-rw-r--r--tensorflow/contrib/lite/tools/optimize/BUILD11
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights.cc280
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights.h38
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc130
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py51
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py54
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc6
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/input_data.h3
-rw-r--r--tensorflow/contrib/tensorrt/BUILD1
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc1
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py15
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc7
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc9
-rw-r--r--tensorflow/core/common_runtime/gpu_device_context.h3
-rw-r--r--tensorflow/core/framework/device_base.h9
-rw-r--r--tensorflow/core/grappler/clusters/virtual_cluster.cc2
-rw-r--r--tensorflow/core/grappler/costs/analytical_cost_estimator.cc1
-rw-r--r--tensorflow/core/grappler/costs/graph_memory.cc1
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc6
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc42
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc1
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc1
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc11
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler_test.cc1
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.cc1
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD5
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc1
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc33
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc1
-rw-r--r--tensorflow/core/grappler/optimizers/evaluation_utils.cc1
-rw-r--r--tensorflow/core/grappler/optimizers/evaluation_utils.h1
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer.cc19
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc1
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc1
-rw-r--r--tensorflow/core/grappler/optimizers/shape_optimizer.cc1
-rw-r--r--tensorflow/core/grappler/utils/functions.cc14
-rw-r--r--tensorflow/core/grappler/utils/functions.h6
-rw-r--r--tensorflow/core/grappler/utils/functions_test.cc31
-rw-r--r--tensorflow/core/kernels/data/BUILD1
-rw-r--r--tensorflow/core/kernels/data/parse_example_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/parameterized_truncated_normal_op.cc3
-rw-r--r--tensorflow/core/lib/monitoring/collection_registry.cc8
-rw-r--r--tensorflow/core/lib/monitoring/collection_registry.h4
-rw-r--r--tensorflow/core/lib/monitoring/metric_def.h4
-rw-r--r--tensorflow/core/lib/strings/numbers.h4
-rw-r--r--tensorflow/core/lib/strings/str_util.cc5
-rw-r--r--tensorflow/core/lib/strings/str_util.h2
-rw-r--r--tensorflow/core/platform/env.cc4
-rw-r--r--tensorflow/core/platform/file_system.cc2
-rw-r--r--tensorflow/core/platform/file_system_helper.cc2
-rw-r--r--tensorflow/core/platform/file_system_test.cc2
-rw-r--r--tensorflow/core/util/command_line_flags.cc2
-rw-r--r--tensorflow/core/util/ctc/ctc_beam_search.h18
-rw-r--r--tensorflow/core/util/env_var.cc8
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.cc2
-rw-r--r--tensorflow/docs_src/guide/premade_estimators.md2
-rw-r--r--tensorflow/docs_src/guide/saved_model.md2
-rw-r--r--tensorflow/js/ops/ts_op_gen.cc93
-rw-r--r--tensorflow/js/ops/ts_op_gen_test.cc138
-rw-r--r--tensorflow/python/BUILD5
-rw-r--r--tensorflow/python/compat/compat.py2
-rw-r--r--tensorflow/python/debug/BUILD2
-rw-r--r--tensorflow/python/eager/BUILD1
-rw-r--r--tensorflow/python/eager/function.py34
-rw-r--r--tensorflow/python/eager/function_test.py41
-rw-r--r--tensorflow/python/estimator/estimator.py6
-rw-r--r--tensorflow/python/estimator/export/export.py22
-rw-r--r--tensorflow/python/estimator/export/export_test.py7
-rw-r--r--tensorflow/python/estimator/run_config.py11
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py74
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py550
-rw-r--r--tensorflow/python/feature_column/feature_column_v2_test.py764
-rw-r--r--tensorflow/python/framework/function.py23
-rw-r--r--tensorflow/python/framework/function_def_to_graph.py20
-rw-r--r--tensorflow/python/framework/function_def_to_graph_test.py43
-rw-r--r--tensorflow/python/framework/ops.py27
-rw-r--r--tensorflow/python/framework/smart_cond.py6
-rw-r--r--tensorflow/python/framework/subscribe.py7
-rw-r--r--tensorflow/python/keras/backend_test.py4
-rw-r--r--tensorflow/python/keras/engine/sequential.py1
-rw-r--r--tensorflow/python/keras/layers/recurrent.py53
-rw-r--r--tensorflow/python/keras/layers/recurrent_test.py18
-rw-r--r--tensorflow/python/keras/models.py2
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py31
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py1
-rw-r--r--tensorflow/python/kernel_tests/ctc_decoder_ops_test.py18
-rw-r--r--tensorflow/python/kernel_tests/matmul_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/partitioned_variables_test.py80
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py10
-rw-r--r--tensorflow/python/ops/cond_v2.py2
-rw-r--r--tensorflow/python/ops/cond_v2_impl.py148
-rw-r--r--tensorflow/python/ops/control_flow_ops.py10
-rw-r--r--tensorflow/python/ops/math_ops.py55
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py9
-rw-r--r--tensorflow/python/ops/sparse_ops.py54
-rw-r--r--tensorflow/python/ops/sparse_ops_test.py32
-rw-r--r--tensorflow/python/ops/variables.py7
-rw-r--r--tensorflow/python/training/checkpoint_management.py14
-rw-r--r--tensorflow/python/training/checkpoint_management_test.py44
-rw-r--r--tensorflow/python/training/checkpointable/util_test.py2
-rw-r--r--tensorflow/python/training/moving_averages.py55
-rw-r--r--tensorflow/python/training/moving_averages_test.py21
-rw-r--r--tensorflow/python/training/saver_test.py2
-rw-r--r--tensorflow/python/util/tf_export.py13
-rw-r--r--tensorflow/stream_executor/lib/env.h2
-rw-r--r--tensorflow/stream_executor/lib/path.cc2
-rw-r--r--tensorflow/stream_executor/lib/str_util.h2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt4
-rw-r--r--tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh2
-rw-r--r--tensorflow/tools/docker/README.md7
-rw-r--r--tensorflow/tools/dockerfiles/README.md67
-rw-r--r--tensorflow/tools/dockerfiles/assembler.Dockerfile30
-rw-r--r--tensorflow/tools/dockerfiles/assembler.py554
-rw-r--r--tensorflow/tools/dockerfiles/bashrc50
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/cpu-devel-jupyter.Dockerfile100
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/cpu-devel.Dockerfile89
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile69
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile58
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel-jupyter.Dockerfile120
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel.Dockerfile109
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/nvidia-jupyter.Dockerfile90
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/nvidia.Dockerfile79
-rw-r--r--tensorflow/tools/dockerfiles/partials/bazel.partial.Dockerfile13
-rw-r--r--tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile8
-rw-r--r--tensorflow/tools/dockerfiles/partials/nvidia-devel.partial.Dockerfile43
-rw-r--r--tensorflow/tools/dockerfiles/partials/nvidia.partial.Dockerfile23
-rw-r--r--tensorflow/tools/dockerfiles/partials/python.partial.Dockerfile12
-rw-r--r--tensorflow/tools/dockerfiles/partials/shell.partial.Dockerfile2
-rw-r--r--tensorflow/tools/dockerfiles/partials/tensorflow.partial.Dockerfile2
-rw-r--r--tensorflow/tools/dockerfiles/partials/ubuntu-devel.partial.Dockerfile24
-rw-r--r--tensorflow/tools/dockerfiles/partials/ubuntu.partial.Dockerfile2
-rw-r--r--tensorflow/tools/dockerfiles/spec.yml195
-rw-r--r--tensorflow/workspace.bzl8
616 files changed, 12251 insertions, 4605 deletions
diff --git a/README.md b/README.md
index 16d354ca7b..823c688096 100644
--- a/README.md
+++ b/README.md
@@ -100,7 +100,7 @@ The TensorFlow project strives to abide by generally accepted best practices in
| **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA |
| **IBM ppc64le GPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA |
| **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) |
-| **Linux CPU with Intel® MKL-DNN** Python 2.7<br> **Linux CPU with Intel® MKL-DNN** Python 3.5<br> **Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild)|[1.9.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp27-cp27mu-linux_x86_64.whl)<br>[1.9.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp35-cp35m-linux_x86_64.whl)<br>[1.9.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp36-cp36m-linux_x86_64.whl) |
+| **Linux CPU with Intel® MKL-DNN** Python 2.7<br> **Linux CPU with Intel® MKL-DNN** Python 3.5<br> **Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild)|[1.10.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp27-cp27mu-linux_x86_64.whl)<br>[1.10.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp35-cp35m-linux_x86_64.whl)<br>[1.10.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp36-cp36m-linux_x86_64.whl) |
## For more information
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index 2220d0786d..59b961cdd9 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -32,7 +32,6 @@ cc_library(
deps = [
":embedded_protocol_buffers",
"//tensorflow/compiler/tf2xla",
- "//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/tf2xla:tf2xla_proto",
"//tensorflow/compiler/tf2xla:tf2xla_util",
@@ -56,6 +55,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -72,6 +72,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
"@llvm//:support", # fixdeps: keep
"@llvm//:x86_code_gen", # fixdeps: keep
],
@@ -100,6 +101,7 @@ cc_library(
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -195,6 +197,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
"@llvm//:support",
"@llvm//:target",
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index 44291d977f..e77a8fecf0 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -20,9 +20,10 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_replace.h"
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
-#include "tensorflow/compiler/tf2xla/str_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
@@ -30,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
@@ -142,7 +142,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
}
rewrites->push_back({"{{I}}", strings::StrCat(i)});
rewrites->push_back({"{{TYPE}}", type});
- rewrites->push_back({"{{DIM_VARS}}", str_util::Join(dim_vars, ", ")});
+ rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
rewrites->push_back({"{{INDICES}}", indices});
return Status::OK();
@@ -158,8 +158,9 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
// text-templating mechanism.
string RewriteWithName(const string& name, string code,
const std::vector<std::pair<string, string>>& rewrites) {
- str_util::ReplaceAllPairs(&code, rewrites);
- return str_util::StringReplace(code, "{{NAME}}", name, /*replace_all=*/true);
+ absl::StrReplaceAll(rewrites, &code);
+ absl::StrReplaceAll({{"{{NAME}}", name}}, &code);
+ return code;
}
// Generate methods for args (inputs).
@@ -571,11 +572,11 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)},
{"{{ARG_NAMES_CODE}}", arg_names_code},
{"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())},
- {"{{ARG_INDEX_TABLE}}", str_util::Join(arg_index_table, ", ")},
+ {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
{"{{CLASS}}", opts.class_name},
{"{{DECLS_FROM_OBJ_FILE}}",
- str_util::Join(metadata_result.header_variable_decls, "\n")},
+ absl::StrJoin(metadata_result.header_variable_decls, "\n")},
{"{{ENTRY}}", compile_result.entry_point},
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
metadata_result.hlo_profile_printer_data_access_shim},
@@ -595,8 +596,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
{"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())},
{"{{BUFFER_INFOS_AS_STRING}}",
- str_util::Join(buffer_infos_as_strings, ",\n")}};
- str_util::ReplaceAllPairs(header, rewrites);
+ absl::StrJoin(buffer_infos_as_strings, ",\n")}};
+ absl::StrReplaceAll(rewrites, header);
return Status::OK();
}
diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc
index 60d59ae996..e3a53edb73 100644
--- a/tensorflow/compiler/aot/codegen_test.cc
+++ b/tensorflow/compiler/aot/codegen_test.cc
@@ -18,13 +18,13 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/strings/match.h"
#include "llvm/Support/TargetSelect.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
@@ -34,9 +34,9 @@ namespace {
using ::tensorflow::cpu_function_runtime::BufferInfo;
-void ExpectErrorContains(const Status& status, StringPiece str) {
+void ExpectErrorContains(const Status& status, absl::string_view str) {
EXPECT_NE(Status::OK(), status);
- EXPECT_TRUE(str_util::StrContains(status.error_message(), str))
+ EXPECT_TRUE(absl::StrContains(status.error_message(), str))
<< "expected error: " << status.error_message() << " to contain: " << str;
}
diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc
index 8fb2fad31c..1401aae758 100644
--- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc
+++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include "absl/memory/memory.h"
+#include "absl/strings/str_replace.h"
#include "llvm/ADT/Triple.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/LLVMContext.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
-#include "tensorflow/compiler/tf2xla/str_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/util.h"
@@ -65,14 +65,13 @@ static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name,
" return proto;\n"
" }()";
- str_util::ReplaceAllPairs(
- &code,
+ return absl::StrReplaceAll(
+ code,
{
{"{{ARRAY_SYMBOL}}", strings::StrCat(protobuf_array_symbol_name)},
{"{{ARRAY_SIZE}}", strings::StrCat(protobuf_array_size)},
{"{{PROTOBUF_NAME}}", strings::StrCat(qualified_cpp_protobuf_name)},
});
- return code;
}
static StatusOr<string> CodegenModule(llvm::TargetMachine* target_machine,
@@ -97,7 +96,7 @@ static StatusOr<std::unique_ptr<llvm::TargetMachine>>
GetTargetMachineFromTriple(StringPiece target_triple) {
std::string error;
std::string normalized_triple =
- llvm::Triple::normalize(AsStringRef(target_triple));
+ llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple)));
const llvm::Target* target =
llvm::TargetRegistry::lookupTarget(normalized_triple, error);
if (target == nullptr) {
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index 0ecc3feeb6..7364d63b53 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -226,5 +226,6 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//third_party/eigen3",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index 0c0c676ece..dd2b151098 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#define EIGEN_USE_CUSTOM_THREAD_POOL
+#include "absl/strings/str_split.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -546,7 +546,7 @@ TEST(TFCompileTest, HloProfiling) {
VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string;
std::vector<string> hlo_profile_lines =
- tensorflow::str_util::Split(hlo_profile_as_string, '\n');
+ absl::StrSplit(hlo_profile_as_string, '\n');
auto header = HasSubstr("Execution profile for");
auto total_cycles_profile_line = HasSubstr("[total]");
diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc
index 839e1588b7..f3c44e9dda 100644
--- a/tensorflow/compiler/aot/tfcompile_main.cc
+++ b/tensorflow/compiler/aot/tfcompile_main.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/match.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/compile.h"
#include "tensorflow/compiler/aot/flags.h"
@@ -34,7 +36,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -55,7 +56,7 @@ const char kUsageHeader[] =
"\n";
Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
- if (str_util::EndsWith(fname, ".pbtxt")) {
+ if (absl::EndsWith(fname, ".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
return ReadBinaryProto(Env::Default(), fname, proto);
@@ -75,7 +76,7 @@ Status Main(const MainFlags& flags) {
for (const tf2xla::Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name());
}
- std::cout << str_util::Join(nodes, ",");
+ std::cout << absl::StrJoin(nodes, ",");
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 2466c218c8..df81f3c23e 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -311,6 +311,51 @@ tf_cc_test(
)
cc_library(
+ name = "resource_operation_safety_analysis",
+ srcs = ["resource_operation_safety_analysis.cc"],
+ hdrs = ["resource_operation_safety_analysis.h"],
+ deps = [
+ "//tensorflow/compiler/jit/graphcycles",
+ "//tensorflow/compiler/tf2xla:resource_operation_table",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ ],
+)
+
+tf_cc_test(
+ name = "resource_operation_safety_analysis_test",
+ srcs = ["resource_operation_safety_analysis_test.cc"],
+ deps = [
+ ":common",
+ ":resource_operation_safety_analysis",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
+ "//tensorflow/cc:function_ops",
+ "//tensorflow/cc:functional_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:resource_variable_ops",
+ "//tensorflow/cc:sendrecv_ops",
+ "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
name = "compilation_passes",
srcs = [
"build_xla_launch_ops_pass.cc",
@@ -335,11 +380,10 @@ cc_library(
":union_find",
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
- "//tensorflow/compiler/jit/kernels:parallel_check_op",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
- "//tensorflow/compiler/jit/ops:parallel_check_op",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:dump_graph",
+ "//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
@@ -351,6 +395,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
+ "@com_google_absl//absl/strings",
],
)
@@ -359,6 +404,7 @@ cc_library(
srcs = ["xla_cluster_util.cc"],
hdrs = ["xla_cluster_util.h"],
deps = [
+ ":resource_operation_safety_analysis",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
@@ -437,6 +483,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
+ "//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/compiler/jit/kernels:xla_launch_op",
"//tensorflow/compiler/tf2xla:xla_compiler",
@@ -448,6 +495,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "@com_google_absl//absl/strings",
],
)
@@ -528,6 +576,9 @@ tf_cuda_cc_test(
":common",
":xla_cluster_util",
":xla_fusion_optimizer",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:resource_variable_ops",
"//tensorflow/core:graph",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc
index 1b1ce78ed2..a7f8a5613c 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op.cc
@@ -126,7 +126,8 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
const DataTypeVector& arg_types = (*fbody)->arg_types;
std::vector<bool> const_args(arg_types.size());
// If we can't analyze the const args. Bail out.
- TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args));
+ TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
+ *((*fbody)->graph), &const_args, /*compile_time_const_nodes=*/nullptr));
for (int i = 0; i < const_args.size(); ++i) {
if (const_args[i]) {
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index 0ca0f949dc..fe28502f69 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/deadness_analysis.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h"
@@ -153,7 +154,7 @@ class AndPredicate : public Predicate {
std::back_inserter(operands_str),
[](Predicate* pred) { return pred->ToString(); });
- return strings::StrCat("(", str_util::Join(operands_str, " & "), ")");
+ return strings::StrCat("(", absl::StrJoin(operands_str, " & "), ")");
}
Kind kind() const override { return Kind::kAnd; }
@@ -182,7 +183,7 @@ class OrPredicate : public Predicate {
std::back_inserter(operands_str),
[](Predicate* pred) { return pred->ToString(); });
- return strings::StrCat("(", str_util::Join(operands_str, " | "), ")");
+ return strings::StrCat("(", absl::StrJoin(operands_str, " | "), ")");
}
Kind kind() const override { return Kind::kOr; }
diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc
index cc9f102398..28a56044d5 100644
--- a/tensorflow/compiler/jit/deadness_analysis_test.cc
+++ b/tensorflow/compiler/jit/deadness_analysis_test.cc
@@ -32,7 +32,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index f150bf1819..2788102620 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph.h"
@@ -44,7 +45,6 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
@@ -2504,7 +2504,8 @@ Status EncapsulateSubgraphsPass::Run(
const int num_args = input_permutation->size();
std::vector<bool> const_args(num_args);
- TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args));
+ TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
+ **subgraph, &const_args, /*compile_time_const_nodes=*/nullptr));
DataTypeVector arg_types(num_args);
TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types));
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
index c0543a0079..b3600fc48b 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/function_testlib.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/equal_graph_def.h"
@@ -124,8 +124,8 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
std::unordered_set<string> control_input_a;
std::unordered_set<string> control_input_b;
for (int i = 0; i < a.input_size(); ++i) {
- if (str_util::StartsWith(a.input(i), "^")) {
- if (!str_util::StartsWith(b.input(i), "^")) {
+ if (absl::StartsWith(a.input(i), "^")) {
+ if (!absl::StartsWith(b.input(i), "^")) {
if (diff) {
*diff = strings::StrCat(
diff_preamble, " mismatch for node ", a.name(), " input ", i,
@@ -768,7 +768,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
Graph* graph = graph_ptr->get();
for (const Node* n : graph->nodes()) {
if (n->type_string() == "_Arg" &&
- str_util::StartsWith(n->name(), "const")) {
+ absl::StartsWith(n->name(), "const")) {
++guaranteed_consts;
EXPECT_TRUE(HasGuaranteeConstAttr(*n));
} else {
@@ -813,7 +813,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
Graph* graph = graph_ptr->get();
for (const Node* n : graph->nodes()) {
if (n->type_string() == "_Arg" &&
- str_util::StartsWith(n->name(), "const")) {
+ absl::StartsWith(n->name(), "const")) {
++guaranteed_consts;
EXPECT_TRUE(HasGuaranteeConstAttr(*n));
} else {
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 8f78c110cb..253a5d2547 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -29,16 +29,3 @@ cc_library(
],
alwayslink = 1,
)
-
-cc_library(
- name = "parallel_check_op",
- srcs = ["parallel_check_op.cc"],
- visibility = ["//tensorflow/compiler/jit:friends"],
- deps = [
- "//tensorflow/compiler/jit/legacy_flags:parallel_check_op_flags",
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- ],
- alwayslink = 1,
-)
diff --git a/tensorflow/compiler/jit/kernels/parallel_check_op.cc b/tensorflow/compiler/jit/kernels/parallel_check_op.cc
deleted file mode 100644
index bd4eefbc0b..0000000000
--- a/tensorflow/compiler/jit/kernels/parallel_check_op.cc
+++ /dev/null
@@ -1,144 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h"
-#include "tensorflow/core/common_runtime/device.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/macros.h"
-
-namespace tensorflow {
-namespace {
-
-// Inputs 2*N tensors, outputs the first N inputs.
-// Logs errors if input tensor i and i + N are not (near) identical
-// in any position.
-class ParallelCheckOp : public OpKernel {
- public:
- explicit ParallelCheckOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
- template <typename T>
- int CompareTensors(DataType dtype, const char* v0, const char* v1,
- int64 num_elts, int input_idx) {
- int failed = 0;
- const T* p0 = reinterpret_cast<const T*>(v0);
- const T* p1 = reinterpret_cast<const T*>(v1);
- double rtol;
- legacy_flags::ParallelCheckOpFlags* flags =
- legacy_flags::GetParallelCheckOpFlags();
- if (!tensorflow::strings::safe_strtod(flags->parallel_check_rtol.c_str(),
- &rtol)) {
- LOG(ERROR) << "can't convert parallel_check_rtol "
- << flags->parallel_check_rtol << " to double";
- }
- double atol;
- if (!tensorflow::strings::safe_strtod(flags->parallel_check_atol.c_str(),
- &atol)) {
- LOG(ERROR) << "can't convert parallel_check_atol "
- << flags->parallel_check_atol << " to double";
- }
- for (int i = 0; i < num_elts; ++i) {
- bool ok = (p0[i] == p1[i]);
- VLOG(2) << "output " << input_idx << " element " << i << ": " << p0[i];
- if (!ok) {
- if (std::is_same<T, float>::value || std::is_same<T, double>::value) {
- float tolerance =
- std::max(atol, std::max(fabs(rtol * p0[i]), fabs(rtol * p1[i])));
- T diff = p0[i] - p1[i];
- if (diff < 0) diff = 0 - diff;
- ok = (diff <= tolerance);
- }
- if (ok) continue;
- LOG(ERROR) << "Op " << name() << " fails equality at output "
- << input_idx << " type " << DataTypeString(dtype)
- << " element " << i << ": std_val=" << p0[i]
- << " test_val=" << p1[i] << " diff=" << (p0[i] - p1[i]);
- if (++failed > 10) break;
- }
- }
- return failed;
- }
-
- void Compute(OpKernelContext* ctx) override {
- VLOG(1) << "Compute " << name();
- const int num_pairs = ctx->num_inputs() / 2;
- for (int i = 0; i < num_pairs; ++i) {
- CHECK_EQ(ctx->input_dtype(i), ctx->input_dtype(i + num_pairs));
- Tensor t0 = ctx->input(i);
- Tensor t1 = ctx->input(i + num_pairs);
- int64 num_elts = t0.NumElements();
- CHECK_EQ(num_elts, t1.NumElements());
-
- // Compare inputs elementwise for near-exact equality.
- const char* v0 = t0.tensor_data().data();
- const char* v1 = t1.tensor_data().data();
- int failed = 0;
- switch (ctx->input_dtype(i)) {
- case DT_INT32:
- failed =
- CompareTensors<int32>(ctx->input_dtype(i), v0, v1, num_elts, i);
- break;
- case DT_INT64:
- failed =
- CompareTensors<int64>(ctx->input_dtype(i), v0, v1, num_elts, i);
- break;
- case DT_FLOAT:
- failed =
- CompareTensors<float>(ctx->input_dtype(i), v0, v1, num_elts, i);
- break;
- case DT_DOUBLE:
- failed =
- CompareTensors<double>(ctx->input_dtype(i), v0, v1, num_elts, i);
- break;
- case DT_BOOL:
- failed =
- CompareTensors<bool>(ctx->input_dtype(i), v0, v1, num_elts, i);
- break;
- default:
- LOG(FATAL) << "unimpl: " << ctx->input_dtype(i);
- }
- if (failed > 0) {
- LOG(ERROR) << "check failed for " << name() << " output " << i
- << " num_elts: " << num_elts;
- legacy_flags::ParallelCheckOpFlags* flags =
- legacy_flags::GetParallelCheckOpFlags();
- if (flags->parallel_check_failfast) {
- LOG(QFATAL) << "failfast on first parallel-check failure";
- }
- } else {
- VLOG(1) << "check passed for " << name() << " output " << i
- << " num_elts: " << num_elts;
- }
-
- // Propagate the std value.
- if (IsRefType(ctx->input_dtype(i))) {
- ctx->forward_ref_input_to_ref_output(i, i);
- } else {
- ctx->set_output(i, ctx->input(i));
- }
- }
- }
-
- TF_DISALLOW_COPY_AND_ASSIGN(ParallelCheckOp);
-};
-
-REGISTER_KERNEL_BUILDER(Name("ParallelCheck").Device(DEVICE_CPU),
- ParallelCheckOp);
-
-} // namespace
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index ddb27a38ae..fde4135bf7 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -187,7 +187,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(
ctx, cache->Compile(options, function_, constant_args, variables, ctx,
- &kernel, &executable, &compile_options));
+ &kernel, &executable, compile_options));
VLOG(1) << "Executing XLA Computation...";
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 11bd5eec23..518c39ec15 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -27,7 +27,9 @@ limitations under the License.
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
+#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -40,6 +42,7 @@ limitations under the License.
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"
@@ -74,18 +77,40 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok();
}
+bool HasResourceOutput(const Node& node) {
+ return std::find(node.output_types().begin(), node.output_types().end(),
+ DT_RESOURCE) != node.output_types().end();
+}
+
+bool HasResourceInput(const Node& node) {
+ return std::find(node.input_types().begin(), node.input_types().end(),
+ DT_RESOURCE) != node.input_types().end();
+}
+
+// Returns true if `node` is a resource operation recognized by tf2xla that
+// operates on something other than resource variables.
+bool IsNonResourceVarResourceOp(const Node& node) {
+ // TODO(b/112837194): We can't cluster these because we only support
+ // snapshotting resource variables (and we can't e.g. snapshot stacks). This
+ // limitation may be fixable with some work.
+ const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(node.type_string());
+ return op_info && op_info->resource_kind() != XlaResourceKind::kVariable;
+}
+
// Make sure we don't recurse infinitely on recursive functions.
const int kMaxRecursionDepth = 10;
bool IsCompilableCall(const NodeDef& call_def,
- const DeviceType& jit_device_type, int depth,
+ const DeviceType& jit_device_type,
+ bool allow_resource_ops, int depth,
FunctionLibraryRuntime* lib_runtime);
// Tests whether 'while_node' is a completely compilable loop.
// Every operator in the condition and body functions must be compilable for a
// while loop to be compilable.
bool IsCompilableWhile(const Node& while_node,
- const DeviceType& jit_device_type, int depth,
+ const DeviceType& jit_device_type,
+ bool allow_resource_ops, int depth,
FunctionLibraryRuntime* lib_runtime) {
const NameAttrList* name_attr;
NodeDef call;
@@ -100,7 +125,8 @@ bool IsCompilableWhile(const Node& while_node,
call.set_name("while_cond");
call.set_op(cond_func);
*call.mutable_attr() = name_attr->attr();
- if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
+ if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1,
+ lib_runtime)) {
VLOG(2) << "Rejecting While " << while_node.name()
<< ": can't compile loop condition: " << cond_func;
return false;
@@ -115,7 +141,8 @@ bool IsCompilableWhile(const Node& while_node,
call.set_name("while_body");
call.set_op(body_func);
*call.mutable_attr() = name_attr->attr();
- if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
+ if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1,
+ lib_runtime)) {
VLOG(2) << "Rejecting While " << while_node.name()
<< ": can't compile loop body: " << body_func;
return false;
@@ -127,7 +154,8 @@ bool IsCompilableWhile(const Node& while_node,
// Every operator in the function must be compilable for a function to be
// compilable.
bool IsCompilableCall(const NodeDef& call_def,
- const DeviceType& jit_device_type, int depth,
+ const DeviceType& jit_device_type,
+ bool allow_resource_ops, int depth,
FunctionLibraryRuntime* lib_runtime) {
if (depth > kMaxRecursionDepth) {
VLOG(2) << "Rejecting " << call_def.op()
@@ -167,12 +195,17 @@ bool IsCompilableCall(const NodeDef& call_def,
if (node->type_string() == "_Arg" || node->type_string() == "_Retval")
continue;
if (node->type_string() == "While") {
- // Handle functional While loop (not in open source build).
- return IsCompilableWhile(*node, jit_device_type, depth + 1, lib_runtime);
+ // Handle functional While loop.
+ return IsCompilableWhile(*node, jit_device_type, allow_resource_ops,
+ depth + 1, lib_runtime);
+ }
+ if (!allow_resource_ops &&
+ (HasResourceInput(*node) || HasResourceOutput(*node))) {
+ return false;
}
if (!HasXLAKernel(*node, jit_device_type) &&
- !IsCompilableCall(node->def(), jit_device_type, depth + 1,
- lib_runtime)) {
+ !IsCompilableCall(node->def(), jit_device_type, allow_resource_ops,
+ depth + 1, lib_runtime)) {
VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op "
<< node->name() << ": " << node->def().ShortDebugString();
return false;
@@ -343,6 +376,10 @@ Status FindCompilationCandidates(
flib_def, opts));
FunctionLibraryRuntime* lib_runtime =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
+ std::vector<bool> compile_time_const_nodes(graph.num_node_ids(), false);
+ TF_RETURN_IF_ERROR(
+ BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr,
+ &compile_time_const_nodes));
int64& fuel =
legacy_flags::GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel;
@@ -386,19 +423,46 @@ Status FindCompilationCandidates(
XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration));
DeviceType jit_device_type(registration->compilation_device_name);
if (!HasXLAKernel(*node, jit_device_type) &&
- !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) {
+ !IsCompilableCall(node->def(), jit_device_type,
+ registration->compile_resource_ops, 0, lib_runtime)) {
VLOG(2) << "Rejecting " << node->name() << ": unsupported op "
<< node->type_string();
continue;
}
if (!registration->compile_resource_ops &&
- HasResourceInputOrOutput(*node)) {
- VLOG(2) << "Rejecting: " << node->name() << ": resource input/output "
+ (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) {
+ // We don't have a way of returning values of type DT_RESOURCE from XLA
+ // computations so we avoid auto-clustering nodes producing DT_RESOURCE.
+ // XlaLaunchOp also cannot snapshot resources that are not resource
+ // variables so we avoid clustering resource operations that operate on
+ // non-resource variables.
+ VLOG(2) << "Rejecting: " << node->name() << ": resource output "
<< node->type_string();
continue;
}
+ if (compile_time_const_nodes[node->id()] &&
+ !registration->requires_compilation) {
+ const OpDef* op_def;
+ TF_RETURN_IF_ERROR(
+ OpRegistry::Global()->LookUpOpDef(node->type_string(), &op_def));
+ if (op_def->is_stateful()) {
+ // We need to be able to constant fold the nodes in
+ // compile_time_const_nodes given constant inputs (required by XLA) and
+ // therefore can't auto-cluster stateful ops since these can never be
+ // constant folded.
+ VLOG(2) << "Rejecting " << node->name()
+ << ": must-be-constant stateful op";
+ continue;
+ }
+ }
+ // We don't auto-cluster functional control flow nodes containing resource
+ // operations because safety checks are trickier in this case.
+ // registration->compile_resource_ops is true for XLA_CPU/XLA_GPU but not
+ // for CPU/GPU.
if (node->type_string() == "While" &&
- !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) {
+ !IsCompilableWhile(*node, jit_device_type,
+ registration->compile_resource_ops, 0,
+ lib_runtime)) {
continue;
}
// _Arg nodes in a top-level function represent feeds.
@@ -457,7 +521,11 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
&registration));
DeviceType jit_device_type(registration->compilation_device_name);
- return IsCompilableCall(ndef, jit_device_type, 0, flr);
+
+ // We can always *compile* resource operations, even if we are sometimes
+ // unable to auto-cluster them.
+ const bool compile_resource_ops = true;
+ return IsCompilableCall(ndef, jit_device_type, compile_resource_ops, 0, flr);
}
Status MarkForCompilationPass::Run(
@@ -609,6 +677,43 @@ static bool IsShapeConsumerOp(const Node& node) {
node.type_string() == "Size";
}
+static Status IgnoreResourceOpForSafetyAnalysis(const Node& n, bool* ignore) {
+ // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then
+ // ignore it during resource operation safety analysis. We need this hack
+ // because of two reasons:
+ //
+ // 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled.
+ // 2. We don't support live-out values of type DT_RESOURCE and live-in values
+ // of type DT_RESOURCE that are not resource variables.
+ //
+ // Together these imply we cannot let resource variable safety analysis
+ // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different
+ // clusters: both of them will have to be clustered because of (1) and we
+ // won't be able to keep the edge between the two as neither the input to the
+ // second XLA cluster nor the output from the first XLA cluster are supported
+ // because of (2).
+ //
+ // TODO(b/113100872): This can be fixed if the TensorFlow representation for
+ // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then
+ // (2) would no longer hold.
+
+ if (n.assigned_device_name().empty()) {
+ *ignore = false;
+ return Status::OK();
+ }
+ DeviceType device_type("");
+ TF_RETURN_IF_ERROR(
+ DeviceToDeviceType(n.assigned_device_name(), &device_type));
+
+ const XlaOpRegistry::DeviceRegistration* registration;
+ if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
+ *ignore = true;
+ } else {
+ *ignore = registration->compile_resource_ops;
+ }
+ return Status::OK();
+}
+
// Sequence number generator to ensure clusters have unique names.
static std::atomic<int64> cluster_sequence_num;
@@ -637,6 +742,8 @@ Status MarkForCompilationPass::RunImpl(
GraphCycles cycles;
TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles));
+ TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
+ graph, options.flib_def, IgnoreResourceOpForSafetyAnalysis, &cycles));
// Each compilation candidate belongs to a cluster. The cluster's
// representative
@@ -675,7 +782,7 @@ Status MarkForCompilationPass::RunImpl(
string to_scope;
for (int to : cycles.Successors(from)) {
if (to >= graph->num_node_ids()) {
- // Node is a "frame" node that is present only in the cycle detection
+ // Node is a fictitious node that is present only in the cycle detection
// graph. No clustering is possible.
continue;
}
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 9d7ac0d609..807ab51fd3 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -15,10 +15,12 @@ limitations under the License.
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
+#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/defs.h"
@@ -26,11 +28,11 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -48,9 +50,35 @@ std::unordered_map<string, string> GetClusters(const Graph& graph) {
ids[node->name()] = cluster;
}
}
+
+ if (VLOG_IS_ON(2)) {
+ VLOG(2) << "Clusters:";
+ for (const auto& p : ids) {
+ VLOG(2) << " " << p.first << " -> " << p.second;
+ }
+ }
return ids;
}
+gtl::FlatMap<string, std::vector<string>> GetClusterSets(
+ const Graph& g, std::vector<string>* cluster_names = nullptr) {
+ CHECK(cluster_names == nullptr || cluster_names->empty());
+ gtl::FlatMap<string, std::vector<string>> cluster_sets;
+ for (const auto& p : GetClusters(g)) {
+ cluster_sets[p.second].push_back(p.first);
+ }
+ for (auto& p : cluster_sets) {
+ if (cluster_names != nullptr) {
+ cluster_names->push_back(p.first);
+ }
+ std::sort(p.second.begin(), p.second.end());
+ }
+ if (cluster_names != nullptr) {
+ std::sort(cluster_names->begin(), cluster_names->end());
+ }
+ return cluster_sets;
+}
+
TEST(XlaCompilationTest, Chains) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
@@ -501,38 +529,104 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
EXPECT_EQ(clusters["B"], clusters["C"]);
}
-REGISTER_OP("ResourceInput").Input("a: resource").Output("o: float");
-REGISTER_OP("ResourceOutput").Input("a: float").Output("o: resource");
-
namespace {
+Node* MakeRead(const Scope& scope, const string& id) {
+ Output var_handle =
+ ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
+ Output read =
+ ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT);
+ return read.node();
+}
-class DummyOp : public XlaOpKernel {
- using XlaOpKernel::XlaOpKernel;
- void Compile(XlaOpKernelContext* ctx) override {}
-};
-
-REGISTER_XLA_OP(Name("ResourceInput"), DummyOp);
-REGISTER_XLA_OP(Name("ResourceOutput"), DummyOp);
+Node* MakeWrite(const Scope& scope, const string& id) {
+ Output var_handle =
+ ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
+ Output value_to_write =
+ ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
+ ops::AssignVariableOp assign_op(scope.WithOpName("Assignment" + id),
+ var_handle, value_to_write);
+ return assign_op.operation.node();
+}
+Node* MakeNeutral(const Scope& scope, const string& id) {
+ return ops::Const(scope.WithOpName("Const" + id), 42.0f).node();
+}
} // namespace
-TEST(XlaCompilationTest, Resources) {
+TEST(XlaCompilationTest, ResourcesClusteringAllowed) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(read, write);
+
+ FixupSourceAndSinkEdges(root.graph());
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
- GraphDef graphdef;
- {
- GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
- Node* a =
- ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
- Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
- // We should not form clusters with resource ops by default.
- Node* c = ops::UnaryOp("ResourceOutput", b, builder.opts().WithName("C"));
- Node* d = ops::UnaryOp("ResourceInput", c, builder.opts().WithName("D"));
- ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
- TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
- }
+ TF_EXPECT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
- auto clusters = GetClusters(*graph);
- EXPECT_EQ(0, clusters.size()); // Nothing should be compiled.
+ gtl::FlatMap<string, std::vector<string>> cluster_sets =
+ GetClusterSets(*graph);
+ ASSERT_EQ(cluster_sets.size(), 1);
+ std::vector<string> expected_clustered_nodes = {"AssignmentW", "ReadR",
+ "ValueToAssignW"};
+ ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes);
+}
+
+TEST(XlaCompilationTest, ResourcesClusteringDisallowed) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(write, read);
+
+ FixupSourceAndSinkEdges(root.graph());
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_EXPECT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+ gtl::FlatMap<string, std::vector<string>> cluster_sets =
+ GetClusterSets(*graph);
+ ASSERT_EQ(cluster_sets.size(), 1);
+ std::vector<string> expected_clustered_nodes = {"AssignmentW",
+ "ValueToAssignW"};
+ ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes);
+}
+
+TEST(XlaCompilationTest, ChainOfOps) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* write_0 = MakeWrite(root, "W0");
+ Node* neutral_0 = MakeNeutral(root, "N0");
+ Node* read_0 = MakeRead(root, "R0");
+ Node* write_1 = MakeWrite(root, "W1");
+ Node* neutral_1 = MakeNeutral(root, "N1");
+ Node* read_1 = MakeRead(root, "R1");
+
+ root.graph()->AddControlEdge(write_0, neutral_0);
+ root.graph()->AddControlEdge(neutral_0, read_0);
+ root.graph()->AddControlEdge(read_0, write_1);
+ root.graph()->AddControlEdge(write_1, neutral_1);
+ root.graph()->AddControlEdge(neutral_1, read_1);
+
+ FixupSourceAndSinkEdges(root.graph());
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_EXPECT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::vector<string> cluster_names;
+ gtl::FlatMap<string, std::vector<string>> cluster_sets =
+ GetClusterSets(*graph, &cluster_names);
+
+ ASSERT_EQ(cluster_sets.size(), 2);
+
+ std::vector<string> expected_clustered_nodes_a = {"AssignmentW0", "ConstN0",
+ "ValueToAssignW0"};
+ ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a);
+
+ std::vector<string> expected_clustered_nodes_b = {
+ "AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"};
+ ASSERT_EQ(cluster_sets[cluster_names[1]], expected_clustered_nodes_b);
}
TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
@@ -562,11 +656,11 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph);
EXPECT_FALSE(status.ok());
- EXPECT_TRUE(str_util::StrContains(status.ToString(),
- "Edge from c to a would create a cycle.\n"
- "+-> a\n"
- "| b\n"
- "+-- c\n"));
+ EXPECT_TRUE(absl::StrContains(status.ToString(),
+ "Edge from c to a would create a cycle.\n"
+ "+-> a\n"
+ "| b\n"
+ "+-- c\n"));
}
TEST(XlaCompilationTest, Retval) {
@@ -731,5 +825,27 @@ TEST(XlaCompilationTest, ClusterControlTrigger) {
EXPECT_EQ(clusters, expected_clusters);
}
+TEST(XlaCompilationTest, RandomShape) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output shape_shape = ops::Const(root.WithOpName("shape_shape"), {2}, {1});
+ Output shape =
+ ops::RandomUniformInt(root.WithOpName("shape"), shape_shape,
+ ops::Const(root.WithOpName("minval"), 1),
+ ops::Const(root.WithOpName("maxval"), 20));
+ Output reshape_input =
+ ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({500, 500})));
+ Output reshape =
+ ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+ EXPECT_EQ(clusters["shape"], "");
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD
index c9e46bc147..13804c6a05 100644
--- a/tensorflow/compiler/jit/ops/BUILD
+++ b/tensorflow/compiler/jit/ops/BUILD
@@ -10,10 +10,3 @@ cc_library(
deps = ["//tensorflow/core:framework"],
alwayslink = 1,
)
-
-cc_library(
- name = "parallel_check_op",
- srcs = ["parallel_check_op.cc"],
- deps = ["//tensorflow/core:framework"],
- alwayslink = 1,
-)
diff --git a/tensorflow/compiler/jit/ops/parallel_check_op.cc b/tensorflow/compiler/jit/ops/parallel_check_op.cc
deleted file mode 100644
index db5c195578..0000000000
--- a/tensorflow/compiler/jit/ops/parallel_check_op.cc
+++ /dev/null
@@ -1,30 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/framework/op.h"
-
-namespace tensorflow {
-
-REGISTER_OP("ParallelCheck")
- .Attr("T: list(type) >= 0")
- .Input("expected: T")
- .Input("actual: T")
- .Output("result: T")
- .Doc(R"doc(
-Op that compares two sets of inputs for near-identity, and propagates the first.
-Inequality is logged to ERROR log.
-)doc");
-
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
index 08a956e4c6..f61a955c22 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
@@ -32,7 +32,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
new file mode 100644
index 0000000000..1ba4a5ef73
--- /dev/null
+++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
@@ -0,0 +1,336 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// ALGORITHM OVERVIEW
+// ==================
+//
+// An XLA cluster hoists all resource reads to be beginning of the cluster
+// execution and all the resource writes to the end. This means it cannot
+// enforce arbitrary ordering dependencies (via control or data edges) between
+// resource operations. Since all resource reads happen before all resource
+// writes, edges constraining resource reads to happen before resource writes
+// are fine, but all other kinds of edges are problematic. This analysis
+// computes the set of pairs of resource operations that cannot be put in the
+// same cluster because XLA cannot respect the dependencies between them in the
+// TensorFlow program.
+//
+// TODO(b/112856632): We can, in theory, support Read->Read and Write->Write
+// dependencies.
+//
+// Specifically the result computed by this analysis contains the edge {W, R}
+// iff all of these hold true:
+//
+// - In the graph (g - {edges from NextIteration to Merge}) there is a path
+// from W to R.
+// - IsEdgeSafe(W, R) == False [defined below]
+// - W != R (note: some resource operations both read from and write to
+// resource variables).
+//
+// The result is incorrect around loops because we ignore edges from
+// NextIteration to Merge, but that should be fine because we don't cluster
+// these edges. For instance, in:
+//
+// Init -----> Merge <-------+
+// | |
+// v |
+// Read |
+// | |
+// v |
+// Write |
+// | |
+// v |
+// NextIteration --+
+//
+// we won't put (Read, Write) in the returned set. This is fine if
+// auto-clustering can only cluster the Read->Write edge, but it is a problem if
+// it clusters the Write->NextIteration->Merge->Read edges instead. The same
+// problem is present for the functional version of the loop above. We rely on
+// auto-clustering to not cluster control flow edges like NextIteration->Merge.
+// This is enough to avoid the explicit-control-flow problem shown above. One
+// way to think about this is that we only care about cases where two nodes, A
+// and B, would normally have been put in the same cluster but cannot legally be
+// in the same cluster because of resourcevar-dependencies. If A and B would
+// normally have been put in the same cluster then all paths between A and B
+// would have to be clusterable (otherwise we'd have introduced a cycle). Ergo
+// there could not have been a NextIteration->Merge edge between A and B since
+// we don't cluster these edges.
+//
+// We also rely on auto-clustering to not cluster functional control flow nodes
+// that contain resource operations.
+//
+// IMPLEMENTATION
+// --------------
+//
+// We traverse the graph minus backedges in reverse post order, mapping each
+// node to the set of resource operation reaching that node. Since we visit
+// producers before consumers, we can construct the set of reaching operations
+// by taking the union of the operations reaching the input nodes. These
+// "reaching resource operations" can then be used to create the pairs of
+// incompatible nodes using `IsEdgeSafe`.
+
+#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/optional.h"
+#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/tensor_id.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+namespace {
+// Returns true if `n` may call a function.
+Status MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def,
+ bool* out_result) {
+ if (flib_def->Contains(n.type_string())) {
+ *out_result = true;
+ } else {
+ *out_result =
+ std::any_of(n.def().attr().begin(), n.def().attr().end(),
+ [](const std::pair<string, AttrValue>& name_attr_pair) {
+ return name_attr_pair.second.has_func();
+ });
+ }
+
+ return Status::OK();
+}
+
+// Maps `n` to the XlaResourceOpKind corresponding to its operation. If `n` is
+// not a resource operation recognized by XLA then sets `out_resource_op_kind`
+// to nullopt.
+Status XlaResourceOpKindForNode(
+ const Node& n, const FunctionLibraryDefinition* flib_def,
+ const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
+ absl::optional<XlaResourceOpKind>* out_resource_op_kind) {
+ bool should_ignore = false;
+ if (resource_ops_to_ignore) {
+ TF_RETURN_IF_ERROR(resource_ops_to_ignore(n, &should_ignore));
+ }
+ if (should_ignore) {
+ *out_resource_op_kind = absl::nullopt;
+ return Status::OK();
+ }
+
+ const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string());
+ if (op_info) {
+ *out_resource_op_kind = op_info->kind();
+ return Status::OK();
+ }
+
+ // We conservatively assume that functions will both read and write resource
+ // variables. In the future we may consider doing some form of
+ // inter-procedural analysis.
+ bool may_call_function;
+ TF_RETURN_IF_ERROR(MayCallFunction(n, flib_def, &may_call_function));
+ if (may_call_function) {
+ *out_resource_op_kind = XlaResourceOpKind::kReadWrite;
+ } else {
+ *out_resource_op_kind = absl::nullopt;
+ }
+
+ return Status::OK();
+}
+
+// Returns true if a control or data dependence from a TensorFlow operation of
+// resource op kind `from` to a TensorFlow operation of resource op kind `to`
+// can be represented by an XLA cluster and needs no special handling around
+// auto-jit.
+bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) {
+ // XLA clusters forces all reads to happen before all writes, which means the
+ // kinds of edges it can faithfully represent are: Read->Write, Read->Modify,
+ // Modify->Write, Read->Read, Write->Write.
+ //
+ // TODO(b/112856632): We can, in theory, support Read->Read and Write->Write
+ // dependencies.
+ return from == XlaResourceOpKind::kRead && to == XlaResourceOpKind::kWrite;
+}
+
+using ResourceOp = std::pair<int, XlaResourceOpKind>;
+
+string ResourceOpToString(const ResourceOp& resource_op) {
+ return strings::StrCat(
+ resource_op.first, ": ",
+ XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second));
+}
+
+// A copy-on-write set used to store the set of ResourceOps reaching a node in a
+// TensorFlow graph.
+//
+// TODO(sanjoy): It may be useful to pull this out into its own header at some
+// point.
+class ResourceOpSet {
+ private:
+ using Impl = gtl::FlatSet<ResourceOp>;
+
+ public:
+ ResourceOpSet() = default;
+
+ // Adds all ResourceOp s in `other` to this set.
+ void Add(const ResourceOpSet& other) {
+ CHECK(!frozen_);
+ if (other.impl_ == impl_) {
+ other.frozen_ = true;
+ return;
+ }
+
+ if (!impl_) {
+ other.frozen_ = true;
+ impl_ = other.impl_;
+ return;
+ }
+
+ for (ResourceOp resource_op : other) {
+ Add(resource_op);
+ }
+ }
+
+ void Add(const ResourceOp& resource_op) {
+ CHECK(!frozen_);
+ if (!IsCopy() && Contains(resource_op)) {
+ // We can avoid the copy if the item we want to insert already exists.
+ return;
+ }
+
+ EnsureIsCopied();
+ impl_->insert(resource_op);
+ }
+
+ Impl::const_iterator begin() const {
+ return impl_ ? impl_->begin() : GetEmptyImpl()->begin();
+ }
+
+ Impl::const_iterator end() const {
+ return impl_ ? impl_->end() : GetEmptyImpl()->end();
+ }
+
+ bool Contains(const ResourceOp& resource_op) const {
+ return impl_ != nullptr && impl_->count(resource_op);
+ }
+
+ private:
+ bool IsCopy() const { return storage_ != nullptr; }
+
+ void EnsureIsCopied() {
+ if (storage_ == nullptr) {
+ storage_ = absl::make_unique<Impl>();
+ for (ResourceOp op : *this) {
+ storage_->insert(op);
+ }
+ impl_ = storage_.get();
+ }
+ }
+
+ static Impl* GetEmptyImpl() {
+ static Impl* empty_impl = new Impl;
+ return empty_impl;
+ }
+
+ Impl* impl_ = nullptr;
+ std::unique_ptr<Impl> storage_;
+
+ // frozen_ is true if there is another set pointing to this set's impl_. We
+ // can no longer add elements to this set in that case since the sets pointing
+ // to this set expect the contents of this set to be stable.
+ mutable bool frozen_ = false;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ResourceOpSet);
+};
+
+string ResourceOpSetToString(const ResourceOpSet& resource_op_set) {
+ std::vector<string> elements_debug_string;
+ std::transform(resource_op_set.begin(), resource_op_set.end(),
+ std::back_inserter(elements_debug_string), ResourceOpToString);
+ return strings::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}");
+}
+
+string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) {
+ return strings::StrCat(
+ "[", n.name(), ": ", n.type_string(), "(",
+ XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]");
+}
+} // namespace
+
+Status ComputeIncompatibleResourceOperationPairs(
+ const Graph& g, const FunctionLibraryDefinition* flib_def,
+ const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
+ std::vector<std::pair<int, int>>* result) {
+ CHECK(result->empty());
+
+ std::vector<Node*> rpo;
+ GetReversePostOrder(g, &rpo, /*stable_comparator=*/NodeComparatorName(),
+ /*edge_filter=*/[](const Edge& edge) {
+ return !edge.src()->IsNextIteration();
+ });
+
+ auto resource_op_set_for_node =
+ absl::make_unique<ResourceOpSet[]>(g.num_node_ids());
+
+ const bool vlog = VLOG_IS_ON(2);
+
+ for (Node* n : rpo) {
+ absl::optional<XlaResourceOpKind> op_kind;
+ TF_RETURN_IF_ERROR(XlaResourceOpKindForNode(
+ *n, flib_def, resource_ops_to_ignore, &op_kind));
+
+ ResourceOpSet* resource_op_set = &resource_op_set_for_node[n->id()];
+
+ // Merge the reaching resource operations for all the incoming edges to
+ // create the set of all possible resource ops reaching `n`.
+ for (const Edge* e : n->in_edges()) {
+ if (n->IsMerge() && e->src()->IsNextIteration()) {
+ // Ignore back-edges (see file comment).
+ continue;
+ }
+
+ const ResourceOpSet& incoming_op_set =
+ resource_op_set_for_node[e->src()->id()];
+ resource_op_set->Add(incoming_op_set);
+ }
+
+ // Add to the "incompatible resource ops" set if necessary.
+ if (op_kind) {
+ for (ResourceOp incoming_op : *resource_op_set) {
+ if (IsEdgeSafe(incoming_op.second, *op_kind)) {
+ continue;
+ }
+
+ if (vlog) {
+ VLOG(2) << "Unsafe edge: "
+ << NodeToString(*g.FindNodeId(incoming_op.first),
+ incoming_op.second)
+ << " -> " << NodeToString(*n, *op_kind);
+ }
+ result->push_back({incoming_op.first, n->id()});
+ }
+
+ resource_op_set->Add({n->id(), *op_kind});
+ }
+
+ if (vlog) {
+ VLOG(3) << n->name() << " -> " << ResourceOpSetToString(*resource_op_set);
+ }
+ }
+
+ std::sort(result->begin(), result->end());
+ CHECK(std::unique(result->begin(), result->end()) == result->end());
+
+ return Status::OK();
+}
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.h b/tensorflow/compiler/jit/resource_operation_safety_analysis.h
new file mode 100644
index 0000000000..ae8cfeecad
--- /dev/null
+++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.h
@@ -0,0 +1,73 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_
+#define TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_
+
+#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+// An XLA cluster hoists all resource reads to be beginning of the cluster
+// execution and all the resource writes to the end. This means it cannot
+// enforce arbitrary ordering dependencies (via control or data edges) between
+// resource operations. Since all resource reads happen before all resource
+// writes, edges constraining resource reads to happen before resource writes
+// are fine, but all other kinds of edges are problematic. This analysis
+// returns the set of pairs of resource operations that cannot be put in the
+// same cluster because XLA cannot respect the dependencies between them in the
+// TensorFlow program.
+//
+// The restrictions are not transitive: it is fine to put A and C in the same
+// cluster even if the returned set contains (A,B) and (B,C).
+//
+// In other words, if these pairs are seen as edges in an undirected graph of
+// the nodes in `g` then auto-clustering is at least as constrained as the graph
+// coloring problem on this graph.
+//
+//
+// For instance if we auto-cluster all operations in this TensorFlow graph:
+//
+// ReadVariablepOp0 -> ReadVariableOp1
+// |
+// v
+// AssignVariableOp0 -> AssignVariableOp1
+//
+// we will lose the ReadVariablepOp0 -> ReadVariableOp1 and the
+// AssignVariableOp0 -> AssignVariableOp1 dependencies. I.e. it is possible for
+// XlaLaunchOp to issue ReadVariableOp1 before ReadVariablepOp0 since it reads
+// all the resource variables when the cluster starts executing without any
+// particular ordering between them; same holds for the AssignVariableOp0 ->
+// AssignVariableOp1 edge. The ReadVariableOp1 -> AssignVariableOp0 edge will
+// be respected by XlaLaunchOp though because all reads happen before all
+// writes.
+//
+//
+// NB! The result computed by this analysis assumes that we don't auto-cluster
+// back-edges (i.e. the edges from NextIteration to Merge).
+//
+// NB! The result computed by this analysis assumes that we don't auto-cluster
+// functional control flow nodes containing resource operations.
+//
+// If `resource_ops_to_ignore` is set then nodes for which it returns true are
+// ignored (we pretend these nodes are not resource operations).
+Status ComputeIncompatibleResourceOperationPairs(
+ const Graph& g, const FunctionLibraryDefinition* flib_def,
+ const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
+ std::vector<std::pair<int, int>>* result);
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_
diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc
new file mode 100644
index 0000000000..e54b547abc
--- /dev/null
+++ b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc
@@ -0,0 +1,540 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/control_flow_ops_internal.h"
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/functional_ops.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
+#include "tensorflow/cc/ops/sendrecv_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+Node* MakeRead(const Scope& scope, const string& id) {
+ Output var_handle =
+ ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
+ Output read =
+ ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT);
+ return read.node();
+}
+
+Node* MakeWrite(const Scope& scope, const string& id) {
+ Output var_handle =
+ ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
+ Output value_to_write =
+ ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
+ ops::AssignVariableOp assign_op(scope.WithOpName("Assignee" + id), var_handle,
+ value_to_write);
+ return assign_op.operation.node();
+}
+
+Node* MakeModify(const Scope& scope, const string& id) {
+ Output var_handle =
+ ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
+ Output value_to_write = ops::Const(scope.WithOpName("Increment" + id), 1.0f);
+ ops::AssignAddVariableOp assign_add_op(scope.WithOpName("Increment" + id),
+ var_handle, value_to_write);
+ return assign_add_op.operation.node();
+}
+
+Node* MakeNeutral(const Scope& scope, const string& id) {
+ return ops::Const(scope.WithOpName("Const" + id), 42.0f).node();
+}
+
+Status ComputeIncompatiblePairs(Graph* g,
+ std::vector<std::pair<int, int>>* result) {
+ FixupSourceAndSinkEdges(g);
+ return ComputeIncompatibleResourceOperationPairs(*g, &g->flib_def(), {},
+ result);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, WriteRead) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(write, read);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> write_read_pair = {write->id(), read->id()};
+ EXPECT_EQ(incompatible_pairs[0], write_read_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ReadWrite) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(read, write);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ EXPECT_EQ(incompatible_pairs.size(), 0);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ReadWriteNoEdges) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ MakeRead(root, "R");
+ MakeWrite(root, "W");
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ EXPECT_EQ(incompatible_pairs.size(), 0);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ReadModify) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* modify = MakeModify(root, "M");
+
+ root.graph()->AddControlEdge(read, modify);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ EXPECT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> read_modify_pair = {read->id(), modify->id()};
+ EXPECT_EQ(incompatible_pairs[0], read_modify_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ModifyRead) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* modify = MakeModify(root, "M");
+
+ root.graph()->AddControlEdge(modify, read);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> modify_read_pair = {modify->id(), read->id()};
+ EXPECT_EQ(incompatible_pairs[0], modify_read_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ModifyWrite) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* modify = MakeModify(root, "M");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(modify, write);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ EXPECT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> modify_write_pair = {modify->id(), write->id()};
+ EXPECT_EQ(incompatible_pairs[0], modify_write_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, WriteModify) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* modify = MakeModify(root, "M");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(write, modify);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> write_modify_pair = {write->id(), modify->id()};
+ EXPECT_EQ(incompatible_pairs[0], write_modify_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ReadModifyWrite) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* modify = MakeModify(root, "M");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(read, modify);
+ root.graph()->AddControlEdge(modify, write);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ EXPECT_EQ(incompatible_pairs.size(), 2);
+ std::pair<int, int> modify_write_pair = {modify->id(), write->id()};
+ std::pair<int, int> read_modify_pair = {read->id(), modify->id()};
+ EXPECT_EQ(incompatible_pairs[0], read_modify_pair);
+ EXPECT_EQ(incompatible_pairs[1], modify_write_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, WriteModifyRead) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* modify = MakeModify(root, "M");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(write, modify);
+ root.graph()->AddControlEdge(modify, read);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 3);
+
+ std::pair<int, int> write_modify_pair = {write->id(), modify->id()};
+ std::pair<int, int> modify_read_pair = {modify->id(), read->id()};
+ std::pair<int, int> write_read_pair = {write->id(), read->id()};
+ EXPECT_EQ(incompatible_pairs[0], modify_read_pair);
+ EXPECT_EQ(incompatible_pairs[1], write_read_pair);
+ EXPECT_EQ(incompatible_pairs[2], write_modify_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, WriteReadModify) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* modify = MakeModify(root, "M");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(write, read);
+ root.graph()->AddControlEdge(read, modify);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 3);
+
+ std::pair<int, int> write_modify_pair = {write->id(), modify->id()};
+ std::pair<int, int> write_read_pair = {write->id(), read->id()};
+ std::pair<int, int> read_modify_pair = {read->id(), modify->id()};
+ EXPECT_EQ(incompatible_pairs[0], read_modify_pair);
+ EXPECT_EQ(incompatible_pairs[1], write_read_pair);
+ EXPECT_EQ(incompatible_pairs[2], write_modify_pair);
+}
+
+FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) {
+ FunctionDefLibrary flib_def;
+ FunctionDef func = FunctionDefHelper::Create(
+ /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"},
+ /*attr_def*/
+ {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)},
+ /*ret_def=*/{{"out", "out:output:0"}});
+ *flib_def.add_function() = std::move(func);
+ return flib_def;
+}
+
+Node* MakeCall(Graph* graph, const string& callee_name, const string& node_name,
+ Status* status) {
+ NodeDef call_node;
+ call_node.set_name(node_name);
+ call_node.set_op(callee_name);
+ return graph->AddNode(call_node, status);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, CallRead) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("Const_func");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+
+ Node* read = MakeRead(root, "R");
+ Status status;
+ Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
+ TF_ASSERT_OK(status);
+
+ root.graph()->AddControlEdge(call, read);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> call_read_edge = {call->id(), read->id()};
+ EXPECT_EQ(incompatible_pairs[0], call_read_edge);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ReadCall) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("Const_func");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+
+ Node* read = MakeRead(root, "R");
+ Status status;
+ Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
+ TF_ASSERT_OK(status);
+
+ root.graph()->AddControlEdge(read, call);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> read_call_edge = {read->id(), call->id()};
+ EXPECT_EQ(incompatible_pairs[0], read_call_edge);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, CallWrite) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("Const_func");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+
+ Node* write = MakeWrite(root, "W");
+ Status status;
+ Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
+ TF_ASSERT_OK(status);
+
+ root.graph()->AddControlEdge(call, write);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> call_write_edge = {call->id(), write->id()};
+ EXPECT_EQ(incompatible_pairs[0], call_write_edge);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, WriteCall) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("Const_func");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+
+ Node* write = MakeWrite(root, "W");
+ Status status;
+ Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
+ TF_ASSERT_OK(status);
+
+ root.graph()->AddControlEdge(write, call);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> write_call_edge = {write->id(), call->id()};
+ EXPECT_EQ(incompatible_pairs[0], write_call_edge);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, SymbolicGradientRead) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("Const_func");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+
+ Node* read = MakeRead(root, "R");
+ NameAttrList fn;
+ fn.set_name("Const_func");
+ Node* symbolic_gradient =
+ ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)},
+ /*Tout=*/{DT_FLOAT}, fn)
+ .output[0]
+ .node();
+
+ root.graph()->AddControlEdge(symbolic_gradient, read);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> symbolic_gradient_read_edge = {symbolic_gradient->id(),
+ read->id()};
+ EXPECT_EQ(incompatible_pairs[0], symbolic_gradient_read_edge);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, WriteSymbolicGradient) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("Const_func");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+
+ Node* write = MakeWrite(root, "W");
+ NameAttrList fn;
+ fn.set_name("Const_func");
+ Node* symbolic_gradient =
+ ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)},
+ /*Tout=*/{DT_FLOAT}, fn)
+ .output[0]
+ .node();
+
+ root.graph()->AddControlEdge(write, symbolic_gradient);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> write_symbolic_gradient_edge = {write->id(),
+ symbolic_gradient->id()};
+ EXPECT_EQ(incompatible_pairs[0], write_symbolic_gradient_edge);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ChainOfOps) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* write_0 = MakeWrite(root, "W0");
+ Node* neutral_0 = MakeNeutral(root, "N0");
+ Node* read_0 = MakeRead(root, "R0");
+ Node* write_1 = MakeWrite(root, "W1");
+ Node* neutral_1 = MakeNeutral(root, "N1");
+ Node* read_1 = MakeRead(root, "R1");
+
+ root.graph()->AddControlEdge(write_0, neutral_0);
+ root.graph()->AddControlEdge(neutral_0, read_0);
+ root.graph()->AddControlEdge(read_0, write_1);
+ root.graph()->AddControlEdge(write_1, neutral_1);
+ root.graph()->AddControlEdge(neutral_1, read_1);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 5);
+ std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()};
+ std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()};
+ std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()};
+ std::pair<int, int> write_0_write_1_pair = {write_0->id(), write_1->id()};
+ std::pair<int, int> read_0_read_1_pair = {read_0->id(), read_1->id()};
+
+ EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
+ EXPECT_EQ(incompatible_pairs[1], write_0_write_1_pair);
+ EXPECT_EQ(incompatible_pairs[2], write_0_read_1_pair);
+ EXPECT_EQ(incompatible_pairs[3], read_0_read_1_pair);
+ EXPECT_EQ(incompatible_pairs[4], write_1_read_1_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, DagOfOps) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* write_0 = MakeWrite(root, "W0");
+ Node* write_1 = MakeWrite(root, "W1");
+ Node* neutral = MakeNeutral(root, "N");
+ Node* read_0 = MakeRead(root, "R0");
+ Node* read_1 = MakeRead(root, "R1");
+
+ root.graph()->AddControlEdge(write_0, neutral);
+ root.graph()->AddControlEdge(write_1, neutral);
+ root.graph()->AddControlEdge(neutral, read_0);
+ root.graph()->AddControlEdge(neutral, read_1);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 4);
+ std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()};
+ std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()};
+ std::pair<int, int> write_1_read_0_pair = {write_1->id(), read_0->id()};
+ std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()};
+
+ EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
+ EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair);
+ EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair);
+ EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, DagOfOpsWithRepeatedPaths) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* write_0 = MakeWrite(root, "W0");
+ Node* write_1 = MakeWrite(root, "W1");
+ Node* neutral = MakeNeutral(root, "N");
+ Node* read_0 = MakeRead(root, "R0");
+ Node* read_1 = MakeRead(root, "R1");
+
+ root.graph()->AddControlEdge(write_0, neutral);
+ root.graph()->AddControlEdge(write_1, neutral);
+ root.graph()->AddControlEdge(neutral, read_0);
+ root.graph()->AddControlEdge(neutral, read_1);
+ root.graph()->AddControlEdge(write_1, read_1);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 4);
+ std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()};
+ std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()};
+ std::pair<int, int> write_1_read_0_pair = {write_1->id(), read_0->id()};
+ std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()};
+
+ EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
+ EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair);
+ EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair);
+ EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, Loop) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output init_value = ops::Placeholder(root.WithOpName("init"), DT_FLOAT);
+ Output loop_cond = ops::Placeholder(root.WithOpName("init"), DT_BOOL);
+ Output enter_value =
+ ops::internal::Enter(root.WithOpName("enter"), init_value, "fr");
+ ops::Merge iv(root.WithOpName("iv"), {enter_value, enter_value});
+ ops::Switch latch(root.WithOpName("latch"), iv.output, loop_cond);
+ ops::internal::Exit exit(root.WithOpName("exit"), iv.output);
+ Output next_iteration =
+ ops::NextIteration(root.WithOpName("next_iteration"), latch.output_true);
+ TF_ASSERT_OK(
+ root.graph()->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1));
+
+ Node* write = MakeWrite(root, "W");
+ Node* read = MakeRead(root, "R");
+
+ root.graph()->AddControlEdge(iv.output.node(), write);
+ root.graph()->AddControlEdge(write, read);
+ root.graph()->AddControlEdge(read, next_iteration.node());
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+
+ std::pair<int, int> write_read_pair = {write->id(), read->id()};
+ EXPECT_EQ(incompatible_pairs[0], write_read_pair);
+}
+
+bool IsResourceArgDef(const OpDef::ArgDef& arg_def) {
+ return arg_def.type() == DT_RESOURCE;
+}
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc
index 38adacd93b..4f2fabd658 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.cc
+++ b/tensorflow/compiler/jit/xla_cluster_util.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <unordered_map>
+#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/kernels/bounds_check.h"
@@ -207,4 +208,27 @@ bool HasResourceInputOrOutput(const Node& node) {
void RemoveFromXlaCluster(NodeDef* node_def) {
node_def->mutable_attr()->erase(kXlaClusterAttr);
}
+
+Status AdjustCycleDetectionGraphForResourceOps(
+ const Graph* graph, const FunctionLibraryDefinition* flib_def,
+ const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
+ GraphCycles* cycles) {
+ std::vector<std::pair<int, int>> unsafe_deps;
+ TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs(
+ *graph, flib_def, resource_ops_to_ignore, &unsafe_deps));
+
+ // An edge {P,Q} in `unsafe_deps` denotes that P and Q, both of which are
+ // operations that interact with resource variables, must not be put in the
+ // same cluster. We enforce this constraint by creating a phantom node, X,
+ // and adding edges P->X and X->Q. MarkForCompilation then cannot cluster P
+ // and Q together since that would create a cycle with X.
+
+ for (std::pair<int, int> unsafe_dep : unsafe_deps) {
+ int phantom_node_id = cycles->NewNode();
+ CHECK(cycles->InsertEdge(unsafe_dep.first, phantom_node_id));
+ CHECK(cycles->InsertEdge(phantom_node_id, unsafe_dep.second));
+ }
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h
index 662a53d89e..b0439a63ca 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.h
+++ b/tensorflow/compiler/jit/xla_cluster_util.h
@@ -55,6 +55,13 @@ void RemoveFromXlaCluster(NodeDef* node_def);
// Returns true if `node` has a DT_RESOURCE typed input or output.
bool HasResourceInputOrOutput(const Node& node);
+// Adds edges to `cycles` to prevent clustering resource operations that cannot
+// be legally clustered.
+Status AdjustCycleDetectionGraphForResourceOps(
+ const Graph* graph, const FunctionLibraryDefinition* flib_def,
+ const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
+ GraphCycles* cycles);
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc
index 2cb351e1ec..65bbf3efe8 100644
--- a/tensorflow/compiler/jit/xla_cluster_util_test.cc
+++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc
@@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 7140d47a94..ef6b0e67d3 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -230,7 +230,7 @@ Status XlaCompilationCache::Compile(
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options) {
+ const XlaCompiler::CompileOptions& compile_options) {
return CompileImpl(options, function, constant_args, variable_args, ctx,
compilation_result, executable, compile_options, false);
}
@@ -241,7 +241,7 @@ Status XlaCompilationCache::CompileSingleOp(
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options) {
+ const XlaCompiler::CompileOptions& compile_options) {
const NodeDef& def = ctx->op_kernel().def();
NameAttrList name;
name.set_name(def.op());
@@ -256,7 +256,7 @@ Status XlaCompilationCache::CompileImpl(
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options,
+ const XlaCompiler::CompileOptions& compile_options,
bool compile_single_op) {
CHECK_NE(executable, nullptr);
VLOG(1) << "XlaCompilationCache::Compile " << DebugString();
@@ -324,13 +324,12 @@ Status XlaCompilationCache::CompileImpl(
entry->compiled = true;
if (compile_single_op) {
- entry->compilation_status = compiler.CompileSingleOp(
- compile_options ? *compile_options : XlaCompiler::CompileOptions(),
- signature.name, ctx, args, &entry->compilation_result);
+ entry->compilation_status =
+ compiler.CompileSingleOp(compile_options, signature.name, ctx, args,
+ &entry->compilation_result);
} else {
entry->compilation_status = compiler.CompileFunction(
- compile_options ? *compile_options : XlaCompiler::CompileOptions(),
- function, args, &entry->compilation_result);
+ compile_options, function, args, &entry->compilation_result);
}
TF_RETURN_IF_ERROR(entry->compilation_status);
CHECK_EQ(entry->executable.get(), nullptr);
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h
index fc5f008f4f..10ad87e38c 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.h
+++ b/tensorflow/compiler/jit/xla_compilation_cache.h
@@ -70,7 +70,7 @@ class XlaCompilationCache : public ResourceBase {
OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options);
+ const XlaCompiler::CompileOptions& compile_options);
// As above, but calls XlaCompiler::CompileSingleOp instead of
// XlaCompiler::CompileFunction.
@@ -80,7 +80,7 @@ class XlaCompilationCache : public ResourceBase {
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options);
+ const XlaCompiler::CompileOptions& compile_options);
xla::LocalClient* client() const { return client_; }
const DeviceType& device_type() const { return device_type_; }
@@ -96,7 +96,7 @@ class XlaCompilationCache : public ResourceBase {
OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options,
+ const XlaCompiler::CompileOptions& compile_options,
bool compile_single_op);
// Takes `result` which has been compiled from a Tensorflow subgraph to a
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index dd84fb34c1..3ba48e8c31 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -177,7 +177,7 @@ Status XlaCompileOnDemandOp::Compile(
std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx,
- result, executable, &compile_options);
+ result, executable, compile_options);
}
void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 2027ec7737..ee07c5c964 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -184,18 +184,6 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
return;
}
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
- if (status.ok()) {
- xla_tensor->set_host_tensor(*cpu_tensor);
- host_to_device_stream_->ThenDoHostCallback([this, done]() {
- // We must not call the done closure directly from DoHostCallback
- // to avoid a deadlock. If done() is the callback that ends an
- // Executor's run, the Executor may call XlaDevice::Sync() inside the
- // callback. This deadlocks, because XlaDevice::Sync() waits for all
- // stream activity to complete.
- thread_pool_->Schedule([done]() { done(Status::OK()); });
- });
- return;
- }
} else {
se::DeviceMemoryBase dev_dst_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
@@ -208,8 +196,9 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
host_to_device_stream_.get(), block_status.error_message().c_str());
}
}
- xla_tensor->set_host_tensor(*cpu_tensor);
-
+ if (status.ok()) {
+ xla_tensor->set_host_tensor(*cpu_tensor);
+ }
done(status);
}
diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
index 4b499b1613..915c5afa79 100644
--- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc
+++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
@@ -208,6 +208,8 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
GraphCycles cycles;
TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles));
+ TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
+ &graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles));
// TODO(hpucha): Make clustering more robust. There are two known issues that
// we need to mitigate: (a) Non-resource variables can cause deadlocks
diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc
index 5736760a87..b77b207908 100644
--- a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc
+++ b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_fusion_optimizer.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/core/graph/graph_def_builder.h"
@@ -179,5 +181,28 @@ TEST_F(XlaFusionOptimizerTest, CompilableCycles) {
EXPECT_EQ(clusters["A"], clusters["C"]);
}
+TEST_F(XlaFusionOptimizerTest, ResourcesClusteringDisallowed) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output var_handle =
+ ops::VarHandleOp(root.WithOpName("Var"), DT_FLOAT, TensorShape({}));
+ Output to_assign = ops::Const(root.WithOpName("Const"), 10.0f);
+ Output begin = ops::Const(root.WithOpName("begin"), 0);
+ Output end = ops::Const(root.WithOpName("end"), 1);
+ Output strides = ops::Const(root.WithOpName("strides"), 1);
+ ops::ResourceStridedSliceAssign assign_1(
+ root.WithOpName("assign_1"), var_handle, begin, end, strides, to_assign);
+ ops::ResourceStridedSliceAssign assign_2(
+ root.WithOpName("assign_2"), var_handle, begin, end, strides, to_assign);
+ root.graph()->AddControlEdge(assign_1.operation.node(),
+ assign_2.operation.node());
+ grappler::GrapplerItem item;
+ root.graph()->ToGraphDef(&item.graph);
+
+ XlaFusionOptimizer optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ auto clusters = GetClusters(output);
+ EXPECT_NE(clusters["assign_1"], clusters["assign_2"]);
+}
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 235bef07b3..94e08b6efe 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -1191,3 +1191,19 @@ tf_xla_py_test(
"//tensorflow/python:platform_test",
],
)
+
+tf_xla_py_test(
+ name = "xla_ops_test",
+ size = "small",
+ srcs = ["xla_ops_test.py"],
+ disabled_backends = ["cpu_ondemand"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/compiler/tf2xla/python:xla",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform_test",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 4a281c37e4..ed4940f204 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -1372,5 +1372,40 @@ class BinaryOpsTest(xla_test.XLATestCase):
[[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]],
dtype=dtype))
+ def testBroadcastTo(self):
+ for dtype in self.all_types:
+ x = np.random.randint(0, high=100, size=[2, 3])
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([2, 3], dtype=np.int32),
+ expected=x)
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([6, 6], dtype=np.int32),
+ expected=np.tile(x, [3, 2]))
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([7, 4, 3], dtype=np.int32),
+ expected=np.tile(x, [7, 2, 1]))
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([7, 0, 3], dtype=np.int32),
+ expected=np.zeros([7, 0, 3], dtype=dtype))
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([7, 1, 2, 9], dtype=np.int32),
+ expected=np.tile(x, [7, 1, 1, 3]))
+ self._testBinary(
+ array_ops.broadcast_to,
+ np.zeros([2, 0], dtype=dtype),
+ np.array([4, 0], dtype=np.int32),
+ expected=np.zeros([4, 0], dtype=dtype))
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py
index 1a82fcbb2a..6fe5a66e0e 100644
--- a/tensorflow/compiler/tests/image_ops_test.py
+++ b/tensorflow/compiler/tests/image_ops_test.py
@@ -410,13 +410,14 @@ class ResizeBilinearTest(xla_test.XLATestCase):
image_np,
target_shape,
expected=None,
- large_tolerance=False):
+ large_tolerance=False,
+ align_corners=True):
if expected is None:
self.fail("expected must be specified")
with self.cached_session() as sess, self.test_scope():
image = array_ops.placeholder(image_np.dtype)
resized = gen_image_ops.resize_bilinear(
- image, target_shape, align_corners=True)
+ image, target_shape, align_corners=align_corners)
out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]})
if large_tolerance:
self.assertAllClose(
@@ -579,6 +580,27 @@ class ResizeBilinearTest(xla_test.XLATestCase):
dtype=np.float32)),
large_tolerance=True)
+ def testNonAlignCorners3x2To6x4(self):
+ input_data = [[64, 32], [32, 64], [50, 100]]
+ expected_data = [[64.0, 48.0, 32.0, 32.0], [48.0, 48.0, 48.0, 48.0],
+ [32.0, 48.0, 64.0, 64.0], [41.0, 61.5, 82.0, 82.0],
+ [50.0, 75.0, 100.0, 100.0], [50.0, 75.0, 100.0, 100.0]]
+ for dtype in self.float_types:
+ self._assertForwardOpMatchesExpected(
+ np.array(input_data, dtype=dtype), [6, 4],
+ expected=np.array(expected_data, dtype=np.float32),
+ align_corners=False)
+
+ def testNonAlignCorners6x4To3x2(self):
+ input_data = [[127, 127, 64, 64], [127, 127, 64, 64], [64, 64, 127, 127],
+ [64, 64, 127, 127], [50, 50, 100, 100], [50, 50, 100, 100]]
+ expected_data = [[127, 64], [64, 127], [50, 100]]
+ for dtype in self.float_types:
+ self._assertForwardOpMatchesExpected(
+ np.array(input_data, dtype=dtype), [3, 2],
+ expected=np.array(expected_data, dtype=dtype),
+ align_corners=False)
+
class NonMaxSuppressionTest(xla_test.XLATestCase):
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
new file mode 100644
index 0000000000..b2f026df6c
--- /dev/null
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -0,0 +1,301 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for XLA op wrappers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.compiler.tf2xla.python import xla
+from tensorflow.compiler.xla import xla_data_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import googletest
+
+
+class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
+
+ def _assertOpOutputMatchesExpected(self, op, args, expected,
+ equality_fn=None):
+ with self.test_session() as session:
+ with self.test_scope():
+ placeholders = [
+ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
+ for arg in args
+ ]
+ feeds = {placeholders[i]: args[i] for i in range(0, len(args))}
+ output = op(*placeholders)
+ result = session.run(output, feeds)
+ if not equality_fn:
+ equality_fn = self.assertAllClose
+ equality_fn(result, expected, rtol=1e-3)
+
+ def testAdd(self):
+ for dtype in self.numeric_types:
+ self._assertOpOutputMatchesExpected(
+ xla.add,
+ args=(np.array([1, 2, 3], dtype=dtype),
+ np.array([4, 5, 6], dtype=dtype)),
+ expected=np.array([5, 7, 9], dtype=dtype))
+
+ self._assertOpOutputMatchesExpected(
+ lambda x, y: xla.add(x, y, broadcast_dims=(0,)),
+ args=(np.array([[1, 2], [3, 4]], dtype=dtype),
+ np.array([7, 11], dtype=dtype)),
+ expected=np.array([[8, 9], [14, 15]], dtype=dtype))
+
+ self._assertOpOutputMatchesExpected(
+ lambda x, y: xla.add(x, y, broadcast_dims=(1,)),
+ args=(np.array([[1, 2], [3, 4]], dtype=dtype),
+ np.array([7, 11], dtype=dtype)),
+ expected=np.array([[8, 13], [10, 15]], dtype=dtype))
+
+ def testBroadcast(self):
+ for dtype in self.numeric_types:
+ v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2])
+ self._assertOpOutputMatchesExpected(
+ lambda x: xla.broadcast(x, (7, 42)),
+ args=(v,),
+ expected=np.tile(v, (7, 42, 1, 1)))
+
+ def testShiftRightLogical(self):
+ self._assertOpOutputMatchesExpected(
+ xla.shift_right_logical,
+ args=(np.array([-1, 16], dtype=np.int32), np.int32(4)),
+ expected=np.array([0x0FFFFFFF, 1], dtype=np.int32))
+
+ self._assertOpOutputMatchesExpected(
+ xla.shift_right_logical,
+ args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)),
+ expected=np.array([0x0FFFFFFF, 1], dtype=np.uint32))
+
+ def testShiftRightArithmetic(self):
+ self._assertOpOutputMatchesExpected(
+ xla.shift_right_arithmetic,
+ args=(np.array([-1, 16], dtype=np.int32), np.int32(4)),
+ expected=np.array([-1, 1], dtype=np.int32))
+
+ self._assertOpOutputMatchesExpected(
+ xla.shift_right_arithmetic,
+ args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)),
+ expected=np.array([0xFFFFFFFF, 1], dtype=np.uint32))
+
+ PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfigProto.DEFAULT,
+ xla_data_pb2.PrecisionConfigProto.HIGH,
+ xla_data_pb2.PrecisionConfigProto.HIGHEST)
+
+ @parameterized.parameters(*PRECISION_VALUES)
+ def testConv(self, precision):
+ for dtype in set(self.float_types).intersection(
+ set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
+
+ def conv_1d_fn(lhs, rhs):
+ dnums = xla_data_pb2.ConvolutionDimensionNumbers()
+ num_spatial_dims = 1
+ dnums.input_batch_dimension = 0
+ dnums.input_feature_dimension = 1
+ dnums.output_batch_dimension = 0
+ dnums.output_feature_dimension = 1
+ dnums.kernel_output_feature_dimension = 0
+ dnums.kernel_input_feature_dimension = 1
+ dnums.input_spatial_dimensions.extend(range(2, 2 + num_spatial_dims))
+ dnums.kernel_spatial_dimensions.extend(range(2, 2 + num_spatial_dims))
+ dnums.output_spatial_dimensions.extend(range(2, 2 + num_spatial_dims))
+ precision_config = None
+ if precision:
+ precision_config = xla_data_pb2.PrecisionConfigProto()
+ precision_config.operand_precision.extend([precision, precision])
+ return xla.conv(
+ lhs,
+ rhs,
+ window_strides=(1,),
+ padding=((2, 1),),
+ lhs_dilation=(1,),
+ rhs_dilation=(2,),
+ dimension_numbers=dnums)
+
+ self._assertOpOutputMatchesExpected(
+ conv_1d_fn,
+ args=(
+ np.array([[[3, 4, 5, 6]]], dtype=dtype),
+ np.array([[[-2, -3]]], dtype=dtype),
+ ),
+ expected=np.array([[[-9, -12, -21, -26, -10]]], dtype=dtype))
+
+ @parameterized.parameters(*PRECISION_VALUES)
+ def testDotGeneral(self, precision):
+ for dtype in self.float_types:
+
+ def dot_fn(lhs, rhs):
+ dnums = xla_data_pb2.DotDimensionNumbers()
+ dnums.lhs_contracting_dimensions.append(2)
+ dnums.rhs_contracting_dimensions.append(1)
+ dnums.lhs_batch_dimensions.append(0)
+ dnums.rhs_batch_dimensions.append(0)
+ precision_config = None
+ if precision:
+ precision_config = xla_data_pb2.PrecisionConfigProto()
+ precision_config.operand_precision.extend([precision, precision])
+ return xla.dot_general(
+ lhs,
+ rhs,
+ dimension_numbers=dnums,
+ precision_config=precision_config)
+
+ lhs = np.array(
+ [
+ [[1, 2], [3, 4]],
+ [[5, 6], [7, 8]],
+ ], dtype=dtype)
+ rhs = np.array(
+ [
+ [[1, 2, 3], [4, 5, 6]],
+ [[7, 8, 9], [10, 11, 12]],
+ ], dtype=dtype)
+ self._assertOpOutputMatchesExpected(
+ dot_fn,
+ args=(lhs, rhs),
+ expected=np.array(
+ [
+ [[9, 12, 15], [19, 26, 33]],
+ [[95, 106, 117], [129, 144, 159]],
+ ],
+ dtype=dtype))
+
+ def testNeg(self):
+ for dtype in self.numeric_types:
+ self._assertOpOutputMatchesExpected(
+ xla.neg,
+ args=(np.array([1, 2, 3], dtype=dtype),),
+ expected=np.array([-1, -2, -3], dtype=dtype))
+
+ def testPad(self):
+ for dtype in self.numeric_types:
+
+ def pad_fn(x):
+ return xla.pad(
+ x,
+ padding_value=7,
+ padding_low=[2, 1],
+ padding_high=[1, 2],
+ padding_interior=[1, 0])
+
+ self._assertOpOutputMatchesExpected(
+ pad_fn,
+ args=(np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]),),
+ expected=np.array(
+ [[7, 7, 7, 7, 7], [7, 7, 7, 7, 7], [7, 0, 1, 7, 7],
+ [7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]],
+ dtype=dtype))
+
+ def testReduce(self):
+ for dtype in set(self.numeric_types).intersection(
+ set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
+
+ @function.Defun(dtype, dtype)
+ def sum_reducer(x, y):
+ return x + y
+
+ def sum_reduction(dims):
+
+ def fn(x):
+ return xla.reduce(
+ x, init_value=0, dimensions_to_reduce=dims, reducer=sum_reducer)
+
+ return fn
+
+ self._assertOpOutputMatchesExpected(
+ sum_reduction(dims=[]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]))
+ self._assertOpOutputMatchesExpected(
+ sum_reduction(dims=[0]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=np.array([12, 15, 18, 21], dtype=dtype))
+ self._assertOpOutputMatchesExpected(
+ sum_reduction(dims=[1]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=np.array([6, 22, 38], dtype=dtype))
+ self._assertOpOutputMatchesExpected(
+ sum_reduction(dims=[0, 1]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=dtype(66))
+
+ @function.Defun(dtype, dtype)
+ def mul_reducer(x, y):
+ return x * y
+
+ def mul_reduction(dims):
+
+ def fn(x):
+ return xla.reduce(
+ x, init_value=1, dimensions_to_reduce=dims, reducer=mul_reducer)
+
+ return fn
+
+ self._assertOpOutputMatchesExpected(
+ mul_reduction(dims=[0]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=np.array([0, 45, 120, 231], dtype=dtype))
+
+ def testSelectAndScatter(self):
+ for dtype in set(self.numeric_types).intersection(
+ set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
+
+ @function.Defun(dtype, dtype)
+ def add_scatter(x, y):
+ return x + y
+
+ @function.Defun(dtype, dtype)
+ def ge_select(x, y):
+ return x >= y
+
+ def test_fn(operand, source):
+ return xla.select_and_scatter(
+ operand,
+ window_dimensions=[2, 3, 1, 1],
+ window_strides=[2, 2, 1, 1],
+ padding=[[0, 0]] * 4,
+ source=source,
+ init_value=0,
+ select=ge_select,
+ scatter=add_scatter)
+
+ self._assertOpOutputMatchesExpected(
+ test_fn,
+ args=(np.array(
+ [[7, 2, 5, 3, 8], [3, 8, 9, 3, 4], [1, 5, 7, 5, 6],
+ [0, 6, 2, 10, 2]],
+ dtype=dtype).reshape((4, 5, 1, 1)),
+ np.array([[2, 6], [3, 1]], dtype=dtype).reshape((2, 2, 1, 1))),
+ expected=np.array(
+ [[0, 0, 0, 0, 0], [0, 0, 8, 0, 0], [0, 0, 3, 0, 0],
+ [0, 0, 0, 1, 0]],
+ dtype=dtype).reshape((4, 5, 1, 1)))
+
+ def testTranspose(self):
+ for dtype in self.numeric_types:
+ v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2])
+ self._assertOpOutputMatchesExpected(
+ lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 85fd0c9217..92e577bb7b 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -39,6 +39,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -88,6 +89,7 @@ cc_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -221,13 +223,11 @@ cc_library(
srcs = [
"literal_util.cc",
"shape_util.cc",
- "str_util.cc",
"type_util.cc",
],
hdrs = [
"literal_util.h",
"shape_util.h",
- "str_util.h",
"type_util.h",
],
visibility = [":friends"],
@@ -256,6 +256,7 @@ cc_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -307,6 +308,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
],
)
@@ -374,19 +376,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
- ],
-)
-
-tf_cc_test(
- name = "str_util_test",
- srcs = [
- "str_util_test.cc",
- ],
- deps = [
- ":common",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
],
)
@@ -459,6 +449,7 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -482,6 +473,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
@@ -609,3 +601,30 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
+
+cc_library(
+ name = "resource_operation_table",
+ srcs = ["resource_operation_table.cc"],
+ hdrs = ["resource_operation_table.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/algorithm:container",
+ ],
+)
+
+tf_cc_test(
+ name = "resource_operation_table_test",
+ srcs = ["resource_operation_table_test.cc"],
+ deps = [
+ ":resource_operation_table",
+ ":xla_compiler",
+ "//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
+ ],
+)
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index de1008803d..e8673d7790 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -23,11 +23,11 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
namespace tensorflow {
-
// Backwards dataflow analysis that finds arguments to a graph that must be
// compile-time constants.
Status BackwardsConstAnalysis(const Graph& g,
- std::vector<bool>* compile_time_const_args) {
+ std::vector<bool>* compile_time_const_args,
+ std::vector<bool>* compile_time_const_nodes) {
// Operators that don't look at the data of their inputs, just the shapes.
const std::unordered_set<string> metadata_ops = {
"Rank",
@@ -36,9 +36,16 @@ Status BackwardsConstAnalysis(const Graph& g,
"Size",
};
+ std::vector<bool> compile_time_const_nodes_impl;
+ if (compile_time_const_nodes) {
+ CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids());
+ } else {
+ compile_time_const_nodes_impl.resize(g.num_node_ids());
+ compile_time_const_nodes = &compile_time_const_nodes_impl;
+ }
+
Status status;
- std::unordered_set<const Node*> must_be_const;
- auto visit = [&status, &metadata_ops, &must_be_const,
+ auto visit = [&status, &metadata_ops, compile_time_const_nodes,
compile_time_const_args](Node* node) {
if (!status.ok()) return;
@@ -47,17 +54,19 @@ Status BackwardsConstAnalysis(const Graph& g,
// If this node must be const, and it isn't a metadata op, then all of its
// parents must be const.
- if (must_be_const.find(node) != must_be_const.end()) {
+ if ((*compile_time_const_nodes)[node->id()]) {
if (node->type_string() == "_Arg") {
int index;
status = GetNodeAttr(node->attrs(), "index", &index);
if (!status.ok()) return;
- compile_time_const_args->at(index) = true;
+ if (compile_time_const_args) {
+ (*compile_time_const_args)[index] = true;
+ }
return;
}
for (const Edge* pred : node->in_edges()) {
if (!pred->IsControlEdge()) {
- must_be_const.insert(pred->src());
+ (*compile_time_const_nodes)[pred->src()->id()] = true;
}
}
return;
@@ -80,7 +89,7 @@ Status BackwardsConstAnalysis(const Graph& g,
for (Edge const* edge : node->in_edges()) {
if (edge->dst_input() >= name_range->second.first &&
edge->dst_input() < name_range->second.second) {
- must_be_const.insert(edge->src());
+ (*compile_time_const_nodes)[edge->src()->id()] = true;
}
}
}
diff --git a/tensorflow/compiler/tf2xla/const_analysis.h b/tensorflow/compiler/tf2xla/const_analysis.h
index 634b97d7e3..af57e5a403 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.h
+++ b/tensorflow/compiler/tf2xla/const_analysis.h
@@ -23,10 +23,18 @@ limitations under the License.
namespace tensorflow {
-// Backwards dataflow analysis that finds arguments (_Arg nodes) to a graph that
-// must be compile-time constants.
+// Backwards dataflow analysis that finds nodes in a graph that must be
+// compile-time constants for us to be able to lower the graph to XLA.
+//
+// The indices of the arguments to `graph` that must be constant are returned in
+// `compile_time_const_arg_indices`, if `compile_time_const_arg_indices` is not
+// null.
+//
+// The ids of the nodes in `graph` that must be constant are returned in
+// `compile_time_const_nodes`, if `compile_time_const_nodes` is not null.
Status BackwardsConstAnalysis(const Graph& graph,
- std::vector<bool>* compile_time_const_args);
+ std::vector<bool>* compile_time_const_arg_indices,
+ std::vector<bool>* compile_time_const_nodes);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc
index 992b12c06d..56065be894 100644
--- a/tensorflow/compiler/tf2xla/const_analysis_test.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -38,17 +39,23 @@ TEST(ConstAnalysisTest, Basics) {
auto c = ops::Reshape(root, arg2, b);
auto d = ops::Mul(root, c, ops::Sum(root, arg3, arg3));
- Graph graph(OpRegistry::Global());
- TF_ASSERT_OK(root.ToGraph(&graph));
+ FixupSourceAndSinkEdges(root.graph());
std::vector<bool> const_args(4, false);
- TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args));
+ std::vector<bool> const_nodes(root.graph()->num_node_ids(), false);
+ TF_ASSERT_OK(
+ BackwardsConstAnalysis(*root.graph(), &const_args, &const_nodes));
// Arg 0 doesn't need to be constant since the graph only uses its shape.
// Arg 1 must be constant because it flows to the shape argument of a Reshape.
// Arg 2 is used only as the value input to a Reshape and need not be const.
// Arg 3 is used as the reduction-indices argument to Sum and must be const.
EXPECT_EQ(const_args, std::vector<bool>({false, true, false, true}));
+
+ EXPECT_FALSE(const_nodes[arg0.node()->id()]);
+ EXPECT_TRUE(const_nodes[arg1.node()->id()]);
+ EXPECT_FALSE(const_nodes[arg2.node()->id()]);
+ EXPECT_TRUE(const_nodes[arg3.node()->id()]);
}
// Regression test for a case where the backward const analysis did
@@ -73,7 +80,8 @@ TEST(ConstAnalysisTest, TopologicalOrder) {
TF_ASSERT_OK(root.ToGraph(&graph));
std::vector<bool> const_args(3, false);
- TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args));
+ TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
+ /*compile_time_const_nodes=*/nullptr));
EXPECT_EQ(const_args, std::vector<bool>({true, true, false}));
}
@@ -93,7 +101,8 @@ TEST(ConstAnalysisTest, DontFollowControlDependencies) {
TF_ASSERT_OK(root.ToGraph(&graph));
std::vector<bool> const_args(2, false);
- TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args));
+ TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
+ /*compile_time_const_nodes=*/nullptr));
EXPECT_EQ(const_args, std::vector<bool>({false, true}));
}
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc
index f14cfca4ea..b5667ca0d3 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
@@ -52,11 +53,10 @@ string DebugString(CondStateMap::CondId cond_state) {
if (cond_state == nullptr || cond_state->empty()) return "[]";
return strings::StrCat(
"[",
- tensorflow::str_util::Join(
- *cond_state, ", ",
- [](string* output, const CondStateMap::CondNode& node) {
- strings::StrAppend(output, node.ToString());
- }),
+ absl::StrJoin(*cond_state, ", ",
+ [](string* output, const CondStateMap::CondNode& node) {
+ strings::StrAppend(output, node.ToString());
+ }),
"]");
}
@@ -169,10 +169,10 @@ using CondArgNodes = std::vector<CondArgNode>;
string DebugString(const CondArgNodes& nodes) {
return strings::StrCat(
"[",
- tensorflow::str_util::Join(nodes, ", ",
- [](string* output, const CondArgNode& node) {
- strings::StrAppend(output, node.ToString());
- }),
+ absl::StrJoin(nodes, ", ",
+ [](string* output, const CondArgNode& node) {
+ strings::StrAppend(output, node.ToString());
+ }),
"]");
}
@@ -387,8 +387,9 @@ Status Conditional::BuildArgumentNodes() {
}
if (!has_input) {
return errors::Internal(
- "Failed to functionalize control flow with merge '", m->name(),
- "' that doesn't have input on ", Branch_Name(branch), " branch.");
+ "Failed to functionalize control flow with merge ",
+ FormatNodeForError(*m), " that doesn't have input on ",
+ Branch_Name(branch), " branch.");
}
}
}
@@ -469,8 +470,8 @@ Status Conditional::ExtractBodies(Graph* graph) {
// but revisit to improve the testing to enable making this an
// error.
LOG(WARNING) << errors::InvalidArgument(
- "Graph contains node ", src->name(), " that feeds into node ",
- dst->name(),
+ "Graph contains node ", FormatNodeForError(*src),
+ " that feeds into node ", FormatNodeForError(*dst),
" but these nodes are in different control contexts (",
DebugString(src_id), " vs ", DebugString(dst_id),
" (detected during out edge testing)");
@@ -512,8 +513,8 @@ Status Conditional::ExtractBodies(Graph* graph) {
node_map.at(src->id()) = output->CopyNode(src);
} else {
return errors::InvalidArgument(
- "Graph contains node ", src->name(), " that feeds into node ",
- dst->name(),
+ "Graph contains node ", FormatNodeForError(*src),
+ " that feeds into node ", FormatNodeForError(*dst),
" but these nodes are in different control contexts (",
DebugString(src_id), " vs ", DebugString(dst_id),
" (detected during in edge testing)");
@@ -675,7 +676,8 @@ Status Conditional::AddOutputEdges(Graph* graph) {
int dst_input = edge->dst_input();
if (edge->src_output() > 0) {
return errors::Unimplemented("Output of index (", edge->src_output(),
- ") of merge node ", node->name());
+ ") of merge node ",
+ FormatNodeForError(*node));
}
bool control_edge = edge->IsControlEdge();
@@ -1060,7 +1062,8 @@ Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) {
CondStateMap::CondId prop = StateAlongEdge(e);
auto id_or = JoinCondStatesMerge(prop, cond_state_map_.LookupId(dst));
- TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", dst->name());
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
+ FormatNodeForError(*dst));
cond_state_map_.ResetId(dst, id_or.ValueOrDie());
}
@@ -1090,7 +1093,8 @@ Status FunctionalizeCond::DetermineCondState(Node* dst) {
// Joining the state between the current and propagated state.
CondStateMap::CondId prop = StateAlongEdge(e);
auto id_or = JoinCondStatesNonMerge(prop, cond_state_map_.LookupId(dst));
- TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", dst->name());
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
+ FormatNodeForError(*dst));
cond_state_map_.ResetId(dst, id_or.ValueOrDie());
}
}
@@ -1117,7 +1121,7 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) {
}
if (non_dead_edge == nullptr) {
- return errors::InvalidArgument("Merge node ", node->name(),
+ return errors::InvalidArgument("Merge node ", FormatNodeForError(*node),
" has no non-dead inputs.");
}
cond_state_map_.MarkDead(node);
@@ -1169,7 +1173,8 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) {
if (IsMerge(dst_node)) {
auto id_or =
JoinCondStatesMerge(dst_id, cond_state_map_.LookupId(dst_node));
- TF_RETURN_IF_ERROR(id_or.status());
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
+ FormatNodeForError(*dst_node));
cond_state_map_.ResetId(dst_node, id_or.ValueOrDie());
} else {
auto id_or =
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h
index a0544b69e9..61940e3586 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/graph/graph.h"
@@ -43,11 +44,11 @@ xla::StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index);
template <typename T>
string NodesToString(const T& nodes) {
return strings::StrCat("{",
- str_util::Join(nodes, ",",
- [](string* output, const Node* node) {
- strings::StrAppend(output,
- node->name());
- }),
+ absl::StrJoin(nodes, ",",
+ [](string* output, const Node* node) {
+ strings::StrAppend(output,
+ node->name());
+ }),
"}");
}
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index e4fdf0a618..ba37ed3337 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -57,7 +57,8 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
std::vector<bool> compile_time_constant_flags(expressions.size());
TF_RETURN_IF_ERROR(
- BackwardsConstAnalysis(*graph, &compile_time_constant_flags));
+ BackwardsConstAnalysis(*graph, &compile_time_constant_flags,
+ /*compile_time_const_nodes=*/nullptr));
args->resize(expressions.size());
for (int i = 0; i < args->size(); ++i) {
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index b1366e9e31..c1438f893f 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -22,6 +22,7 @@ tf_kernel_library(
"bcast_ops.cc",
"bias_ops.cc",
"binary_ops.cc",
+ "broadcast_to_op.cc",
"bucketize_op.cc",
"cast_op.cc",
"categorical_op.cc",
@@ -100,6 +101,12 @@ tf_kernel_library(
"unary_ops.cc",
"unpack_op.cc",
"variable_ops.cc",
+ "xla_broadcast_helper_op.cc",
+ "xla_conv_op.cc",
+ "xla_dot_op.cc",
+ "xla_pad_op.cc",
+ "xla_reduce_op.cc",
+ "xla_select_and_scatter_op.cc",
],
hdrs = [
"index_ops.h",
@@ -108,6 +115,8 @@ tf_kernel_library(
deps = [
":if_op",
":while_op",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:batch_dot",
diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
index ba3b1c9dab..2e383b1473 100644
--- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
@@ -16,6 +16,7 @@ limitations under the License.
// XLA-specific Ops for broadcasting used in gradient
// code.
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -51,8 +52,8 @@ class BCastArgsOp : public XlaOpKernel {
BCast bcast(shapes[0], shapes[1]);
OP_REQUIRES(ctx, bcast.IsValid(),
errors::InvalidArgument(
- "Incompatible shapes: [", str_util::Join(shapes[0], ","),
- "] vs. [", str_util::Join(shapes[1], ","), "]"));
+ "Incompatible shapes: [", absl::StrJoin(shapes[0], ","),
+ "] vs. [", absl::StrJoin(shapes[1], ","), "]"));
const int64 len = bcast.output_shape().size();
Tensor output(DT_INT32, TensorShape({len}));
@@ -105,8 +106,8 @@ class BCastGradArgsOp : public XlaOpKernel {
BCast bcast(shapes[0], shapes[1]);
OP_REQUIRES(ctx, bcast.IsValid(),
errors::InvalidArgument(
- "Incompatible shapes: [", str_util::Join(shapes[0], ","),
- "] vs. [", str_util::Join(shapes[1], ","), "]"));
+ "Incompatible shapes: [", absl::StrJoin(shapes[0], ","),
+ "] vs. [", absl::StrJoin(shapes[1], ","), "]"));
Output(ctx, 0, bcast.grad_x_reduce_idx());
Output(ctx, 1, bcast.grad_y_reduce_idx());
}
diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
new file mode 100644
index 0000000000..4bd7c74dca
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
@@ -0,0 +1,101 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/algorithm/container.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/bcast.h"
+
+namespace tensorflow {
+namespace {
+
+class BroadcastToOp : public XlaOpKernel {
+ public:
+ explicit BroadcastToOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape(0);
+ TensorShape output_shape;
+ OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape));
+
+ OP_REQUIRES(context, input_shape.dims() <= output_shape.dims(),
+ errors::InvalidArgument(
+ "Input rank (", input_shape.dims(),
+ ") must be less than or equal to the output rank (",
+ output_shape.dims(), ")"));
+
+ auto input_dims = input_shape.dim_sizes();
+ auto output_dims = output_shape.dim_sizes();
+
+ // Broadcasting is done right-to-left on right-aligned dimensions; reverse
+ // the two vectors so elements to be broadcast are aligned.
+ absl::c_reverse(input_dims);
+ absl::c_reverse(output_dims);
+
+ std::vector<int64> broadcast_dims;
+ std::vector<int64> broadcast_shape;
+ for (int i = 0; i < output_shape.dims(); ++i) {
+ if (i < input_shape.dims()) {
+ OP_REQUIRES(
+ context,
+ (output_dims[i] == 0 && input_dims[i] == 0) ||
+ (input_dims[i] != 0 && output_dims[i] % input_dims[i] == 0),
+ errors::InvalidArgument("invalid shape to broadcast from ",
+ input_shape.DebugString(), " to ",
+ output_shape.DebugString()));
+
+ broadcast_dims.push_back(broadcast_shape.size());
+ if (output_dims[i] == input_dims[i] || input_dims[i] == 1) {
+ broadcast_shape.push_back(output_dims[i]);
+ }
+ if (output_dims[i] != input_dims[i]) {
+ // Add dimensions [I, O/I], which we will later flatten to just
+ // [O]. We must do this in two phases since XLA broadcasting does not
+ // support tiling.
+ broadcast_shape.push_back(input_dims[i]);
+ broadcast_shape.push_back(output_dims[i] / input_dims[i]);
+ }
+ } else {
+ broadcast_shape.push_back(output_dims[i]);
+ }
+ }
+ absl::c_reverse(broadcast_dims);
+ int broadcast_shape_size = broadcast_shape.size();
+ for (int64& broadcast_dim : broadcast_dims) {
+ broadcast_dim = broadcast_shape_size - broadcast_dim - 1;
+ }
+ absl::c_reverse(broadcast_shape);
+ xla::XlaOp output = xla::Reshape(
+ xla::BroadcastInDim(context->Input(0),
+ xla::ShapeUtil::MakeShape(
+ context->input_xla_type(0), broadcast_shape),
+ broadcast_dims),
+ output_shape.dim_sizes());
+ context->SetOutput(0, output);
+ }
+};
+
+REGISTER_XLA_OP(Name("BroadcastTo").CompileTimeConstInput("shape"),
+ BroadcastToOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index 8d75624e74..8e071bf0b7 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -32,13 +32,13 @@ namespace {
//
// 1. S := (N - 1) / gcd(N-1, R-1)
// 2. k := (R - 1) / gcd(N-1, R-1)
-// 3. Convolution(kxk, stride=S, lhs_dilation=k, padding=k-1)
+// 3. Convolution((2k-1)x(2k-1), stride=S, lhs_dilation=k, padding=k-1)
//
// For example, to Scale from 7x7 -> 15x15:
//
// 1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3
// 2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7
-// 3. Convolution(7x7, stride=3, lhs_dilation=3, padding=2)
+// 3. Convolution(15x15, stride=3, lhs_dilation=7, padding=2)
//
//
// The 7x7 -> 15x15 case is much too large to write out in full as an
@@ -65,6 +65,8 @@ namespace {
// 1/9 * 3 6 9 6 3
// 2 4 6 4 2
// 1 2 3 2 1
+// Note that the convolution kernel matrix is separable and thus we can instead
+// use 2 consecutive 1D kernel of the dimension 2k-1, along each axis.
// Computes the size of the convolutional kernel and stride to use when resizing
// from in_size to out_size.
@@ -76,7 +78,8 @@ struct ResizeConvolutionDims {
std::vector<int64> stride;
};
ResizeConvolutionDims ComputeResizeConvolutionParameters(
- gtl::ArraySlice<int64> in_size, gtl::ArraySlice<int64> out_size) {
+ gtl::ArraySlice<int64> in_size, gtl::ArraySlice<int64> out_size,
+ bool align_corners) {
CHECK_EQ(in_size.size(), out_size.size());
int num_spatial_dims = in_size.size();
ResizeConvolutionDims dims;
@@ -92,15 +95,32 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters(
// entry before resizing.
dims.stride[i] = dims.kernel_size[i] = 1;
} else {
- int64 gcd = MathUtil::GCD(static_cast<uint64>(in_size[i] - 1),
- static_cast<uint64>(out_size[i] - 1));
- dims.stride[i] = (in_size[i] - 1) / gcd;
- dims.kernel_size[i] = (out_size[i] - 1) / gcd;
+ // The scaling factor changes depending on the alignment of corners.
+ const int64 in_size_factor = align_corners ? in_size[i] - 1 : in_size[i];
+ const int64 out_size_factor =
+ align_corners ? out_size[i] - 1 : out_size[i];
+
+ int64 gcd = MathUtil::GCD(static_cast<uint64>(in_size_factor),
+ static_cast<uint64>(out_size_factor));
+ dims.stride[i] = in_size_factor / gcd;
+ dims.kernel_size[i] = out_size_factor / gcd;
}
}
return dims;
}
+// The upper padding of the input needed by ConvGeneralDilated calls is
+// determined by solving two related relationships (assuming rhs_dilation == 0):
+// 1. dilated_input_dim = lower_padding + upper_padding
+// + lhs_dilation * (in_size - 1) + 1
+// 2. dilated_input_dim = (2 * dims.kernel-size - 1)
+// + dims.stride * (out_size - 1)
+int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size,
+ int64 stride) {
+ return (2 * kernel_size - 1) + (out_size - 1) * stride - (kernel_size - 1) -
+ 1 - (kernel_size * (in_size - 1));
+}
+
// Form a 2D convolution kernel like:
// 1 2 3 2 1
// 2 4 6 4 2
@@ -171,7 +191,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
const int num_spatial_dims,
std::vector<int64> in_size,
std::vector<int64> out_size,
- const int64 channels) {
+ const int64 channels,
+ const bool align_corners) {
// Picture for a 1x3 to 1x4 resize:
// stride = 2, kernel size = 3
// Input:
@@ -196,27 +217,82 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
ResizeConvolutionDims dims =
- ComputeResizeConvolutionParameters(in_size, out_size);
+ ComputeResizeConvolutionParameters(in_size, out_size, align_corners);
xla::XlaOp output;
- // Split convolutions into independent dimensions if they wmuld be a very
+
+ // Concatenation and padding below currently assumes num_spatial_dims is 2 to
+ // prevent needless code complexity.
+ CHECK_EQ(num_spatial_dims, 2)
+ << "ResizeUsingDilationAndConvolution pads only 2 dimensions currently.";
+ std::vector<int64> upper_padding(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ upper_padding[i] = dims.kernel_size[i] - 1;
+ }
+ xla::XlaOp input_data = input;
+
+ if (!align_corners) {
+ // When Tensorflow does not align_corners, the resize indexing can access
+ // beyond the upper bound and is instead clamped to prevent out of bounds
+ // reads. This is conceptually the same as extending the edges of the input.
+ // We emulate this by copying the last row/column of the input.
+ // Calculate what padding would be needed then determine how far to extend
+ // the border before lhs dilation.
+ std::vector<int64> num_extended(num_spatial_dims);
+ upper_padding[0] = CalculateUpperPadding(
+ in_size[0], out_size[0], dims.kernel_size[0], dims.stride[0]);
+ upper_padding[1] = CalculateUpperPadding(
+ in_size[1], out_size[1], dims.kernel_size[1], dims.stride[1]);
+ num_extended[0] = upper_padding[0] / (dims.kernel_size[0]);
+ num_extended[1] = upper_padding[1] / (dims.kernel_size[1]);
+
+ if (num_extended[0] > 0) {
+ auto slice =
+ xla::Slice(input_data, {0, in_size[0] - 1, 0, 0},
+ {1, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
+ for (int i = 0; i < num_extended[0]; i++) {
+ input_data = xla::ConcatInDim(builder, {input_data, slice}, 1);
+ }
+ }
+
+ if (num_extended[1] > 0) {
+ auto slice =
+ xla::Slice(input_data, {0, 0, in_size[1] - 1, 0},
+ {1, in_size[0] + num_extended[0], in_size[1], channels},
+ {1, 1, 1, 1});
+ for (int i = 0; i < num_extended[1]; i++) {
+ input_data = xla::ConcatInDim(builder, {input_data, slice}, 2);
+ }
+ }
+
+ // Setting in_size to (in_size + num_extended) due to the above Slice and
+ // ConcatInDim. Recalculate needed padding after the above Slice/Concat.
+ upper_padding[0] =
+ CalculateUpperPadding(in_size[0] + num_extended[0], out_size[0],
+ dims.kernel_size[0], dims.stride[0]);
+ upper_padding[1] =
+ CalculateUpperPadding(in_size[1] + num_extended[1], out_size[1],
+ dims.kernel_size[1], dims.stride[1]);
+ }
+
+ // Split convolutions into independent dimensions if they would be a very
// large kernel.
if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
xla::XlaOp kernel =
MakeBilinearResizeKernel(builder, dims.kernel_size, channels);
- output = xla::ConvGeneralDilated(
- input, kernel, dims.stride,
- /*padding=*/
- {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
- {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
- /*lhs_dilation=*/dims.kernel_size,
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ output =
+ xla::ConvGeneralDilated(input_data, kernel, dims.stride,
+ /*padding=*/
+ {{dims.kernel_size[0] - 1, upper_padding[0]},
+ {dims.kernel_size[1] - 1, upper_padding[1]}},
+ /*lhs_dilation=*/dims.kernel_size,
+ /*rhs_dilation=*/{1, 1}, dimension_numbers);
} else {
xla::XlaOp kernel0 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
output = xla::ConvGeneralDilated(
- input, kernel0, {dims.stride[0], 1},
+ input_data, kernel0, {dims.stride[0], 1},
/*padding=*/
- {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
+ {{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}},
/*lhs_dilation=*/{dims.kernel_size[0], 1},
/*rhs_dilation=*/{1, 1}, dimension_numbers);
xla::XlaOp kernel1 =
@@ -224,7 +300,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
output = xla::ConvGeneralDilated(
output, kernel1, {1, dims.stride[1]},
/*padding=*/
- {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
+ {{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}},
/*lhs_dilation=*/{1, dims.kernel_size[1]},
/*rhs_dilation=*/{1, 1}, dimension_numbers);
}
@@ -245,9 +321,10 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
const int num_spatial_dims,
std::vector<int64> in_size,
std::vector<int64> grad_size,
- const int64 channels) {
+ const int64 channels,
+ const bool align_corners) {
ResizeConvolutionDims dims =
- ComputeResizeConvolutionParameters(in_size, grad_size);
+ ComputeResizeConvolutionParameters(in_size, grad_size, align_corners);
// To form the backward convolution, we keep the kernel unchanged (it is
// already symmetric) and swap the roles of strides and LHS dilation.
@@ -341,10 +418,6 @@ class ResizeBilinearOp : public XlaOpKernel {
public:
explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
- OP_REQUIRES(
- ctx, align_corners_ == true,
- errors::Unimplemented(
- "ResizeBilinear with align_corners=False is not yet implemented"));
}
void Compile(XlaOpKernelContext* ctx) override {
@@ -377,20 +450,19 @@ class ResizeBilinearOp : public XlaOpKernel {
// If in_size[i] > 1 and out_size[i] == 1, slice out the first input in
// dimension i.
- std::vector<int64> slice_size = in_size;
bool slice_input = false;
for (int i = 0; i < num_spatial_dims; ++i) {
if (in_size[i] > 1 && out_size[i] == 1) {
// If in_size[i] > 1 but out_size[i] == 1, then we slice out the first
// entry before resizing.
slice_input = true;
- slice_size[i] = 1;
+ in_size[i] = 1;
}
}
if (slice_input) {
- input = xla::Slice(input, {0, 0, 0, 0},
- {batch, slice_size[0], slice_size[1], channels},
- {1, 1, 1, 1});
+ input =
+ xla::Slice(input, {0, 0, 0, 0},
+ {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
}
// Output is always type float.
@@ -406,6 +478,9 @@ class ResizeBilinearOp : public XlaOpKernel {
// operations along different dimensions.
// Given sufficient numerical stability and a<e<c and b<f<d, bilinear resize
// from image of size axb -> cxd is same as resizing axb -> exf -> cxd.
+ // This does not work in the case of align_corners_=false because of special
+ // padding requirements that cause multiple resizes to be very different
+ // from a single resize.
//
// This makes the convolutions kernels smaller and the operation faster.
xla::XlaOp output = input;
@@ -415,21 +490,24 @@ class ResizeBilinearOp : public XlaOpKernel {
(static_cast<float>(out_size[0]) - 1) / ((in_size[0] - 1) * 2),
(static_cast<float>(out_size[1]) - 1) / ((in_size[1] - 1) * 2)};
if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) &&
- k[0] > 1 && k[1] > 1) {
+ k[0] > 1 && k[1] > 1 && align_corners_) {
std::vector<int64> next_out_size = {(in_size[0] - 1) * 2 + 1,
(in_size[1] - 1) * 2 + 1};
- output = ResizeUsingDilationAndConvolution(
- b, input, num_spatial_dims, in_size, next_out_size, channels);
+ output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims,
+ in_size, next_out_size,
+ channels, align_corners_);
input = output;
in_size = next_out_size;
} else {
- output = ResizeUsingDilationAndConvolution(
- b, input, num_spatial_dims, in_size, out_size, channels);
+ output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims,
+ in_size, out_size,
+ channels, align_corners_);
in_size = out_size;
}
} else {
output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims,
- in_size, out_size, channels);
+ in_size, out_size, channels,
+ align_corners_);
in_size = out_size;
}
}
@@ -509,17 +587,20 @@ class ResizeBilinearGradOp : public XlaOpKernel {
std::vector<int64> next_grad_size = {(in_size[0] - 1) * 2 + 1,
(in_size[1] - 1) * 2 + 1};
output = ResizeUsingDilationAndConvolutionGradOp(
- b, grad, num_spatial_dims, in_size, next_grad_size, channels);
+ b, grad, num_spatial_dims, in_size, next_grad_size, channels,
+ align_corners_);
grad = output;
in_size = next_grad_size;
} else {
output = ResizeUsingDilationAndConvolutionGradOp(
- b, grad, num_spatial_dims, in_size, grad_size, channels);
+ b, grad, num_spatial_dims, in_size, grad_size, channels,
+ align_corners_);
in_size = grad_size;
}
} else {
output = ResizeUsingDilationAndConvolutionGradOp(
- b, grad, num_spatial_dims, in_size, grad_size, channels);
+ b, grad, num_spatial_dims, in_size, grad_size, channels,
+ align_corners_);
in_size = grad_size;
}
}
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
index d4d180aff8..f6f158a73b 100644
--- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
@@ -199,59 +199,6 @@ class MaxPool3DOp : public MaxPoolOp {
};
REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
-// Divide each element of an image by the count of elements that contributed to
-// that element during pooling.
-static xla::XlaOp AvgPoolDivideByCount(
- XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype,
- const TensorShape& input_shape, xla::Padding padding,
- const std::vector<int64>& ksize, const std::vector<int64>& stride,
- int num_spatial_dims, TensorFormat data_format) {
- if (padding == xla::Padding::kValid) {
- // In VALID padding, all windows have the same number of elements
- // contributing to each average. Divide by the window size everywhere to
- // get the average.
- int64 window_size = std::accumulate(ksize.begin(), ksize.end(), 1,
- [](int64 a, int64 b) { return a * b; });
-
- auto divisor =
- XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size);
- return xla::Div(output, divisor);
- } else {
- // For SAME padding, the padding shouldn't be included in the
- // counts. We use another ReduceWindow to find the right counts.
-
- // TODO(phawkins): use a less brute-force way to compute this. Only
- // the boundary regions will have interesting values here.
-
- std::vector<int64> input_dim_sizes(num_spatial_dims);
- std::vector<int64> window_dims(num_spatial_dims);
- std::vector<int64> window_ksize(num_spatial_dims);
- std::vector<int64> window_stride(num_spatial_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
- int dim = GetTensorSpatialDimIndex(num_spatial_dims + 2, data_format, i);
- input_dim_sizes[i] = input_shape.dim_size(dim);
- window_dims[i] = dim;
- window_ksize[i] = ksize[dim];
- window_stride[i] = stride[dim];
- }
-
- // Build a matrix of all 1s, with the same width/height as the input.
- const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype);
- auto ones = xla::Broadcast(
- XlaHelpers::One(ctx->builder(), accumulation_type), input_dim_sizes);
-
- // Perform a ReduceWindow with the same window size, strides, and padding
- // to count the number of contributions to each result element.
- auto reduce = xla::ReduceWindow(
- ones, XlaHelpers::Zero(ctx->builder(), accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), window_ksize, window_stride,
- xla::Padding::kSame);
- auto counts = XlaHelpers::ConvertElementType(ctx->builder(), reduce, dtype);
-
- return xla::Div(output, counts, window_dims);
- }
-}
-
class AvgPoolOp : public PoolingOp {
public:
AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
@@ -463,78 +410,31 @@ class AvgPoolGradOp : public XlaOpKernel {
errors::InvalidArgument("out_backprop must be ", num_dims(),
"-dimensional"));
- int depth_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
- int64 depth = out_backprop_shape.dim_size(depth_dim);
-
- // We can think of average-pooling as:
- // * a convolution with a kernel consisting entirely of 1s, where the
- // input feature and output feature are equal, and 0s everywhere else.
- // * followed by dividing by the counts.
- //
- // This then gives us an algorithm to build the gradient:
- // * divide out_backprop by the counts, followed by
- // * Conv2DBackpropInput specialized for that kernel, which simplifies to
- // a Pad and a ReduceWindow.
- //
- // For an explanation of backpropagation for convolution, see the comments
- // in third_party/tensorflow/core/kernels/conv_grad_ops.h
-
- // TF filter shape is [ H, W, ..., inC, outC ]
- std::vector<int64> filter_dims(num_dims());
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- filter_dims[i] = ksize_[dim];
- }
- filter_dims[num_dims() - 2] = depth;
- filter_dims[num_dims() - 1] = depth;
- TensorShape filter_shape(filter_dims);
-
- // Reuse the logic from Conv2DBackpropInput to compute padding.
- ConvBackpropDimensions dims;
- OP_REQUIRES_OK(
- ctx, ConvBackpropComputeDimensions(
- type_string(), /*num_spatial_dims=*/num_spatial_dims_,
- gradients_shape, filter_shape, out_backprop_shape, stride_,
- padding_, data_format_, &dims));
-
- // The input gradients are computed by a convolution of the output gradients
- // and the filter, with some appropriate padding. See the comment at the top
- // of conv_grad_ops.h for details.
- xla::XlaBuilder* const b = ctx->builder();
auto out_backprop = ctx->Input(1);
- auto dtype = input_type(1);
+ std::vector<int64> stride_int64s(stride_.begin(), stride_.end());
xla::Padding xla_padding =
(padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
-
- // Divide the out_backprop values by the counts for each spatial position.
- std::vector<int64> stride_int64s(stride_.begin(), stride_.end());
- auto out_backprop_div = AvgPoolDivideByCount(
- ctx, out_backprop, dtype, gradients_shape, xla_padding, ksize_,
- stride_int64s, num_spatial_dims_, data_format_);
-
- // Pad the gradients in the spatial dimensions. We use the same padding
- // as Conv2DBackpropInput.
- xla::PaddingConfig padding_config = xla::MakeNoPaddingConfig(num_dims());
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- auto* padding = padding_config.mutable_dimensions(dim);
- padding->set_edge_padding_low(dims.spatial_dims[i].pad_before);
- padding->set_edge_padding_high(dims.spatial_dims[i].pad_after);
- padding->set_interior_padding(dims.spatial_dims[i].stride - 1);
- }
-
- auto zero = XlaHelpers::Zero(b, dtype);
- auto padded_gradients = xla::Pad(out_backprop_div, zero, padding_config);
-
- // in_backprop = padded_gradients <conv> ones
- std::vector<int64> ones(num_dims(), 1LL);
- auto accumulation_type = XlaHelpers::SumAccumulationType(dtype);
- auto in_backprop = xla::ReduceWindow(
- XlaHelpers::ConvertElementType(b, padded_gradients, accumulation_type),
- XlaHelpers::Zero(b, accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), ksize_,
- /* window_strides=*/ones, xla::Padding::kValid);
- ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, in_backprop, dtype));
+ xla::PrimitiveType xla_reduction_type;
+ auto reduction_type = XlaHelpers::SumAccumulationType(ctx->input_type(1));
+ OP_REQUIRES_OK(
+ ctx, DataTypeToPrimitiveType(reduction_type, &xla_reduction_type));
+ auto converted_out_backprop =
+ xla::ConvertElementType(out_backprop, xla_reduction_type);
+ auto xla_data_format =
+ XlaTensorFormat(data_format_, gradients_shape.dims() - 2);
+ auto padding_values =
+ MakeSpatialPadding(gradients_shape.dim_sizes(), ksize_, stride_int64s,
+ xla_padding, xla_data_format);
+ auto in_backprop =
+ xla::AvgPoolGrad(converted_out_backprop, gradients_shape.dim_sizes(),
+ ksize_, stride_int64s, padding_values, xla_data_format,
+ /*counts_include_padding=*/padding_ == VALID);
+ // Convert the pooling result back to the input type before returning it.
+ xla::PrimitiveType xla_out_backprop_type;
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1),
+ &xla_out_backprop_type));
+ ctx->SetOutput(0,
+ xla::ConvertElementType(in_backprop, xla_out_backprop_type));
}
protected:
diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
index b11a4ce36d..8102faad28 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
@@ -32,41 +32,30 @@ class ReduceWindowOp : public XlaOpKernel {
explicit ReduceWindowOp(OpKernelConstruction* context)
: XlaOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("computation", &computation_));
- OP_REQUIRES_OK(context,
- context->GetAttr("window_dimensions", &window_dimensions_));
- OP_REQUIRES_OK(context,
- context->GetAttr("window_strides", &window_strides_));
- OP_REQUIRES_OK(context, context->GetAttr("padding_low", &padding_low_));
- OP_REQUIRES_OK(context, context->GetAttr("padding_high", &padding_high_));
}
void Compile(XlaOpKernelContext* context) override {
const TensorShape input_shape = context->InputShape(0);
const DataType dtype = context->input_type(0);
+ std::vector<int64> window_dimensions;
+ std::vector<int64> window_strides;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
+ "window_dimensions", &window_dimensions));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides",
+ &window_strides));
+
const int rank = input_shape.dims();
- OP_REQUIRES(context, rank == window_dimensions_.size(),
+ OP_REQUIRES(context, rank == window_dimensions.size(),
errors::InvalidArgument(
"The size of window_dimensions must be equal to the input "
"rank (",
- window_dimensions_.size(), " vs. ", rank, ")"));
- OP_REQUIRES(context, rank == window_strides_.size(),
+ window_dimensions.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == window_strides.size(),
errors::InvalidArgument(
"The size of window_strides must be equal to the input "
"rank (",
- window_strides_.size(), " vs. ", rank, ")"));
- OP_REQUIRES(context, rank == padding_low_.size(),
- errors::InvalidArgument(
- "The size of padding_low must be equal to the input "
- "rank (",
- padding_low_.size(), " vs. ", rank, ")"));
- OP_REQUIRES(context, rank == padding_high_.size(),
- errors::InvalidArgument(
- "The size of padding_high must be equal to the input "
- "rank (",
- padding_high_.size(), " vs. ", rank, ")"));
-
- xla::XlaBuilder* builder = context->builder();
+ window_strides.size(), " vs. ", rank, ")"));
// Build the reducer function.
XlaCompiler::Argument reducer_arg;
@@ -78,6 +67,7 @@ class ReduceWindowOp : public XlaOpKernel {
compile_options.use_tuple_arg = false;
compile_options.resolve_compile_time_constants = false;
compile_options.is_entry_computation = false;
+ compile_options.always_return_tuple = false;
XlaCompiler::CompilationResult reducer;
OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
compile_options, *computation_,
@@ -86,51 +76,47 @@ class ReduceWindowOp : public XlaOpKernel {
xla::Shape scalar_shape;
OP_REQUIRES_OK(context,
TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape));
+ OP_REQUIRES(
+ context,
+ xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape),
+ errors::InvalidArgument(
+ "Invalid output shape of ReduceWindow reducer. Expected ",
+ xla::ShapeUtil::HumanString(scalar_shape), " got ",
+ xla::ShapeUtil::HumanString(reducer.xla_output_shape)));
+
+ const TensorShape padding_shape = context->InputShape("padding");
OP_REQUIRES(context,
- xla::ShapeUtil::Compatible(
- reducer.xla_output_shape,
- xla::ShapeUtil::MakeTupleShape({scalar_shape})),
+ TensorShapeUtils::IsMatrix(padding_shape) &&
+ padding_shape.dim_size(1) == 2,
errors::InvalidArgument(
- "Invalid output shape of ReduceWindow reducer. Expected ",
- xla::ShapeUtil::HumanString(scalar_shape), " got ",
- xla::ShapeUtil::HumanString(reducer.xla_output_shape)));
-
- // Wraps the reducer in a computation that unpacks the output tuple.
- xla::XlaComputation wrapper;
- {
- std::unique_ptr<xla::XlaBuilder> cb =
- builder->CreateSubBuilder("wrapper");
- auto x = xla::Parameter(cb.get(), 0, scalar_shape, "x");
- auto y = xla::Parameter(cb.get(), 1, scalar_shape, "y");
- auto outputs = xla::Call(cb.get(), *reducer.computation, {x, y});
- xla::GetTupleElement(outputs, 0);
- xla::StatusOr<xla::XlaComputation> result = cb->Build();
- OP_REQUIRES_OK(context, result.status());
- wrapper = std::move(result.ValueOrDie());
- }
-
- std::vector<std::pair<int64, int64>> padding(rank);
- for (int i = 0; i < rank; ++i) {
- padding[i] = {padding_low_[i], padding_high_[i]};
+ "padding must be a matrix with minor dimension 2, got ",
+ padding_shape.DebugString()));
+ xla::Literal padding_literal;
+ OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal(
+ "padding", &padding_literal));
+ std::vector<std::pair<int64, int64>> padding(padding_shape.dim_size(0));
+ for (int i = 0; i < padding.size(); ++i) {
+ padding[i] = {padding_literal.Get<int64>({i, 0}),
+ padding_literal.Get<int64>({i, 1})};
}
xla::XlaOp output = xla::ReduceWindowWithGeneralPadding(
- context->Input(0), context->Input(1), wrapper, window_dimensions_,
- window_strides_, padding);
+ context->Input(0), context->Input(1), *reducer.computation,
+ window_dimensions, window_strides, padding);
context->SetOutput(0, output);
}
private:
const NameAttrList* computation_;
- std::vector<int64> window_dimensions_;
- std::vector<int64> window_strides_;
- std::vector<int64> padding_low_;
- std::vector<int64> padding_high_;
TF_DISALLOW_COPY_AND_ASSIGN(ReduceWindowOp);
};
-REGISTER_XLA_OP(Name("XlaReduceWindow"), ReduceWindowOp);
+REGISTER_XLA_OP(Name("XlaReduceWindow")
+ .CompileTimeConstInput("window_dimensions")
+ .CompileTimeConstInput("window_strides")
+ .CompileTimeConstInput("padding"),
+ ReduceWindowOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
index 6a71b8ca36..598248563b 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
@@ -15,6 +15,7 @@ limitations under the License.
// XLA-specific reduction Ops.
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
@@ -66,7 +67,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector(1, &axes));
VLOG(1) << "data shape: " << data_shape.DebugString();
- VLOG(1) << "axes : " << str_util::Join(axes, ",");
+ VLOG(1) << "axes : " << absl::StrJoin(axes, ",");
gtl::InlinedVector<bool, 4> bitmap(data_shape.dims(), false);
std::vector<int64> xla_axes;
diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
index 025ba82741..d6bd927135 100644
--- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
@@ -15,6 +15,7 @@ limitations under the License.
// XLA-specific Ops for softmax.
+#include "absl/strings/match.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace {
@@ -33,7 +33,7 @@ namespace {
class SoftmaxOp : public XlaOpKernel {
public:
explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
- log_ = str_util::StartsWith(type_string(), "Log");
+ log_ = absl::StartsWith(type_string(), "Log");
}
void Compile(XlaOpKernelContext* ctx) override {
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc
new file mode 100644
index 0000000000..412afeaaad
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc
@@ -0,0 +1,115 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaBroadcastHelperOp : public XlaOpKernel {
+ public:
+ explicit XlaBroadcastHelperOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ xla::XlaOp lhs = context->Input(0);
+ xla::XlaOp rhs = context->Input(1);
+ const TensorShape lhs_shape = context->InputShape(0);
+ const TensorShape rhs_shape = context->InputShape(1);
+
+ const bool broadcast_lhs = lhs_shape.dims() < rhs_shape.dims();
+ const TensorShape* min_rank_shape = broadcast_lhs ? &lhs_shape : &rhs_shape;
+ const TensorShape* max_rank_shape = broadcast_lhs ? &rhs_shape : &lhs_shape;
+
+ std::vector<int64> broadcast_dims;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("broadcast_dims",
+ &broadcast_dims));
+ if (broadcast_dims.empty()) {
+ OP_REQUIRES(
+ context,
+ lhs_shape.dims() == rhs_shape.dims() || lhs_shape.dims() == 0 ||
+ rhs_shape.dims() == 0,
+ errors::InvalidArgument(
+ "If broadcast_dims is empty, both "
+ "arguments must have equal rank; "
+ "argument shapes, or at least one argument must be a scalar: ",
+ lhs_shape.DebugString(), " and ", rhs_shape.DebugString()));
+ context->SetOutput(0, lhs);
+ context->SetOutput(1, rhs);
+ return;
+ }
+
+ OP_REQUIRES(
+ context, broadcast_dims.size() == min_rank_shape->dims(),
+ errors::InvalidArgument(
+ "broadcast_dims must have size equal to the smaller argument rank; "
+ "broadcast_dims: [",
+ absl::StrJoin(broadcast_dims, ","), "]; argument shapes: ",
+ lhs_shape.DebugString(), " and ", rhs_shape.DebugString()));
+ std::vector<int64> sorted_broadcast_dims = broadcast_dims;
+ absl::c_sort(sorted_broadcast_dims);
+ std::set<int64> dims_set(broadcast_dims.begin(), broadcast_dims.end());
+ OP_REQUIRES(context,
+ dims_set.size() == broadcast_dims.size() &&
+ broadcast_dims == sorted_broadcast_dims,
+ errors::InvalidArgument(
+ "Duplicate or nonmonotonic dimension in broadcast_dims; "
+ "broadcast_dims: [",
+ absl::StrJoin(broadcast_dims, ","), "]"));
+
+ std::vector<int64> broadcast_shape(max_rank_shape->dims(), 1LL);
+ for (int i = 0; i < broadcast_dims.size(); ++i) {
+ const int dim = broadcast_dims[i];
+ OP_REQUIRES(
+ context, dim >= 0 && dim < broadcast_shape.size(),
+ errors::InvalidArgument(
+ "Invalid broadcast dimension (", dim, "); broadcast_dims: [",
+ absl::StrJoin(broadcast_dims, ","), "]; argument shapes: ",
+ lhs_shape.DebugString(), " and ", rhs_shape.DebugString()));
+ broadcast_shape[dim] = min_rank_shape->dim_size(i);
+ }
+ xla::PrimitiveType type = context->input_xla_type(0);
+ xla::Shape broadcast_xla_shape =
+ xla::ShapeUtil::MakeShape(type, broadcast_shape);
+ if (broadcast_lhs) {
+ lhs = xla::BroadcastInDim(lhs, broadcast_xla_shape, broadcast_dims);
+ } else {
+ rhs = xla::BroadcastInDim(rhs, broadcast_xla_shape, broadcast_dims);
+ }
+ context->SetOutput(0, lhs);
+ context->SetOutput(1, rhs);
+ }
+
+ private:
+ xla::DotDimensionNumbers dnums_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaBroadcastHelperOp);
+};
+
+REGISTER_XLA_OP(
+ Name("XlaBroadcastHelper").CompileTimeConstInput("broadcast_dims"),
+ XlaBroadcastHelperOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc
new file mode 100644
index 0000000000..8848623868
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc
@@ -0,0 +1,101 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaConvOp : public XlaOpKernel {
+ public:
+ explicit XlaConvOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+ string dnums_attr;
+ OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
+ OP_REQUIRES(
+ context, dnums_.ParsePartialFromString(dnums_attr),
+ errors::InvalidArgument("Error parsing convolution dimension numbers"));
+ string precision_config_attr;
+ OP_REQUIRES_OK(
+ context, context->GetAttr("precision_config", &precision_config_attr));
+ OP_REQUIRES(
+ context,
+ precision_config_.ParsePartialFromString(precision_config_attr),
+ errors::InvalidArgument("Error parsing convolution dimension numbers"));
+ }
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape lhs_shape = context->InputShape(0);
+ const TensorShape rhs_shape = context->InputShape(1);
+ const TensorShape padding_shape = context->InputShape("padding");
+ std::vector<int64> window_strides;
+ std::vector<int64> lhs_dilation;
+ std::vector<int64> rhs_dilation;
+ int64 feature_group_count;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides",
+ &window_strides));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("lhs_dilation",
+ &lhs_dilation));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("rhs_dilation",
+ &rhs_dilation));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(
+ "feature_group_count", &feature_group_count));
+
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsMatrix(padding_shape) &&
+ padding_shape.dim_size(1) == 2,
+ errors::InvalidArgument(
+ "padding must be a matrix with minor dimension 2, got ",
+ padding_shape.DebugString()));
+ xla::Literal padding_literal;
+ OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal(
+ "padding", &padding_literal));
+ std::vector<std::pair<int64, int64>> padding(padding_shape.dim_size(0));
+ for (int i = 0; i < padding.size(); ++i) {
+ padding[i] = {padding_literal.Get<int64>({i, 0}),
+ padding_literal.Get<int64>({i, 1})};
+ }
+
+ // We do only minimal checking, relying on XLA to check the shape
+ // invariants.
+ xla::XlaOp output = xla::ConvGeneralDilated(
+ context->Input(0), context->Input(1), window_strides, padding,
+ lhs_dilation, rhs_dilation, dnums_, feature_group_count,
+ &precision_config_);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ xla::ConvolutionDimensionNumbers dnums_;
+ xla::PrecisionConfigProto precision_config_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaConvOp);
+};
+
+REGISTER_XLA_OP(Name("XlaConv")
+ .CompileTimeConstInput("window_strides")
+ .CompileTimeConstInput("lhs_dilation")
+ .CompileTimeConstInput("rhs_dilation")
+ .CompileTimeConstInput("feature_group_count")
+ .CompileTimeConstInput("padding"),
+ XlaConvOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc
new file mode 100644
index 0000000000..2fed53e5c0
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaDotOp : public XlaOpKernel {
+ public:
+ explicit XlaDotOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+ string dnums_attr;
+ OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
+ OP_REQUIRES(
+ context, dnums_.ParsePartialFromString(dnums_attr),
+ errors::InvalidArgument("Error parsing convolution dimension numbers"));
+ string precision_config_attr;
+ OP_REQUIRES_OK(
+ context, context->GetAttr("precision_config", &precision_config_attr));
+ OP_REQUIRES(
+ context,
+ precision_config_.ParsePartialFromString(precision_config_attr),
+ errors::InvalidArgument("Error parsing convolution dimension numbers"));
+ }
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape lhs_shape = context->InputShape(0);
+ const TensorShape rhs_shape = context->InputShape(1);
+
+ // We do only minimal checking, relying on XLA to check the shape
+ // invariants.
+ xla::XlaOp output = xla::DotGeneral(context->Input(0), context->Input(1),
+ dnums_, &precision_config_);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ xla::DotDimensionNumbers dnums_;
+ xla::PrecisionConfigProto precision_config_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaDotOp);
+};
+
+REGISTER_XLA_OP(Name("XlaDot"), XlaDotOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc
new file mode 100644
index 0000000000..59502d83c7
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc
@@ -0,0 +1,105 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaPadOp : public XlaOpKernel {
+ public:
+ explicit XlaPadOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape("input");
+ const TensorShape padding_value_shape =
+ context->InputShape("padding_value");
+
+ std::vector<int64> padding_low;
+ std::vector<int64> padding_high;
+ std::vector<int64> padding_interior;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("padding_low",
+ &padding_low));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("padding_high",
+ &padding_high));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
+ "padding_interior", &padding_interior));
+
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(padding_value_shape),
+ errors::InvalidArgument("padding_value must be a scalar"));
+ const int rank = input_shape.dims();
+ OP_REQUIRES(context, rank == padding_low.size(),
+ errors::InvalidArgument(
+ "The size of padding_low must be equal to the input "
+ "rank (",
+ padding_low.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == padding_high.size(),
+ errors::InvalidArgument(
+ "The size of padding_high must be equal to the input "
+ "rank (",
+ padding_high.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == padding_interior.size(),
+ errors::InvalidArgument(
+ "The size of padding_interior must be equal to the input "
+ "rank (",
+ padding_interior.size(), " vs. ", rank, ")"));
+
+ auto non_negative = [](int64 x) { return x >= 0; };
+ OP_REQUIRES(
+ context, absl::c_all_of(padding_low, non_negative),
+ errors::InvalidArgument("padding_low must be non-negative, got [",
+ absl::StrJoin(padding_low, ","), "]"));
+ OP_REQUIRES(
+ context, absl::c_all_of(padding_high, non_negative),
+ errors::InvalidArgument("padding_high must be non-negative, got [",
+ absl::StrJoin(padding_high, ","), "]"));
+ OP_REQUIRES(
+ context, absl::c_all_of(padding_interior, non_negative),
+ errors::InvalidArgument("padding_interior must be non-negative, got [",
+ absl::StrJoin(padding_interior, ","), "]"));
+
+ xla::PaddingConfig padding_config;
+ for (int i = 0; i < rank; ++i) {
+ auto* dim = padding_config.add_dimensions();
+ dim->set_edge_padding_low(padding_low[i]);
+ dim->set_edge_padding_high(padding_high[i]);
+ dim->set_interior_padding(padding_interior[i]);
+ }
+
+ xla::XlaOp output =
+ xla::Pad(context->Input("input"), context->Input("padding_value"),
+ padding_config);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaPadOp);
+};
+
+REGISTER_XLA_OP(Name("XlaPad")
+ .CompileTimeConstInput("padding_low")
+ .CompileTimeConstInput("padding_high")
+ .CompileTimeConstInput("padding_interior"),
+ XlaPadOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc
new file mode 100644
index 0000000000..fc2425f37b
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc
@@ -0,0 +1,102 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/algorithm/container.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaReduceOp : public XlaOpKernel {
+ public:
+ explicit XlaReduceOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("reducer", &reducer_));
+ OP_REQUIRES_OK(context, context->GetAttr("dimensions_to_reduce",
+ &dimensions_to_reduce_));
+ std::set<int64> dims_set(dimensions_to_reduce_.begin(),
+ dimensions_to_reduce_.end());
+ OP_REQUIRES(
+ context, dims_set.size() == dimensions_to_reduce_.size(),
+ errors::InvalidArgument("Duplicate dimension in dimensions_to_reduce "
+ "argument to XlaReduce"));
+ }
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape("input");
+ const TensorShape init_value_shape = context->InputShape("init_value");
+ const DataType dtype = context->input_type(0);
+
+ const int rank = input_shape.dims();
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(init_value_shape),
+ errors::InvalidArgument("init_value must be a scalar"));
+
+ auto dim_in_range = [rank](int64 dim) { return dim >= 0 && dim < rank; };
+ OP_REQUIRES(context,
+ rank >= dimensions_to_reduce_.size() &&
+ absl::c_all_of(dimensions_to_reduce_, dim_in_range),
+ errors::InvalidArgument(
+ "Invalid dimensions_to_reduce argument to XlaReduce"));
+
+ // Build the reducer function.
+ XlaCompiler::Argument reducer_arg;
+ reducer_arg.kind = XlaCompiler::Argument::kParameter;
+ reducer_arg.type = dtype;
+ reducer_arg.shape = TensorShape();
+
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.use_tuple_arg = false;
+ compile_options.always_return_tuple = false;
+ compile_options.resolve_compile_time_constants = false;
+ compile_options.is_entry_computation = false;
+ XlaCompiler::CompilationResult reducer;
+ OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
+ compile_options, *reducer_,
+ {reducer_arg, reducer_arg}, &reducer));
+
+ xla::Shape scalar_shape;
+ OP_REQUIRES_OK(context,
+ TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape));
+ OP_REQUIRES(
+ context,
+ xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape),
+ errors::InvalidArgument(
+ "Invalid output shape of XlaReduce reducer. Expected ",
+ xla::ShapeUtil::HumanString(scalar_shape), " got ",
+ xla::ShapeUtil::HumanString(reducer.xla_output_shape)));
+
+ xla::XlaOp output =
+ xla::Reduce(context->Input("input"), context->Input("init_value"),
+ *reducer.computation, dimensions_to_reduce_);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ const NameAttrList* reducer_;
+ std::vector<int64> dimensions_to_reduce_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaReduceOp);
+};
+
+REGISTER_XLA_OP(Name("XlaReduce"), XlaReduceOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc
new file mode 100644
index 0000000000..089776fcf7
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc
@@ -0,0 +1,147 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/kernels/while_op.h"
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaSelectAndScatterOp : public XlaOpKernel {
+ public:
+ explicit XlaSelectAndScatterOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("select", &select_computation_));
+ OP_REQUIRES_OK(context, context->GetAttr("scatter", &scatter_computation_));
+ }
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape(0);
+ const DataType dtype = context->input_type(0);
+
+ std::vector<int64> window_dimensions;
+ std::vector<int64> window_strides;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
+ "window_dimensions", &window_dimensions));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides",
+ &window_strides));
+
+ const int rank = input_shape.dims();
+ OP_REQUIRES(context, rank == window_dimensions.size(),
+ errors::InvalidArgument(
+ "The size of window_dimensions must be equal to the input "
+ "rank (",
+ window_dimensions.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == window_strides.size(),
+ errors::InvalidArgument(
+ "The size of window_strides must be equal to the input "
+ "rank (",
+ window_strides.size(), " vs. ", rank, ")"));
+
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.use_tuple_arg = false;
+ compile_options.resolve_compile_time_constants = false;
+ compile_options.is_entry_computation = false;
+ compile_options.always_return_tuple = false;
+
+ // Build the select function.
+ XlaCompiler::Argument select_arg;
+ select_arg.kind = XlaCompiler::Argument::kParameter;
+ select_arg.type = dtype;
+ select_arg.shape = TensorShape();
+
+ XlaCompiler::CompilationResult select;
+ OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
+ compile_options, *select_computation_,
+ {select_arg, select_arg}, &select));
+
+ xla::Shape select_output_shape = xla::ShapeUtil::MakeShape(xla::PRED, {});
+ OP_REQUIRES(
+ context,
+ xla::ShapeUtil::Compatible(select.xla_output_shape,
+ select_output_shape),
+ errors::InvalidArgument(
+ "Invalid output shape of XlaSelectAndScatter select. Expected ",
+ xla::ShapeUtil::HumanString(select_output_shape), " got ",
+ xla::ShapeUtil::HumanString(select.xla_output_shape)));
+
+ // Build the scatter function.
+ XlaCompiler::Argument scatter_arg;
+ scatter_arg.kind = XlaCompiler::Argument::kParameter;
+ scatter_arg.type = dtype;
+ scatter_arg.shape = TensorShape();
+
+ XlaCompiler::CompilationResult scatter;
+ OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
+ compile_options, *scatter_computation_,
+ {scatter_arg, scatter_arg}, &scatter));
+
+ xla::Shape scalar_shape;
+ OP_REQUIRES_OK(context,
+ TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape));
+ OP_REQUIRES(
+ context,
+ xla::ShapeUtil::Compatible(scatter.xla_output_shape, scalar_shape),
+ errors::InvalidArgument(
+ "Invalid output shape of scatter. Expected ",
+ xla::ShapeUtil::HumanString(scalar_shape), " got ",
+ xla::ShapeUtil::HumanString(scatter.xla_output_shape)));
+
+ const TensorShape padding_shape = context->InputShape("padding");
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsMatrix(padding_shape) &&
+ padding_shape.dim_size(1) == 2,
+ errors::InvalidArgument(
+ "padding must be a matrix with minor dimension 2, got ",
+ padding_shape.DebugString()));
+ xla::Literal padding_literal;
+ OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal(
+ "padding", &padding_literal));
+ std::vector<std::pair<int64, int64>> padding(padding_shape.dim_size(0));
+ for (int i = 0; i < padding.size(); ++i) {
+ padding[i] = {padding_literal.Get<int64>({i, 0}),
+ padding_literal.Get<int64>({i, 1})};
+ }
+
+ xla::XlaOp output = xla::SelectAndScatterWithGeneralPadding(
+ context->Input("operand"), *select.computation, window_dimensions,
+ window_strides, padding, context->Input("source"),
+ context->Input("init_value"), *scatter.computation);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ const NameAttrList* select_computation_;
+ const NameAttrList* scatter_computation_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaSelectAndScatterOp);
+};
+
+REGISTER_XLA_OP(Name("XlaSelectAndScatter")
+ .CompileTimeConstInput("window_dimensions")
+ .CompileTimeConstInput("window_strides")
+ .CompileTimeConstInput("padding"),
+ XlaSelectAndScatterOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index cb7a40e23d..99511e9914 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -25,8 +25,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core:lib",
],
)
@@ -44,8 +44,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/core:lib",
],
@@ -78,8 +78,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:math",
@@ -119,6 +119,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:constants",
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
index f666d22ea4..d8c050d09e 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
@@ -27,7 +27,8 @@ limitations under the License.
namespace tensorflow {
xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
- bool transpose_y, bool conjugate_x, bool conjugate_y) {
+ bool transpose_y, bool conjugate_x, bool conjugate_y,
+ xla::PrecisionConfigProto::Precision precision) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
@@ -95,6 +96,10 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
y = xla::Conj(y);
}
+ xla::PrecisionConfigProto precision_proto;
+ precision_proto.add_operand_precision(precision);
+ precision_proto.add_operand_precision(precision);
+
// If there are no batch dimensions, use a regular Dot.
// TODO(b/69062148) Remove this code when Dot emitters can be passed
// dimensions to transpose directly (i.e. without requiring a Transpose
@@ -102,7 +107,7 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
if (batch_dimension_numbers.empty()) {
auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x;
auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y;
- return xla::Dot(lhs, rhs);
+ return xla::Dot(lhs, rhs, &precision_proto);
}
xla::DotDimensionNumbers dot_dnums;
@@ -112,7 +117,8 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
dot_dnums.add_lhs_batch_dimensions(batch_dimension_number);
dot_dnums.add_rhs_batch_dimensions(batch_dimension_number);
}
- return xla::DotGeneral(x, y, dot_dnums);
+
+ return xla::DotGeneral(x, y, dot_dnums, &precision_proto);
});
}
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h
index 8757b16a1c..6cfccd5553 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.h
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h
@@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_
#include "tensorflow/compiler/xla/client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace tensorflow {
@@ -45,7 +45,9 @@ namespace tensorflow {
// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false,
bool transpose_y = false, bool conjugate_x = false,
- bool conjugate_y = false);
+ bool conjugate_y = false,
+ xla::PrecisionConfigProto::Precision precision =
+ xla::PrecisionConfigProto::DEFAULT);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
index 87d73eb3f0..67fb56510c 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -49,7 +49,8 @@ namespace {
// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) /
// l[..., j, j]
// return l
-xla::XlaOp CholeskyUnblocked(xla::XlaOp a) {
+xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
+ xla::PrecisionConfigProto::Precision precision) {
xla::XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
@@ -101,7 +102,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) {
// np.dot(row, np.swapaxes(row, -1, -2))
auto diag_dot = BatchDot(row, row,
/*transpose_x=*/false,
- /*transpose_y=*/true);
+ /*transpose_y=*/true, /*conjugate_x=*/false,
+ /*conjugate_y=*/false, precision);
// l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row,
// np.swapaxes(row, -1, -2)))
auto l_ii =
@@ -121,7 +123,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) {
// r.T)
auto dot = BatchDot(body_l, row,
/*transpose_x=*/false,
- /*transpose_y=*/true);
+ /*transpose_y=*/true, /*conjugate_x=*/false,
+ /*conjugate_y=*/false, precision);
// np.dot(l[..., i+1:, :i], r.T)
auto dot_ip1 =
xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot);
@@ -145,7 +148,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) {
} // namespace
-xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) {
+xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size,
+ xla::PrecisionConfigProto::Precision precision) {
xla::XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
@@ -181,14 +185,15 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) {
auto lhs = SliceInMinorDims(l, {i, 0}, {n, i});
auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i});
auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false,
- /*transpose_y=*/true);
+ /*transpose_y=*/true, /*conjugate_x=*/false,
+ /*conjugate_y=*/false, precision);
auto before = SliceInMinorDims(a, {i, i}, {n, i + k});
a = UpdateSliceInMinorDims(a, before - delta, {i, i});
}
// l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k])
auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k});
- auto factorized = CholeskyUnblocked(x);
+ auto factorized = CholeskyUnblocked(x, precision);
l = UpdateSliceInMinorDims(l, factorized, {i, i});
if (i + k < n) {
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h
index 1bef9bb166..60cd7ded53 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.h
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.h
@@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
#include "tensorflow/compiler/xla/client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace tensorflow {
@@ -30,7 +30,9 @@ namespace tensorflow {
// TODO(phawkins): check for negative values on the diagonal and return an
// error, instead of silently yielding NaNs.
// TODO(znado): handle the complex Hermitian case
-xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256);
+xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256,
+ xla::PrecisionConfigProto::Precision precision =
+ xla::PrecisionConfigProto::HIGHEST);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc
index fc0c1ee838..b6f30d8d49 100644
--- a/tensorflow/compiler/tf2xla/lib/qr.cc
+++ b/tensorflow/compiler/tf2xla/lib/qr.cc
@@ -149,7 +149,8 @@ struct QRBlockResult {
xla::XlaOp taus; // Shape: [..., n]
xla::XlaOp vs; // Shape: [..., m, n]
};
-xla::StatusOr<QRBlockResult> QRBlock(xla::XlaOp a) {
+xla::StatusOr<QRBlockResult> QRBlock(
+ xla::XlaOp a, xla::PrecisionConfigProto::Precision precision) {
xla::XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
const int num_dims = xla::ShapeUtil::Rank(a_shape);
@@ -190,8 +191,12 @@ xla::StatusOr<QRBlockResult> QRBlock(xla::XlaOp a) {
auto v_broadcast = xla::Reshape(v, shape);
// a[:, :] -= tau * np.dot(v[:, np.newaxis],
// np.dot(v[np.newaxis, :], a[:, :]))
- auto vva = BatchDot(v_broadcast, a);
- vva = BatchDot(v_broadcast, vva, /*transpose_x=*/true);
+ auto vva =
+ BatchDot(v_broadcast, a, /*transpose_x=*/false, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
+ vva =
+ BatchDot(v_broadcast, vva, /*transpose_x=*/true, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
a = a - xla::Mul(tau, vva,
/*broadcast_dimensions=*/batch_dim_indices);
@@ -251,7 +256,8 @@ xla::StatusOr<QRBlockResult> QRBlock(xla::XlaOp a) {
// vs.
xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
xla::PrimitiveType type, gtl::ArraySlice<int64> batch_dims, xla::XlaOp vs,
- xla::XlaOp taus, int64 m, int64 n) {
+ xla::XlaOp taus, int64 m, int64 n,
+ xla::PrecisionConfigProto::Precision precision) {
std::vector<int64> batch_dim_indices(batch_dims.size());
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
int64 n_index = batch_dims.size() + 1;
@@ -272,9 +278,12 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
auto beta = DynamicSliceInMinorDims(taus, {j}, {1});
// yv has shape [..., n, 1]
- auto yv = BatchDot(y, v, /*transpose_x=*/true);
+ auto yv = BatchDot(y, v, /*transpose_x=*/true, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
// wyv has shape [..., m, 1]
- auto wyv = BatchDot(w, yv);
+ auto wyv =
+ BatchDot(w, yv, /*transpose_x=*/false, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
auto z = xla::Mul(
-beta, v + wyv,
@@ -321,8 +330,9 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
// return (q, a)
// TODO(phawkins): consider using UT transformations (in the form I - V U V')
// rather than WY transformations.
-xla::StatusOr<QRDecompositionResult> QRDecomposition(xla::XlaOp a,
- int64 block_size) {
+xla::StatusOr<QRDecompositionResult> QRDecomposition(
+ xla::XlaOp a, int64 block_size,
+ xla::PrecisionConfigProto::Precision precision) {
xla::XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
const int num_dims = xla::ShapeUtil::Rank(a_shape);
@@ -352,29 +362,36 @@ xla::StatusOr<QRDecompositionResult> QRDecomposition(xla::XlaOp a,
int64 k = std::min(block_size, p - i);
auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k});
- TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block));
+ TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block, precision));
a = UpdateSliceInMinorDims(a, qr_block.r, {i, i});
// Compute the I-WY block representation of a product of Householder
// matrices.
- TF_ASSIGN_OR_RETURN(auto w,
- ComputeWYRepresentation(type, batch_dims, qr_block.vs,
- qr_block.taus, m - i, k));
+ TF_ASSIGN_OR_RETURN(
+ auto w, ComputeWYRepresentation(type, batch_dims, qr_block.vs,
+ qr_block.taus, m - i, k, precision));
auto y = qr_block.vs;
// a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:]))
auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n});
- auto a_update = BatchDot(w, a_panel, /*transpose_x=*/true);
- a_update = BatchDot(y, a_update);
+ auto a_update =
+ BatchDot(w, a_panel, /*transpose_x=*/true, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
+ a_update =
+ BatchDot(y, a_update, /*transpose_x=*/false, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
a_panel = a_panel + a_update;
a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
// q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T))
auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
- auto q_update = BatchDot(q_panel, w);
- q_update =
- BatchDot(q_update, y, /*transpose_x=*/false, /*transpose_y=*/true);
+ auto q_update =
+ BatchDot(q_panel, w, /*transpose_x=*/false, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
+ q_update = BatchDot(q_update, y, /*transpose_x=*/false,
+ /*transpose_y=*/true, /*conjugate_x=*/false,
+ /*conjugate_y=*/false, precision);
q_panel = q_panel + q_update;
q = UpdateSliceInMinorDims(q, q_panel, {0, i});
}
diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h
index abd2316ac9..05565477b6 100644
--- a/tensorflow/compiler/tf2xla/lib/qr.h
+++ b/tensorflow/compiler/tf2xla/lib/qr.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace tensorflow {
@@ -32,8 +33,10 @@ struct QRDecompositionResult {
xla::XlaOp r;
};
-xla::StatusOr<QRDecompositionResult> QRDecomposition(xla::XlaOp a,
- int64 block_size = 128);
+xla::StatusOr<QRDecompositionResult> QRDecomposition(
+ xla::XlaOp a, int64 block_size = 128,
+ xla::PrecisionConfigProto::Precision precision =
+ xla::PrecisionConfigProto::HIGHEST);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index febb638e5e..37b2240b45 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -110,8 +110,9 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) {
});
}
-xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower,
- bool transpose_a, bool conjugate_a) {
+xla::XlaOp InvertDiagonalBlocks(
+ xla::XlaOp diag_blocks, bool lower, bool transpose_a, bool conjugate_a,
+ xla::PrecisionConfigProto::Precision precision) {
xla::XlaBuilder* builder = diag_blocks.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
// Input is a batch of square lower triangular square matrices. Its shape is
@@ -215,7 +216,10 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower,
dnums.add_rhs_batch_dimensions(0);
dnums.add_lhs_contracting_dimensions(2);
dnums.add_rhs_contracting_dimensions(1);
- auto update = -DotGeneral(input_row, body_out, dnums);
+ xla::PrecisionConfigProto precision_proto;
+ precision_proto.add_operand_precision(precision);
+ precision_proto.add_operand_precision(precision);
+ auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto);
body_out = DynamicUpdateSlice(body_out, update, start_indices);
@@ -238,10 +242,10 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower,
});
}
-xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b,
- xla::XlaOp inv_diag_blocks,
- bool left_side, bool lower,
- bool transpose_a, bool conjugate_a) {
+xla::XlaOp SolveWithInvertedDiagonalBlocks(
+ xla::XlaOp a, xla::XlaOp b, xla::XlaOp inv_diag_blocks, bool left_side,
+ bool lower, bool transpose_a, bool conjugate_a,
+ xla::PrecisionConfigProto::Precision precision) {
xla::XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape,
@@ -307,9 +311,13 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b,
auto a_row =
MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a);
if (left_side) {
- remainder = b_row - BatchDot(a_row, x, transpose_a, false);
+ remainder = b_row - BatchDot(a_row, x, transpose_a, false,
+ /*conjugate_x=*/false,
+ /*conjugate_y=*/false, precision);
} else {
- remainder = b_row - BatchDot(x, a_row, false, transpose_a);
+ remainder = b_row - BatchDot(x, a_row, false, transpose_a,
+ /*conjugate_x=*/false,
+ /*conjugate_y=*/false, precision);
}
}
@@ -319,9 +327,13 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b,
xla::ConstantR0WithType(builder, xla::S32, j * block_size);
std::vector<xla::XlaOp> update_starts = {start_index, zero};
if (left_side) {
- x_update = BatchDot(inv_block, remainder, transpose_a, false);
+ x_update =
+ BatchDot(inv_block, remainder, transpose_a, false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
} else {
- x_update = BatchDot(remainder, inv_block, false, transpose_a);
+ x_update =
+ BatchDot(remainder, inv_block, false, transpose_a,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
std::swap(update_starts[0], update_starts[1]);
}
x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts);
@@ -333,7 +345,8 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b,
xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side,
bool lower, bool transpose_a, bool conjugate_a,
- int64 block_size) {
+ int64 block_size,
+ xla::PrecisionConfigProto::Precision precision) {
xla::XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
@@ -388,12 +401,13 @@ xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side,
auto diag_blocks = DiagonalBlocks(a, block_size);
// We invert these blocks in parallel using batched matrix-vector products
- auto inv_diag_blocks =
- InvertDiagonalBlocks(diag_blocks, lower, transpose_a, conjugate_a);
+ auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a,
+ conjugate_a, precision);
// We now find the solution using GEMMs
- auto x = SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side,
- lower, transpose_a, conjugate_a);
+ auto x =
+ SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, lower,
+ transpose_a, conjugate_a, precision);
return x;
});
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
index 555760b7ef..ac42a48352 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
@@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_
#include "tensorflow/compiler/xla/client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace tensorflow {
@@ -59,7 +59,9 @@ namespace tensorflow {
// blocking is used.
xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side,
bool lower, bool transpose_a, bool conjugate_a,
- int64 block_size = 128);
+ int64 block_size = 128,
+ xla::PrecisionConfigProto::Precision precision =
+ xla::PrecisionConfigProto::HIGHEST);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD
index ace6fd1d8e..4dce0a2102 100644
--- a/tensorflow/compiler/tf2xla/ops/BUILD
+++ b/tensorflow/compiler/tf2xla/ops/BUILD
@@ -11,6 +11,8 @@ cc_library(
srcs = ["xla_ops.cc"],
deps = [
"//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
alwayslink = 1,
)
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index a59c77f5c3..2cd9ae799f 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -13,11 +13,97 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "absl/algorithm/container.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
+namespace {
+
+// Helper shape function for operators that return an output with the same rank
+// as their first input.
+Status UnchangedRank(shape_inference::InferenceContext* c) {
+ if (c->RankKnown(c->input(0))) {
+ c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
+ } else {
+ c->set_output(0, c->input(0));
+ }
+ return Status::OK();
+}
+
+REGISTER_OP("XlaBroadcastHelper")
+ .Input("lhs: T")
+ .Input("rhs: T")
+ .Input("broadcast_dims: Tindices")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Output("lhs_output: T")
+ .Output("rhs_output: T")
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+Helper operator for performing XLA-style broadcasts
+
+Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to
+whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules
+for binary operators.
+
+lhs: the LHS input tensor
+rhs: the RHS input tensor
+broadcast_dims: an XLA-style broadcast dimension specification
+lhs_output: the broadcasted LHS tensor
+rhs_output: the broadcasted RHS tensor
+)doc");
+
+REGISTER_OP("XlaConv")
+ .Input("lhs: T")
+ .Input("rhs: T")
+ .Input("window_strides: Tindices")
+ .Input("padding: Tindices")
+ .Input("lhs_dilation: Tindices")
+ .Input("rhs_dilation: Tindices")
+ .Input("feature_group_count: Tindices")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("dimension_numbers: string")
+ .Attr("precision_config: string")
+ .Output("output: T")
+ .SetShapeFn(UnchangedRank)
+ .Doc(R"doc(
+Wraps the XLA ConvGeneralDilated operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
+.
+
+lhs: the input tensor
+rhs: the kernel tensor
+window_strides: the inter-window strides
+padding: the padding to apply at the start and end of each input dimensions
+lhs_dilation: dilation to apply between input elements
+rhs_dilation: dilation to apply between kernel elements
+feature_group_count: number of feature groups for grouped convolution.
+dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
+precision_config: a serialized xla::PrecisionConfigProto proto.
+)doc");
+
+REGISTER_OP("XlaDot")
+ .Input("lhs: T")
+ .Input("rhs: T")
+ .Attr("T: numbertype")
+ .Attr("dimension_numbers: string")
+ .Attr("precision_config: string")
+ .Output("output: T")
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+Wraps the XLA ConvGeneralDilated operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
+.
+
+lhs: the LHS tensor
+rhs: the RHS tensor
+dimension_numbers: a serialized xla::DotDimensionNumbers proto.
+precision_config: a serialized xla::PrecisionConfigProto proto.
+)doc");
REGISTER_OP("XlaDynamicUpdateSlice")
.Input("input: T")
@@ -73,6 +159,29 @@ else_branch: A function takes 'inputs' and returns a list of tensors.
whose types are the same as what then_branch returns.
)doc");
+REGISTER_OP("XlaPad")
+ .Input("input: T")
+ .Input("padding_value: T")
+ .Input("padding_low: Tindices")
+ .Input("padding_high: Tindices")
+ .Input("padding_interior: Tindices")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(UnchangedRank)
+ .Doc(R"doc(
+Wraps the XLA Pad operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#pad
+.
+
+input: A `Tensor` of type T.
+padding_value: A scalar `Tensor` of type T.
+padding_low: the padding to apply at the start of each input dimensions
+padding_high: the padding to apply at the end of each input dimension.
+padding_interior: the padding to apply between each input element.
+output: A `Tensor` of type T.
+)doc");
+
REGISTER_OP("XlaRecv")
.Output("tensor: dtype")
.Attr("dtype: type")
@@ -98,17 +207,58 @@ tensor_name: A string key that identifies the channel.
shape: The shape of the tensor.
)doc");
+REGISTER_OP("XlaReduce")
+ .Input("input: T")
+ .Input("init_value: T")
+ .Attr("T: numbertype")
+ .Attr("dimensions_to_reduce: list(int)")
+ .Attr("reducer: func")
+ .Output("output: T")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ if (c->RankKnown(c->input(0))) {
+ int rank = c->Rank(c->input(0));
+ std::vector<int64> dimensions_to_reduce;
+ TF_RETURN_IF_ERROR(
+ c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
+ std::set<int64> dims_set(dimensions_to_reduce.begin(),
+ dimensions_to_reduce.end());
+ auto dim_in_range = [rank](int64 dim) {
+ return dim >= 0 && dim < rank;
+ };
+ if (rank < dimensions_to_reduce.size() ||
+ dims_set.size() != dimensions_to_reduce.size() ||
+ !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
+ return errors::InvalidArgument(
+ "Invalid dimensions_to_reduce argument to XlaReduce");
+ }
+ c->set_output(
+ 0, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
+ } else {
+ c->set_output(0, c->input(0));
+ }
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Wraps the XLA Reduce operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#reduce .
+
+input: the input tensor
+init_value: a scalar representing the initial value for the reduction
+reducer: a reducer function to apply
+dimensions_to_reduce: dimension numbers over which to reduce
+)doc");
+
REGISTER_OP("XlaReduceWindow")
.Input("input: T")
.Input("init_value: T")
+ .Input("window_dimensions: Tindices")
+ .Input("window_strides: Tindices")
+ .Input("padding: Tindices")
.Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
.Attr("computation: func")
- .Attr("window_dimensions: list(int)")
- .Attr("window_strides: list(int)")
- .Attr("padding_low: list(int)")
- .Attr("padding_high: list(int)")
.Output("output: T")
- .SetShapeFn(shape_inference::UnknownShape)
+ .SetShapeFn(UnchangedRank)
.Doc(R"doc(
Wraps the XLA ReduceWindow operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
@@ -118,8 +268,35 @@ init_value: a scalar representing the initial value for the reduction
computation: a reducer function to apply
window_dimensions: the shape of the window
window_strides: the inter-window strides
-padding_low: the padding to apply at the start of each input dimensions
-padding_high: the padding to apply at the end of each input dimension.
+padding: the padding to apply at the start and end of each input dimensions
+)doc");
+
+REGISTER_OP("XlaSelectAndScatter")
+ .Input("operand: T")
+ .Input("window_dimensions: Tindices")
+ .Input("window_strides: Tindices")
+ .Input("padding: Tindices")
+ .Input("source: T")
+ .Input("init_value: T")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("select: func")
+ .Attr("scatter: func")
+ .Output("output: T")
+ .SetShapeFn(UnchangedRank)
+ .Doc(R"doc(
+Wraps the XLA SelectAndScatter operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter
+.
+
+operand: the input tensor
+window_dimensions: the shape of the window
+window_strides: the inter-window strides
+padding: the padding to apply at the start and end of each input dimensions
+source: a tensor of values to scatter
+init_value: a scalar representing the initial value for the output tensor
+select: a selection function to apply
+scatter: a scatter function to apply
)doc");
REGISTER_OP("XlaSend")
@@ -179,4 +356,5 @@ body: A function that takes a list of tensors and returns another
list of tensors. Both lists have the same types as specified by T.
)doc");
+} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD
index 42b6292f79..69ca394360 100644
--- a/tensorflow/compiler/tf2xla/python/BUILD
+++ b/tensorflow/compiler/tf2xla/python/BUILD
@@ -28,5 +28,6 @@ py_library(
srcs = ["xla.py"],
deps = [
"//tensorflow/compiler/tf2xla/ops:gen_xla_ops",
+ "//tensorflow/compiler/xla:xla_data_proto_py",
],
)
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index 2fc47dffb8..3626de375e 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -15,11 +15,12 @@
"""Experimental library that exposes XLA operations directly in TensorFlow.
It is sometimes useful to be able to build HLO programs directly from
-TensorFlow. This file provides Tensorflow operators that map as closely as
-possible to HLO operators.
+TensorFlow. This file provides Tensorflow operators that mirror the semantics of
+HLO operators as closely as possible.
-There is no promise of backward or forward compatibility for operators defined
-in this module.
+Note: There is no promise of backward or forward compatibility for operators
+defined in this module. This is primarily because the underlying HLO operators
+do not promise backward or forward compatibility.
"""
from __future__ import absolute_import
@@ -27,11 +28,298 @@ from __future__ import division
from __future__ import print_function
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import bitwise_ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+
+# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing
+# ops include:
+# infeed/outfeed (available via tf.contrib.tpu)
+# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu)
+# conditional
+# gather/scatter
+# collapse
+
+# This file reuses builtin names (following XLA's names, so we can call things
+# like xla.max), so we capture the builtin versions here.
+# pylint: disable=redefined-builtin
+_max = max
+_min = min
+_slice = slice # pylint: disable=invalid-name
+
+constant = constant_op.constant
+
+# Unary operators.
+
+# For most arithmetic operators there is a TensorFlow operator
+# that exactly corresponds to each XLA operator. Rather than defining
+# XLA-specific variants, we reuse the corresponding TensorFlow operator.
+# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1
+# wrap every HLO operator, because that would allow us to be confident that the
+# semantics match.
+
+
+def _unary_op(fn):
+ """Wrapper that restricts `fn` to have the correct signature."""
+
+ def unary_op_wrapper(x, name=None):
+ return fn(x, name=name)
+
+ return unary_op_wrapper
+
+
+abs = _unary_op(math_ops.abs)
+# TODO(phawkins): implement clz.
+conj = _unary_op(math_ops.conj)
+cos = _unary_op(math_ops.cos)
+ceil = _unary_op(math_ops.ceil)
+digamma = _unary_op(math_ops.digamma)
+erf = _unary_op(math_ops.erf)
+erfc = _unary_op(math_ops.erfc)
+# TODO(phawkins): implement erfinv
+exp = _unary_op(math_ops.exp)
+expm1 = _unary_op(math_ops.expm1)
+floor = _unary_op(math_ops.floor)
+imag = _unary_op(math_ops.imag)
+is_finite = _unary_op(math_ops.is_finite)
+lgamma = _unary_op(math_ops.lgamma)
+log = _unary_op(math_ops.log)
+log1p = _unary_op(math_ops.log1p)
+logical_not = _unary_op(math_ops.logical_not)
+neg = _unary_op(math_ops.neg)
+real = _unary_op(math_ops.real)
+# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for
+# numbers halfway between two integers.
+round = _unary_op(math_ops.round)
+sin = _unary_op(math_ops.sin)
+sign = _unary_op(math_ops.sign)
+tanh = _unary_op(math_ops.tanh)
+
+# Binary operators
+
+# The main difference between TensorFlow and XLA binary ops is the broadcasting
+# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA
+# requires an explicit specification of which dimensions to broadcast if the
+# arguments have different ranks.
+
+
+def _broadcasting_binary_op(fn):
+ """Wraps a binary Tensorflow operator and performs XLA-style broadcasting."""
+
+ def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None):
+ """Inner wrapper function."""
+ broadcast_dims = broadcast_dims or []
+ broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64)
+ # Rather than relying on having static shape information in the TensorFlow
+ # graph, we use an XlaBroadcastHelper op that can compute the correct shapes
+ # at JIT compilation time.
+ x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims)
+ return fn(x, y, name=name)
+
+ return broadcasting_binary_op_wrapper
+
+
+# Map from TF signed types to TF unsigned types.
+_SIGNED_TO_UNSIGNED_TABLE = {
+ dtypes.int8: dtypes.uint8,
+ dtypes.int16: dtypes.uint16,
+ dtypes.int32: dtypes.uint32,
+ dtypes.int64: dtypes.uint64,
+}
+
+# Map from TF unsigned types to TF signed types.
+_UNSIGNED_TO_SIGNED_TABLE = {
+ dtypes.uint8: dtypes.int8,
+ dtypes.uint16: dtypes.int16,
+ dtypes.uint32: dtypes.int32,
+ dtypes.uint64: dtypes.int64,
+}
+
+
+def _shift_right_logical_helper(x, y, name=None):
+ """Performs an integer right logical shift irrespective of input type."""
+ assert y.dtype == x.dtype
+ dtype = x.dtype
+ signed = dtype in _SIGNED_TO_UNSIGNED_TABLE
+ if signed:
+ unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype]
+ x = math_ops.cast(x, unsigned_dtype)
+ y = math_ops.cast(y, unsigned_dtype)
+ output = bitwise_ops.right_shift(x, y, name=name)
+ if signed:
+ output = math_ops.cast(output, dtype)
+ return output
+
+
+def _shift_right_arithmetic_helper(x, y, name=None):
+ """Performs an integer right arithmetic shift irrespective of input type."""
+ assert y.dtype == x.dtype
+ dtype = x.dtype
+ unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE
+ if unsigned:
+ signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype]
+ x = math_ops.cast(x, signed_dtype)
+ y = math_ops.cast(y, signed_dtype)
+ output = bitwise_ops.right_shift(x, y, name=name)
+ if unsigned:
+ output = math_ops.cast(output, dtype)
+ return output
+
+
+add = _broadcasting_binary_op(math_ops.add)
+sub = _broadcasting_binary_op(math_ops.sub)
+mul = _broadcasting_binary_op(math_ops.mul)
+div = _broadcasting_binary_op(math_ops.div)
+rem = _broadcasting_binary_op(gen_math_ops.mod)
+max = _broadcasting_binary_op(math_ops.maximum)
+min = _broadcasting_binary_op(math_ops.minimum)
+atan2 = _broadcasting_binary_op(math_ops.atan2)
+complex = _broadcasting_binary_op(math_ops.complex)
+logical_and = _broadcasting_binary_op(math_ops.logical_and)
+logical_or = _broadcasting_binary_op(math_ops.logical_or)
+logical_xor = _broadcasting_binary_op(math_ops.logical_xor)
+eq = _broadcasting_binary_op(math_ops.equal)
+ne = _broadcasting_binary_op(math_ops.not_equal)
+ge = _broadcasting_binary_op(math_ops.greater_equal)
+gt = _broadcasting_binary_op(math_ops.greater)
+le = _broadcasting_binary_op(math_ops.less_equal)
+lt = _broadcasting_binary_op(math_ops.less)
+pow = _broadcasting_binary_op(math_ops.pow)
+shift_left = _broadcasting_binary_op(bitwise_ops.left_shift)
+shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper)
+shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper)
+
+
+def _binary_op(fn):
+ """Wrapper that restricts `fn` to have the correct signature."""
+
+ def binary_op_wrapper(x, y, name=None):
+ return fn(x, y, name=name)
+
+ return binary_op_wrapper
+
+
+transpose = _binary_op(array_ops.transpose)
+rev = _binary_op(array_ops.reverse)
+
+bitcast_convert_type = array_ops.bitcast
+
+
+def broadcast(x, dims, name=None):
+ x = ops.convert_to_tensor(x)
+ shape = array_ops.concat(
+ [constant_op.constant(dims),
+ array_ops.shape(x)], axis=0)
+ return array_ops.broadcast_to(x, shape, name=name)
+
+
+def clamp(a, x, b, name=None):
+ return min(max(a, x, name=name), b, name=name)
+
+
+concatenate = array_ops.concat
+
+
+def conv(lhs,
+ rhs,
+ window_strides,
+ padding,
+ lhs_dilation,
+ rhs_dilation,
+ dimension_numbers,
+ feature_group_count=1,
+ precision_config=None,
+ name=None):
+ """Wraps the XLA ConvGeneralDilated operator.
+
+ ConvGeneralDilated is the most general form of XLA convolution and is
+ documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
+
+ Args:
+ lhs: the input tensor
+ rhs: the kernel tensor
+ window_strides: the inter-window strides
+ padding: the padding to apply at the start and end of each input dimensions
+ lhs_dilation: dilation to apply between input elements
+ rhs_dilation: dilation to apply between kernel elements
+ dimension_numbers: a `ConvolutionDimensionNumbers` proto.
+ feature_group_count: number of feature groups for grouped convolution.
+ precision_config: a `PrecisionConfigProto` proto.
+ name: an optional name for the operator
+
+ Returns:
+ A tensor representing the output of the convolution.
+ """
+ precision_config_proto = ""
+ if precision_config:
+ precision_config_proto = precision_config.SerializeToString()
+ return gen_xla_ops.xla_conv(
+ lhs,
+ rhs,
+ window_strides=window_strides,
+ padding=padding,
+ lhs_dilation=lhs_dilation,
+ rhs_dilation=rhs_dilation,
+ feature_group_count=feature_group_count,
+ dimension_numbers=dimension_numbers.SerializeToString(),
+ precision_config=precision_config_proto,
+ name=name)
+
+
+convert_element_type = math_ops.cast
+
+
+def dot(lhs, rhs, name=None):
+ return math_ops.tensordot(lhs, rhs, axes=1, name=name)
+
+
+def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None):
+ precision_config_proto = ""
+ if precision_config:
+ precision_config_proto = precision_config.SerializeToString()
+ return gen_xla_ops.xla_dot(
+ lhs,
+ rhs,
+ dimension_numbers=dimension_numbers.SerializeToString(),
+ precision_config=precision_config_proto,
+ name=name)
+
+
+def dynamic_slice(x, starts, sizes, name=None):
+ # TODO(phawkins): the Slice operator lowers to DynamicSlice if `starts` is not
+ # a compile-time constant. This doesn't exactly mimic the semantics of dynamic
+ # slice if the slice is out of bounds.
+ return array_ops.slice(x, starts, sizes, name=name)
-# TODO(phawkins): provide wrappers for all XLA operators.
dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
+# TODO(phawkins): generalize tf.pad to support interior padding, and then remove
+# the XLA-specific pad operator.
+pad = gen_xla_ops.xla_pad
+
+
+def random_normal(mu, sigma, dims, name=None):
+ mu = ops.convert_to_tensor(mu)
+ return random_ops.random_normal(
+ dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name)
+
+
+def random_uniform(minval, maxval, dims, name=None):
+ minval = ops.convert_to_tensor(minval)
+ return random_ops.random_uniform(
+ dims, minval, maxval, dtype=minval.dtype, name=name)
+
+
+recv = gen_xla_ops.xla_recv
+reduce = gen_xla_ops.xla_reduce
+
def reduce_window(operand,
init,
@@ -61,22 +349,38 @@ def reduce_window(operand,
"""
window_strides = window_strides or [1] * len(window_dimensions)
padding = padding or [(0, 0)] * len(window_dimensions)
- padding_low = [x for (x, _) in padding]
- padding_high = [y for (_, y) in padding]
return gen_xla_ops.xla_reduce_window(
- operand,
- init,
- reducer,
- window_dimensions,
- window_strides,
- padding_low,
- padding_high,
+ input=operand,
+ init_value=init,
+ window_dimensions=window_dimensions,
+ window_strides=window_strides,
+ padding=padding,
+ computation=reducer,
name=name)
-recv = gen_xla_ops.xla_recv
+def reshape(x, new_sizes, dimensions=None, name=None):
+ if dimensions is not None:
+ x = array_ops.transpose(x, dimensions)
+ x = array_ops.reshape(x, new_sizes, name=name)
+ return x
+
+
+def select(condition, x, y, name=None):
+ return array_ops.where(condition, x, y, name)
+
+
+select_and_scatter = gen_xla_ops.xla_select_and_scatter
send = gen_xla_ops.xla_send
-sort = gen_xla_ops.xla_sort
+def slice(x, start_dims, limit_dims, strides):
+ spec = [
+ _slice(start, limit, stride)
+ for (start, limit, stride) in zip(start_dims, limit_dims, strides)
+ ]
+ return x[tuple(spec)]
+
+
+sort = gen_xla_ops.xla_sort
while_loop = gen_xla_ops.xla_while
diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc
new file mode 100644
index 0000000000..32ba6df2e6
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc
@@ -0,0 +1,130 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
+#include "absl/algorithm/container.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+
+namespace tensorflow {
+/*static*/ StringPiece XlaResourceOpInfo::XlaResourceOpKindToString(
+ XlaResourceOpKind op_kind) {
+ switch (op_kind) {
+ case XlaResourceOpKind::kRead:
+ return "Read";
+ case XlaResourceOpKind::kWrite:
+ return "Write";
+ case XlaResourceOpKind::kReadWrite:
+ return "Modify";
+ }
+}
+
+static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* CreateResourceOpInfoMap() {
+ gtl::FlatMap<StringPiece, XlaResourceOpInfo>* result =
+ new gtl::FlatMap<StringPiece, XlaResourceOpInfo>;
+
+ auto add = [&](StringPiece op, XlaResourceOpKind op_kind,
+ XlaResourceKind resource_kind) {
+ auto insert_result =
+ result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)});
+ CHECK(insert_result.second);
+ };
+
+ auto kRead = XlaResourceOpKind::kRead;
+ auto kWrite = XlaResourceOpKind::kWrite;
+ auto kReadWrite = XlaResourceOpKind::kReadWrite;
+
+ auto kVariable = XlaResourceKind::kVariable;
+ auto kStack = XlaResourceKind::kStack;
+ auto kTensorArray = XlaResourceKind::kTensorArray;
+
+ // clang-format off
+ add("AssignAddVariableOp" , kReadWrite, kVariable);
+ add("AssignSubVariableOp" , kReadWrite, kVariable);
+ add("AssignVariableOp" , kWrite, kVariable);
+ add("ReadVariableOp" , kRead, kVariable);
+ add("ResourceApplyAdaMax" , kReadWrite, kVariable);
+ add("ResourceApplyAdadelta" , kReadWrite, kVariable);
+ add("ResourceApplyAdagrad" , kReadWrite, kVariable);
+ add("ResourceApplyAdagradDA" , kReadWrite, kVariable);
+ add("ResourceApplyAdam" , kReadWrite, kVariable);
+ add("ResourceApplyAddSign" , kReadWrite, kVariable);
+ add("ResourceApplyCenteredRMSProp" , kReadWrite, kVariable);
+ add("ResourceApplyFtrl" , kReadWrite, kVariable);
+ add("ResourceApplyFtrlV2" , kReadWrite, kVariable);
+ add("ResourceApplyGradientDescent" , kReadWrite, kVariable);
+ add("ResourceApplyMomentum" , kReadWrite, kVariable);
+ add("ResourceApplyPowerSign" , kReadWrite, kVariable);
+ add("ResourceApplyProximalAdagrad" , kReadWrite, kVariable);
+ add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable);
+ add("ResourceApplyRMSProp" , kReadWrite, kVariable);
+ add("ResourceGather" , kRead, kVariable);
+ add("ResourceScatterAdd" , kReadWrite, kVariable);
+ add("ResourceScatterDiv" , kReadWrite, kVariable);
+ add("ResourceScatterMax" , kReadWrite, kVariable);
+ add("ResourceScatterMin" , kReadWrite, kVariable);
+ add("ResourceScatterMul" , kReadWrite, kVariable);
+ add("ResourceScatterNdAdd" , kReadWrite, kVariable);
+ add("ResourceScatterNdUpdate" , kReadWrite, kVariable);
+ add("ResourceScatterSub" , kReadWrite, kVariable);
+ add("ResourceScatterUpdate" , kReadWrite, kVariable);
+ add("ResourceStridedSliceAssign" , kReadWrite, kVariable);
+ add("VarIsInitializedOp" , kRead, kVariable);
+ add("VariableShape" , kRead, kVariable);
+
+ add("StackV2" , kWrite, kStack);
+ add("StackCloseV2" , kRead, kStack);
+ add("StackPopV2" , kReadWrite, kStack);
+ add("StackPushV2" , kReadWrite, kStack);
+
+ add("TensorArrayV3" , kWrite, kTensorArray);
+ add("TensorArrayConcatV3" , kRead, kTensorArray);
+ add("TensorArrayGatherV3" , kRead, kTensorArray);
+ add("TensorArrayScatterV3" , kWrite, kTensorArray);
+ add("TensorArrayGradV3" , kRead, kTensorArray);
+ add("TensorArrayCloseV3" , kRead, kTensorArray);
+ add("TensorArrayReadV3" , kRead, kTensorArray);
+ add("TensorArraySizeV3" , kRead, kTensorArray);
+ add("TensorArraySplitV3" , kWrite, kTensorArray);
+ add("TensorArrayWriteV3" , kWrite, kTensorArray);
+ // clang-format on
+
+ return result;
+}
+
+static const gtl::FlatMap<StringPiece, XlaResourceOpInfo>&
+GetStaticResourceOpInfoMap() {
+ static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* op_info_map =
+ CreateResourceOpInfoMap();
+ return *op_info_map;
+}
+
+const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op) {
+ const gtl::FlatMap<StringPiece, XlaResourceOpInfo>& op_infos =
+ GetStaticResourceOpInfoMap();
+ auto it = op_infos.find(op);
+ return it == op_infos.end() ? nullptr : &it->second;
+}
+
+namespace resource_op_table_internal {
+std::vector<StringPiece> GetKnownResourceOps() {
+ std::vector<StringPiece> result;
+ for (const auto& p : GetStaticResourceOpInfoMap()) {
+ result.push_back(p.first);
+ }
+ absl::c_sort(result);
+ return result;
+}
+} // namespace resource_op_table_internal
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.h b/tensorflow/compiler/tf2xla/resource_operation_table.h
new file mode 100644
index 0000000000..7f627a64c6
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/resource_operation_table.h
@@ -0,0 +1,71 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_
+#define TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/logging.h"
+
+// Exposes information about the resource operations supported by tf2xla in a
+// structured form.
+
+namespace tensorflow {
+enum class XlaResourceOpKind {
+ kRead, // Only reads from resources.
+ kWrite, // Only writes to resources.
+ kReadWrite // Reads from and writes to resources.
+};
+
+enum class XlaResourceKind {
+ kVariable, // Operates on resource variables.
+ kStack, // Operates on stacks.
+ kTensorArray // Operates on tensor arrays.
+};
+
+class XlaResourceOpInfo {
+ public:
+ explicit XlaResourceOpInfo(XlaResourceOpKind op_kind,
+ XlaResourceKind resource_kind)
+ : op_kind_(op_kind), resource_kind_(resource_kind) {}
+
+ XlaResourceOpKind kind() const { return op_kind_; }
+ XlaResourceKind resource_kind() const { return resource_kind_; }
+
+ static StringPiece XlaResourceOpKindToString(XlaResourceOpKind op_kind);
+
+ private:
+ XlaResourceOpKind op_kind_;
+ XlaResourceKind resource_kind_;
+};
+
+// Returns a XlaResourceOpInfo describing `op` if it is a resource operation
+// supported by tf2xla, otherwise returns null (i.e. if this returns null then
+// `op` is either not a resource operation or is unsupported by XLA).
+const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op);
+
+namespace resource_op_table_internal {
+// NB! Implementation detail exposed for unit testing, do not use.
+//
+// Returns the set of resource operations known by this module.
+std::vector<StringPiece> GetKnownResourceOps();
+} // namespace resource_op_table_internal
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_
diff --git a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc
new file mode 100644
index 0000000000..0343f80de9
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc
@@ -0,0 +1,66 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
+
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+bool IsResourceArgDef(const OpDef::ArgDef& arg_def) {
+ return arg_def.type() == DT_RESOURCE;
+}
+
+bool HasResourceInputOrOutput(const OpDef& op_def) {
+ return absl::c_any_of(op_def.input_arg(), IsResourceArgDef) ||
+ absl::c_any_of(op_def.output_arg(), IsResourceArgDef);
+}
+
+TEST(ResourceOperationTableTest, HaveAllResourceOps) {
+ gtl::FlatMap<string, bool> known_resource_ops;
+ for (StringPiece known_resource_op :
+ resource_op_table_internal::GetKnownResourceOps()) {
+ ASSERT_TRUE(
+ known_resource_ops.insert({string(known_resource_op), false}).second);
+ }
+
+ std::vector<string> xla_op_names = XlaOpRegistry::GetAllRegisteredOps();
+ for (const string& xla_op_name : xla_op_names) {
+ const OpDef* op_def;
+ TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef(xla_op_name, &op_def));
+ if (HasResourceInputOrOutput(*op_def)) {
+ EXPECT_EQ(known_resource_ops.count(xla_op_name), 1)
+ << "Unknown resource op " << xla_op_name;
+ known_resource_ops[xla_op_name] = true;
+ }
+ }
+
+ std::vector<string> unnecessary_resource_ops;
+ for (const auto& pair : known_resource_ops) {
+ if (!pair.second) {
+ unnecessary_resource_ops.push_back(pair.first);
+ }
+ }
+
+ EXPECT_TRUE(unnecessary_resource_ops.empty())
+ << "Stale resource ops:\n"
+ << absl::StrJoin(unnecessary_resource_ops, "\n");
+}
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc
index 66835e69b2..2d7eb8b915 100644
--- a/tensorflow/compiler/tf2xla/sharding_util.cc
+++ b/tensorflow/compiler/tf2xla/sharding_util.cc
@@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/sharding_util.h"
+#include "absl/strings/match.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/util/device_name_utils.h"
@@ -65,8 +65,8 @@ xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
if (explicit_sharding.has_value()) {
return explicit_sharding;
} else if (!parsed_device.has_type || !parsed_device.has_id ||
- !str_util::StrContains(parsed_device.type,
- kDeviceSuffixReplicatedCore)) {
+ !absl::StrContains(parsed_device.type,
+ kDeviceSuffixReplicatedCore)) {
return absl::optional<xla::OpSharding>();
} else {
const int core = parsed_device.id;
diff --git a/tensorflow/compiler/tf2xla/str_util.cc b/tensorflow/compiler/tf2xla/str_util.cc
deleted file mode 100644
index 2b0834fe7b..0000000000
--- a/tensorflow/compiler/tf2xla/str_util.cc
+++ /dev/null
@@ -1,44 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/tf2xla/str_util.h"
-
-#include <string>
-#include <utility>
-#include <vector>
-
-namespace tensorflow {
-namespace str_util {
-
-static void ReplaceAll(string* text, StringPiece from, StringPiece to) {
- size_t pos = 0;
- while ((pos = text->find(from.data(), pos, from.size())) != string::npos) {
- text->replace(pos, from.size(), to.data(), to.size());
- pos += to.size();
- if (from.empty()) {
- pos++; // Match at the beginning of the text and after every byte
- }
- }
-}
-
-void ReplaceAllPairs(string* text,
- const std::vector<std::pair<string, string>>& replace) {
- for (const std::pair<string, string>& from_to : replace) {
- ReplaceAll(text, from_to.first, from_to.second);
- }
-}
-
-} // namespace str_util
-} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/str_util.h b/tensorflow/compiler/tf2xla/str_util.h
deleted file mode 100644
index 51f25009d7..0000000000
--- a/tensorflow/compiler/tf2xla/str_util.h
+++ /dev/null
@@ -1,42 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// String utilities that are esoteric enough that they don't belong in
-// third_party/tensorflow/core/lib/strings/str_util.h, but are still generally
-// useful under xla.
-
-#ifndef TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_
-#define TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_
-
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "tensorflow/core/lib/core/stringpiece.h"
-
-namespace tensorflow {
-namespace str_util {
-
-// Replace all non-overlapping occurrences of the given (from,to) pairs in-place
-// in text. If from is empty, it matches at the beginning of the text and after
-// every byte. Each (from,to) replacement pair is processed in the order it is
-// given.
-void ReplaceAllPairs(string* text,
- const std::vector<std::pair<string, string>>& replace);
-
-} // namespace str_util
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/str_util_test.cc b/tensorflow/compiler/tf2xla/str_util_test.cc
deleted file mode 100644
index 8817f6902a..0000000000
--- a/tensorflow/compiler/tf2xla/str_util_test.cc
+++ /dev/null
@@ -1,60 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/tf2xla/str_util.h"
-
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace str_util {
-
-class ReplaceAllPairsTest : public ::testing::Test {
- protected:
- void ExpectReplaceAllPairs(
- string text, const std::vector<std::pair<string, string>>& replace,
- StringPiece want) {
- ReplaceAllPairs(&text, replace);
- EXPECT_EQ(text, want);
- }
-};
-
-TEST_F(ReplaceAllPairsTest, Simple) {
- ExpectReplaceAllPairs("", {}, "");
- ExpectReplaceAllPairs("", {{"", ""}}, "");
- ExpectReplaceAllPairs("", {{"", "X"}}, "X");
- ExpectReplaceAllPairs("", {{"", "XYZ"}}, "XYZ");
- ExpectReplaceAllPairs("", {{"", "XYZ"}, {"", "_"}}, "_X_Y_Z_");
- ExpectReplaceAllPairs("", {{"", "XYZ"}, {"", "_"}, {"_Y_", "a"}}, "_XaZ_");
- ExpectReplaceAllPairs("banana", {}, "banana");
- ExpectReplaceAllPairs("banana", {{"", ""}}, "banana");
- ExpectReplaceAllPairs("banana", {{"", "_"}}, "_b_a_n_a_n_a_");
- ExpectReplaceAllPairs("banana", {{"", "__"}}, "__b__a__n__a__n__a__");
- ExpectReplaceAllPairs("banana", {{"a", "a"}}, "banana");
- ExpectReplaceAllPairs("banana", {{"a", ""}}, "bnn");
- ExpectReplaceAllPairs("banana", {{"a", "X"}}, "bXnXnX");
- ExpectReplaceAllPairs("banana", {{"a", "XX"}}, "bXXnXXnXX");
- ExpectReplaceAllPairs("banana", {{"a", "XX"}, {"XnX", "z"}}, "bXzzX");
- ExpectReplaceAllPairs("a{{foo}}b{{bar}}c{{foo}}",
- {{"{{foo}}", "0"}, {"{{bar}}", "123456789"}},
- "a0b123456789c0");
-}
-
-} // namespace str_util
-} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc
index 48568c825b..f34af2d67d 100644
--- a/tensorflow/compiler/tf2xla/tf2xla.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
@@ -40,7 +41,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -197,8 +197,8 @@ Status RewriteAndPruneGraph(
if (!missing_feeds.empty() || !missing_fetches.empty()) {
return errors::Aborted(
"Post graph-pruning",
- ", missing feeds: ", str_util::Join(missing_feeds, ", "),
- ", missing fetches: ", str_util::Join(missing_fetches, ", "));
+ ", missing feeds: ", absl::StrJoin(missing_feeds, ", "),
+ ", missing fetches: ", absl::StrJoin(missing_fetches, ", "));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc
index 7aca889a26..567d212b5e 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc
@@ -20,11 +20,11 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/util/command_line_flags.h"
@@ -54,10 +54,10 @@ void PrintSupportedOps(const string& device, const string& regen_run) {
}
std::sort(types.begin(), types.end());
constraints.push_back("`" + constraint.name() + "={" +
- str_util::Join(types, ",") + "}`");
+ absl::StrJoin(types, ",") + "}`");
}
std::cout << "`" << kdef->op() << "` | "
- << str_util::Join(constraints, "<br>") << std::endl;
+ << absl::StrJoin(constraints, "<br>") << std::endl;
}
std::cout << "\nTo regenerate this table, run:\n\n```shell\n"
@@ -76,7 +76,7 @@ void SupportedOpsMain(int argc, char** argv, const char* regen_run) {
{"device", &device,
"Name of the compilation device for which to print supported ops, "
"one of: " +
- str_util::Join(device_names, ",")},
+ absl::StrJoin(device_names, ",")},
};
string usage = Flags::Usage(argv[0], flag_list);
bool parsed_flags_ok = Flags::Parse(&argc, argv, flag_list);
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc
index ae51446204..2b1f724dc7 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/data_flow_ops.h"
#include "tensorflow/cc/ops/function_ops.h"
@@ -25,16 +26,15 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
-void ExpectErrorContains(const Status& status, StringPiece str) {
+void ExpectErrorContains(const Status& status, absl::string_view str) {
EXPECT_NE(Status::OK(), status);
- EXPECT_TRUE(str_util::StrContains(status.error_message(), str))
+ EXPECT_TRUE(absl::StrContains(status.error_message(), str))
<< "expected error: " << status.error_message() << " to contain: " << str;
}
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 7227df9649..6e5a0198f6 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/data_flow_ops.h"
#include "tensorflow/cc/ops/function_ops.h"
@@ -38,7 +39,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/version.h"
@@ -309,10 +309,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
std::move(graph), args, &result);
EXPECT_FALSE(status.ok());
EXPECT_TRUE(
- str_util::StrContains(status.error_message(), "depends on a parameter"))
+ absl::StrContains(status.error_message(), "depends on a parameter"))
<< status.error_message();
EXPECT_TRUE(
- str_util::StrContains(status.error_message(), "[[{{node C}} = Reshape"))
+ absl::StrContains(status.error_message(), "[[{{node C}} = Reshape"))
<< status.error_message();
}
@@ -727,8 +727,7 @@ TEST_F(XlaCompilerTest, UndefinedFunctionFails) {
compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr,
/*args=*/{}, &result);
EXPECT_FALSE(status.ok());
- EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()),
- "is not defined."))
+ EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined."))
<< status.error_message();
}
@@ -807,12 +806,10 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
ASSERT_FALSE(status.ok());
// Flib lookup failure.
- EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()),
- "is not defined."))
+ EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined."))
<< status.error_message();
// Local flib lookup failure.
- EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()),
- "Attr T is not found"))
+ EXPECT_TRUE(absl::StrContains(status.error_message(), "Attr T is not found"))
<< status.error_message();
}
@@ -1078,9 +1075,9 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
std::move(graph), args, &result);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(str_util::StrContains(status.error_message(), "InvalidOp"))
+ EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp"))
<< status.error_message();
- EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node fill_fn}}"))
+ EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node fill_fn}}"))
<< status.error_message();
}
@@ -1103,10 +1100,10 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) {
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type",
std::move(graph), args, &result);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(str_util::StrContains(status.error_message(),
- "is not in the list of allowed values"))
+ EXPECT_TRUE(absl::StrContains(status.error_message(),
+ "is not in the list of allowed values"))
<< status.error_message();
- EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node Shape}}"))
+ EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node Shape}}"))
<< status.error_message();
}
@@ -1130,9 +1127,9 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
std::move(graph_copy), args, &result);
ASSERT_FALSE(status.ok());
EXPECT_TRUE(
- str_util::StrContains(status.error_message(),
- "The following nodes are unreachable "
- "from the source in the graph: {{node NoOp}}"))
+ absl::StrContains(status.error_message(),
+ "The following nodes are unreachable "
+ "from the source in the graph: {{node NoOp}}"))
<< status.error_message();
}
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 31a41f8719..9e8f5f2a1a 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -99,6 +99,25 @@ Status XlaOpKernelContext::ConstantInput(int index,
index, context_->input(index).shape().dim_sizes(), constant_literal);
}
+static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context,
+ StringPiece name) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued input name '",
+ name,
+ "' when single-valued input was "
+ "expected");
+ }
+ return start;
+}
+
+Status XlaOpKernelContext::ConstantInput(StringPiece name,
+ xla::Literal* constant_literal) {
+ TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
+ return ConstantInput(index, constant_literal);
+}
+
Status XlaOpKernelContext::ConstantInputReshaped(
int index, gtl::ArraySlice<int64> new_dims,
xla::Literal* constant_literal) {
@@ -246,6 +265,12 @@ Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) {
return LiteralToInt64Scalar(literal, out);
}
+Status XlaOpKernelContext::ConstantInputAsIntScalar(StringPiece name,
+ int64* out) {
+ TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
+ return ConstantInputAsIntScalar(index, out);
+}
+
Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
@@ -280,6 +305,12 @@ Status XlaOpKernelContext::ConstantInputAsIntVector(int index,
return LiteralToInt64Vector(literal, out);
}
+Status XlaOpKernelContext::ConstantInputAsIntVector(StringPiece name,
+ std::vector<int64>* out) {
+ TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
+ return ConstantInputAsIntVector(index, out);
+}
+
Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
int index, std::vector<int64>* out) {
xla::Literal literal;
@@ -313,6 +344,12 @@ Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
}
}
+Status XlaOpKernelContext::ConstantInputAsInt64Literal(StringPiece name,
+ xla::Literal* out) {
+ TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
+ return ConstantInputAsInt64Literal(index, out);
+}
+
// TODO(phawkins): validate that the dimensions form a valid shape, fail
// gracefully if they do not.
Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 3f21a2bf41..3e26ba4f01 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -106,6 +106,7 @@ class XlaOpKernelContext {
// expression cannot be evaluated, e.g., because it depends on unbound
// parameters, returns a non-OK status.
Status ConstantInput(int index, xla::Literal* constant_literal);
+ Status ConstantInput(StringPiece name, xla::Literal* constant_literal);
// Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
// InputShape(index), and stores it in `*constant_literal`. If the input
@@ -117,12 +118,14 @@ class XlaOpKernelContext {
// Converts a constant scalar int32 or int64 tensor into an int64.
Status ConstantInputAsIntScalar(int index, int64* out);
+ Status ConstantInputAsIntScalar(StringPiece name, int64* out);
// Converts a constant scalar float32 or float64 tensor into a float64.
Status ConstantInputAsFloatScalar(int index, double* out);
// Converts a constant 1D int32 or int64 tensor into a vector of int64s.
Status ConstantInputAsIntVector(int index, std::vector<int64>* out);
+ Status ConstantInputAsIntVector(StringPiece name, std::vector<int64>* out);
// Reshapes and converts a constant int32 or int64 tensor into a vector of
// int64s.
@@ -130,6 +133,7 @@ class XlaOpKernelContext {
// Converts a constant int32 or int64 Tensor into an xla int64 Literal.
Status ConstantInputAsInt64Literal(int index, xla::Literal* out);
+ Status ConstantInputAsInt64Literal(StringPiece name, xla::Literal* out);
// Converts a constant 1D int32 or int64 tensor into a TensorShape.
Status ConstantInputAsShape(int index, TensorShape* shape);
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index 46785bc1f0..e25c7e8c9e 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -325,6 +325,17 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
return kernels;
}
+/*static*/ std::vector<string> XlaOpRegistry::GetAllRegisteredOps() {
+ std::vector<string> ops;
+ XlaOpRegistry& registry = Instance();
+ mutex_lock lock(registry.mutex_);
+ for (const auto& pair : registry.ops_) {
+ ops.push_back(pair.first);
+ }
+ std::sort(ops.begin(), ops.end());
+ return ops;
+}
+
/* static */ const std::unordered_set<string>*
XlaOpRegistry::CompileTimeConstantInputs(const string& op) {
XlaOpRegistry& registry = Instance();
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index fc14834ca6..6ce0e2580b 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -128,6 +128,9 @@ class XlaOpRegistry {
const string& compilation_device_name,
bool include_compilation_only_kernels);
+ // Returns all operations for which there are XLA kernels on any device.
+ static std::vector<string> GetAllRegisteredOps();
+
// Returns the set of compile-time constant inputs to 'op'. Returns nullptr
// if the op is not registered.
static const std::unordered_set<string>* CompileTimeConstantInputs(
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 1a8fa627a0..26bd1ac4f7 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -113,6 +113,7 @@ cc_library(
":statusor",
":types",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -173,6 +174,7 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
],
)
@@ -237,11 +239,11 @@ cc_library(
":types",
":util",
":xla_data_proto",
- "//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
"@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
@@ -259,6 +261,7 @@ tf_cc_test(
":xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
],
)
@@ -301,6 +304,7 @@ cc_library(
":xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -320,6 +324,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -341,6 +346,7 @@ cc_library(
":xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -359,6 +365,7 @@ cc_library(
":literal_util",
":util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -370,6 +377,7 @@ cc_library(
deps = [
":util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -379,8 +387,8 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":types",
- "//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
],
)
@@ -391,6 +399,7 @@ cc_library(
":status",
":types",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -413,6 +422,7 @@ cc_library(
":types",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -457,6 +467,7 @@ cc_library(
":array2d",
":types",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -510,6 +521,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -529,6 +541,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -559,6 +572,7 @@ cc_library(
":types",
":xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -629,6 +643,7 @@ cc_library(
":types",
":xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h
index 2d5d078aa7..c8e483712e 100644
--- a/tensorflow/compiler/xla/array.h
+++ b/tensorflow/compiler/xla/array.h
@@ -27,12 +27,12 @@ limitations under the License.
#include <type_traits>
#include <vector>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -507,9 +507,7 @@ class Array {
}
}
- pieces.push_back(
- tensorflow::strings::AlphaNum(values_[calculate_index(index)])
- .data());
+ pieces.push_back(absl::StrCat(values_[calculate_index(index)]));
// Emit comma if it isn't the last element
if (index.back() != sizes_.back() - 1) {
@@ -527,7 +525,7 @@ class Array {
}
}
} while (next_index(&index));
- return tensorflow::str_util::Join(pieces, "");
+ return absl::StrJoin(pieces, "");
}
private:
diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h
index 340f94fab7..782c966b4c 100644
--- a/tensorflow/compiler/xla/array2d.h
+++ b/tensorflow/compiler/xla/array2d.h
@@ -25,11 +25,10 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/bits.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h
index a75fffc605..14e7bf1814 100644
--- a/tensorflow/compiler/xla/array4d.h
+++ b/tensorflow/compiler/xla/array4d.h
@@ -26,12 +26,11 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index c8b2a1ac73..9ad8ee2014 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -77,6 +77,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -90,6 +91,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
@@ -216,6 +218,7 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index 25608d6616..1fdf8f6260 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
@@ -26,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -400,7 +400,7 @@ StatusOr<string> Client::ExecutionStatsAsString(
int64 nanoseconds = profile.compute_time_ns();
int64 cycle_count = profile.compute_cycle_count();
double gflops = total_flops / nanoseconds;
- return tensorflow::strings::StrCat(
+ return absl::StrCat(
"[Execution Statistics] flop count: ", computation_stats.flop_count(),
", transcendental count: ", computation_stats.transcendental_count(),
", compute execution time: ", nanoseconds, " nsec",
diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc
index b6012a0352..040344c9a6 100644
--- a/tensorflow/compiler/xla/client/compile_only_client.cc
+++ b/tensorflow/compiler/xla/client/compile_only_client.cc
@@ -41,7 +41,7 @@ CompileOnlyClient::CompileAheadOfTime(
metadata);
}
-int64 CompileOnlyClient::PointerSizeForTriple(tensorflow::StringPiece triple) {
+int64 CompileOnlyClient::PointerSizeForTriple(absl::string_view triple) {
llvm::Triple llvm_triple(
llvm::Triple::normalize(llvm::StringRef(triple.data(), triple.size())));
if (llvm_triple.isArch64Bit()) {
diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h
index a551edeab0..d0c83cbfcc 100644
--- a/tensorflow/compiler/xla/client/compile_only_client.h
+++ b/tensorflow/compiler/xla/client/compile_only_client.h
@@ -57,7 +57,7 @@ class CompileOnlyClient : public Client {
std::unique_ptr<AotCompilationMetadata>* metadata = nullptr);
// Returns the size of a pointer in bytes for a given triple.
- static int64 PointerSizeForTriple(tensorflow::StringPiece triple);
+ static int64 PointerSizeForTriple(absl::string_view triple);
private:
CompileOnlyService* compiler_service_;
diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc
index 5fe28c33df..5a73408db5 100644
--- a/tensorflow/compiler/xla/client/executable_build_options.cc
+++ b/tensorflow/compiler/xla/client/executable_build_options.cc
@@ -77,7 +77,7 @@ const absl::optional<string>& ExecutableBuildOptions::generate_hlo_graph()
}
ExecutableBuildOptions& ExecutableBuildOptions::set_dump_optimized_hlo_proto_to(
- tensorflow::StringPiece dirpath) {
+ absl::string_view dirpath) {
dump_optimized_hlo_proto_to_ = string(dirpath);
return *this;
}
@@ -89,8 +89,8 @@ ExecutableBuildOptions::dump_optimized_hlo_proto_to() const {
ExecutableBuildOptions&
ExecutableBuildOptions::set_dump_unoptimized_hlo_proto_to(
- tensorflow::StringPiece dirpath) {
- dump_unoptimized_hlo_proto_to_ = dirpath.ToString();
+ absl::string_view dirpath) {
+ dump_unoptimized_hlo_proto_to_ = string(dirpath);
return *this;
}
@@ -100,7 +100,7 @@ ExecutableBuildOptions::dump_unoptimized_hlo_proto_to() const {
}
ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to(
- tensorflow::StringPiece dirpath) {
+ absl::string_view dirpath) {
dump_per_pass_hlo_proto_to_ = string(dirpath);
return *this;
}
diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h
index 904d230981..888d2f28eb 100644
--- a/tensorflow/compiler/xla/client/executable_build_options.h
+++ b/tensorflow/compiler/xla/client/executable_build_options.h
@@ -16,11 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_
+#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
@@ -62,19 +62,19 @@ class ExecutableBuildOptions {
// If set, specifies a dirpath to dump the end-of-optimization-pipeline HLO
// protobuf to (as in DebugOptions).
ExecutableBuildOptions& set_dump_optimized_hlo_proto_to(
- tensorflow::StringPiece dirpath);
+ absl::string_view dirpath);
const absl::optional<string>& dump_optimized_hlo_proto_to() const;
// If set, specifies a dirpath to dump the start-of-optimization-pipeline HLO
// protobuf to (as in DebugOptions).
ExecutableBuildOptions& set_dump_unoptimized_hlo_proto_to(
- tensorflow::StringPiece dirpath);
+ absl::string_view dirpath);
const absl::optional<string>& dump_unoptimized_hlo_proto_to() const;
// If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs
// to (as in DebugOptions).
ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to(
- tensorflow::StringPiece dirpath);
+ absl::string_view dirpath);
const absl::optional<string>& dump_per_pass_hlo_proto_to() const;
// If true, specifies that we should record an HLO profile during execution
@@ -83,7 +83,7 @@ class ExecutableBuildOptions {
ExecutableBuildOptions& set_hlo_profile(bool enabled);
absl::optional<bool> hlo_profile() const;
- void add_disabled_hlo_pass(tensorflow::StringPiece pass_name) {
+ void add_disabled_hlo_pass(absl::string_view pass_name) {
disabled_hlo_passes_.push_back(std::string(pass_name));
}
const tensorflow::gtl::ArraySlice<std::string> disabled_hlo_passes() const {
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index 4d233741bd..8736f18dcf 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -31,7 +31,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -221,5 +221,6 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc
index 9225b1acd6..e86c10f030 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <string>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
@@ -24,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace {
@@ -39,7 +39,7 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
b = builder->CreateSubBuilder(name);
} else {
b = builder->CreateSubBuilder(
- tensorflow::strings::StrCat(name, "_", PrimitiveType_Name(type)));
+ absl::StrCat(name, "_", PrimitiveType_Name(type)));
}
const Shape scalar = ShapeUtil::MakeShape(type, {});
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc
index 081fec7ad9..6861521acc 100644
--- a/tensorflow/compiler/xla/client/lib/testing.cc
+++ b/tensorflow/compiler/xla/client/lib/testing.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/testing.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/literal.h"
@@ -23,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -61,8 +61,7 @@ XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) {
std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
Client* client) {
- XlaBuilder b(
- tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape)));
+ XlaBuilder b(absl::StrCat("make_fake_", ShapeUtil::HumanString(shape)));
BuildFakeDataOpOnDevice(shape, &b);
XlaComputation computation = b.Build().ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 4e7ef66dc5..9f902d7298 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -23,6 +23,9 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
@@ -31,12 +34,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
namespace xla {
-using tensorflow::strings::StrCat;
+using absl::StrCat;
namespace {
@@ -223,8 +225,7 @@ XlaComputation XlaBuilder::BuildAndNoteError() {
auto build_status = Build();
if (!build_status.ok()) {
parent_builder_->ReportError(
- AddStatus(build_status.status(),
- tensorflow::strings::StrCat("error from: ", name_)));
+ AddStatus(build_status.status(), absl::StrCat("error from: ", name_)));
return {};
}
return build_status.ConsumeValueOrDie();
@@ -705,8 +706,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand,
TF_ASSIGN_OR_RETURN(const Shape& original_shape, GetShape(operand));
VLOG(3) << "original shape: " << ShapeUtil::HumanString(original_shape);
- VLOG(3) << "dims to collapse: "
- << tensorflow::str_util::Join(dimensions, ",");
+ VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ",");
std::vector<int64> new_sizes;
for (int i = 0; i < ShapeUtil::Rank(original_shape); ++i) {
@@ -717,8 +717,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand,
}
}
- VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",")
- << "]";
+ VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]";
return Reshape(operand, new_sizes);
});
@@ -1013,7 +1012,7 @@ StatusOr<Window> XlaBuilder::MakeWindow(
return Status::OK();
} else {
return InvalidArgument(
- "%s", tensorflow::strings::StrCat(
+ "%s", absl::StrCat(
"Window has different number of window dimensions than of ",
x_name,
"\nNumber of window dimensions: ", window_dimensions.size(),
@@ -1283,7 +1282,7 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name,
const Shape& shape) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
- if (tensorflow::str_util::StartsWith(call_target_name, "$")) {
+ if (absl::StartsWith(call_target_name, "$")) {
return InvalidArgument(
"Invalid custom_call_target \"%s\": Call targets that start with '$' "
"are reserved for internal use.",
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 3dbf1e5bee..baa2ae5184 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <type_traits>
#include <utility>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
diff --git a/tensorflow/compiler/xla/device_util.h b/tensorflow/compiler/xla/device_util.h
index 1a51fdee68..6d51126d88 100644
--- a/tensorflow/compiler/xla/device_util.h
+++ b/tensorflow/compiler/xla/device_util.h
@@ -21,8 +21,8 @@ limitations under the License.
#include <string>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
@@ -30,8 +30,8 @@ namespace xla {
// Returns a string that represents the device in terms of platform and ordinal;
// e.g. the first CUDA device will be "cuda:0"
string DeviceIdentifier(se::StreamExecutor* stream_exec) {
- return tensorflow::strings::StrCat(stream_exec->platform()->Name(), ":",
- stream_exec->device_ordinal());
+ return absl::StrCat(stream_exec->platform()->Name(), ":",
+ stream_exec->device_ordinal());
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc
index ffd1fb79e9..693dcb3a3e 100644
--- a/tensorflow/compiler/xla/index_util.cc
+++ b/tensorflow/compiler/xla/index_util.cc
@@ -18,10 +18,10 @@ limitations under the License.
#include <algorithm>
#include <string>
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -36,7 +36,7 @@ namespace xla {
DCHECK_GE(multi_index[i], 0);
DCHECK_LT(multi_index[i], shape.dimensions(i))
<< "indexing beyond extent in dimension " << i << ":"
- << "\n\tindex: " << tensorflow::str_util::Join(multi_index, ",")
+ << "\n\tindex: " << absl::StrJoin(multi_index, ",")
<< "\n\tshape: " << ShapeUtil::HumanString(shape);
}
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index b72d190d54..61c26434b1 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -23,6 +23,8 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -31,8 +33,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -211,7 +211,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
"layout minor_to_major field contains %d elements, "
"but shape is rank %lld: {%s}; shape: %s",
layout.minor_to_major_size(), ShapeUtil::Rank(shape),
- tensorflow::str_util::Join(layout.minor_to_major(), ", ").c_str(),
+ absl::StrJoin(layout.minor_to_major(), ", ").c_str(),
shape.ShortDebugString().c_str());
}
@@ -403,12 +403,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
/* static */ string LayoutUtil::HumanString(const Layout& layout) {
if (IsSparse(layout)) {
- return tensorflow::strings::StrCat("sparse{", layout.max_sparse_elements(),
- "}");
+ return absl::StrCat("sparse{", layout.max_sparse_elements(), "}");
}
CHECK(IsDense(layout));
- return tensorflow::strings::StrCat(
- "{", tensorflow::str_util::Join(layout.minor_to_major(), ","), "}");
+ return absl::StrCat("{", absl::StrJoin(layout.minor_to_major(), ","), "}");
}
namespace {
diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD
index 89353448e2..989035896b 100644
--- a/tensorflow/compiler/xla/legacy_flags/BUILD
+++ b/tensorflow/compiler/xla/legacy_flags/BUILD
@@ -56,6 +56,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -73,5 +74,6 @@ tf_cc_test(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
index 5d27e4a46b..0d3136b0cc 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -17,9 +17,9 @@ limitations under the License.
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
#include <vector>
+#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace legacy_flags {
@@ -87,7 +87,7 @@ void AllocateFlags() {
// Custom "sub-parser" lambda for xla_disable_hlo_passes.
auto setter_for_xla_disable_hlo_passes = [](string comma_separated_values) {
std::vector<string> disabled_passes =
- tensorflow::str_util::Split(comma_separated_values, ',');
+ absl::StrSplit(comma_separated_values, ',');
for (const auto& passname : disabled_passes) {
flag_values->add_xla_disable_hlo_passes(passname);
}
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h
index e9cf435d83..acda438395 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h
@@ -17,9 +17,10 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_
#include <vector>
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/xla.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
@@ -30,7 +31,7 @@ template <typename T>
void parse_xla_backend_extra_options(T* extra_options_map,
string comma_separated_values) {
std::vector<string> extra_options_parts =
- tensorflow::str_util::Split(comma_separated_values, ',');
+ absl::StrSplit(comma_separated_values, ',');
// The flag contains a comma-separated list of options; some options
// have arguments following "=", some don't.
@@ -59,8 +60,7 @@ void parse_xla_backend_extra_options(T* extra_options_map,
inline bool parse_xla_reduce_precision_option(
HloReducePrecisionOptions* options, string option_string) {
// Split off "LOCATION" from remainder of string.
- std::vector<string> eq_split =
- tensorflow::str_util::Split(option_string, '=');
+ std::vector<string> eq_split = absl::StrSplit(option_string, '=');
if (eq_split.size() != 2) {
return false;
}
@@ -80,26 +80,25 @@ inline bool parse_xla_reduce_precision_option(
}
// Split off "E,M" from remainder of string.
- std::vector<string> colon_split =
- tensorflow::str_util::Split(eq_split[1], ':');
+ std::vector<string> colon_split = absl::StrSplit(eq_split[1], ':');
if (colon_split.size() != 2) {
return false;
}
// Split E and M, and parse.
std::vector<int32> bitsizes;
- if (!tensorflow::str_util::SplitAndParseAsInts(colon_split[0], ',',
- &bitsizes) ||
- bitsizes.size() != 2) {
- return false;
+ for (const auto& s : absl::StrSplit(colon_split[0], ',')) {
+ bitsizes.emplace_back();
+ if (!absl::SimpleAtoi(s, &bitsizes.back())) {
+ return false;
+ }
}
options->set_exponent_bits(bitsizes[0]);
options->set_mantissa_bits(bitsizes[1]);
// Split off OPS comma-separated list from remainder of string, if the
// remainder exists.
- std::vector<string> semicolon_split =
- tensorflow::str_util::Split(colon_split[1], ';');
+ std::vector<string> semicolon_split = absl::StrSplit(colon_split[1], ';');
if (semicolon_split.size() > 2) {
return false;
}
@@ -113,8 +112,7 @@ inline bool parse_xla_reduce_precision_option(
options->add_opcodes_to_suffix(i);
}
} else {
- std::vector<string> opcodes =
- tensorflow::str_util::Split(opcode_string, ',');
+ std::vector<string> opcodes = absl::StrSplit(opcode_string, ',');
for (const string& opcode : opcodes) {
bool found = false;
for (int i = 0; i < HloOpcodeCount(); i++) {
@@ -132,8 +130,7 @@ inline bool parse_xla_reduce_precision_option(
// Process the NAMES string, if it exists.
if (semicolon_split.size() == 2) {
- std::vector<string> opnames =
- tensorflow::str_util::Split(semicolon_split[1], ',');
+ std::vector<string> opnames = absl::StrSplit(semicolon_split[1], ',');
for (const string& opname : opnames) {
if (opname.length() > 0) {
options->add_opname_substrings_to_suffix(opname);
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc
index 0ed788a967..6f197aec53 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include <unordered_map>
#include <vector>
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index d54f051a1a..30b890737b 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -23,6 +23,8 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -31,19 +33,16 @@ limitations under the License.
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
-using tensorflow::strings::Printf;
-using tensorflow::strings::StrCat;
-
namespace xla {
-
namespace {
+using absl::StrCat;
+using tensorflow::strings::Printf;
+
constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
// Converts between little and big endian.
@@ -1030,9 +1029,9 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
element_index.push_back(i);
std::vector<string> element_pieces;
ToStringHelper(literal, element_index, print_layout, &element_pieces);
- tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, ""));
+ tuple_pieces.push_back(absl::StrJoin(element_pieces, ""));
}
- pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n"));
+ pieces->push_back(absl::StrJoin(tuple_pieces, ",\n"));
pieces->push_back("\n)");
return;
}
@@ -1056,8 +1055,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
pieces->push_back(": ");
} else {
pieces->push_back("[");
- pieces->push_back(
- tensorflow::str_util::Join(literal.GetSparseIndex(i), ", "));
+ pieces->push_back(absl::StrJoin(literal.GetSparseIndex(i), ", "));
pieces->push_back("]: ");
}
pieces->push_back(literal.GetSparseElementAsString(i));
@@ -1183,7 +1181,7 @@ string LiteralBase::ToString(bool print_layout) const {
std::vector<string> pieces;
CHECK(LayoutUtil::HasLayout(this->shape()));
ToStringHelper(*this, {}, print_layout, &pieces);
- return tensorflow::str_util::Join(pieces, "");
+ return absl::StrJoin(pieces, "");
}
void LiteralBase::EachCellAsString(
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index ed9de65299..aad435ed5b 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -26,6 +26,7 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -40,7 +41,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/bitmap.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc
index 6883a6bbab..67a69c2403 100644
--- a/tensorflow/compiler/xla/literal_comparison.cc
+++ b/tensorflow/compiler/xla/literal_comparison.cc
@@ -19,16 +19,16 @@ limitations under the License.
#include <cmath>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/casts.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
+using absl::StrAppend;
+using absl::StrCat;
using tensorflow::strings::Appendf;
using tensorflow::strings::Printf;
-using tensorflow::strings::StrAppend;
-using tensorflow::strings::StrCat;
namespace xla {
namespace literal_comparison {
@@ -47,9 +47,9 @@ Status CompareFloatsBitwiseEqual(
if (ulhs != urhs) {
return InvalidArgument(
"floating values are not bitwise-equal; and equality testing "
- "was requested: %s=%g=%a vs %s=%g=%a at index %s",
- StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, lhs_double,
- StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double,
+ "was requested: %s=%g=%a vs %s=%g=%a at array index %s",
+ StrCat(absl::Hex(ulhs)).c_str(), lhs_double, lhs_double,
+ StrCat(absl::Hex(urhs)).c_str(), rhs_double, rhs_double,
LiteralUtil::MultiIndexAsString(multi_index).c_str());
}
return Status::OK();
@@ -65,9 +65,10 @@ Status CompareEqual(NativeT lhs, NativeT rhs,
return Status::OK();
}
return InvalidArgument(
- "Expected equality of these values:\n %s\n %s\nat index %s",
- StrCat(lhs).c_str(), StrCat(rhs).c_str(),
- LiteralUtil::MultiIndexAsString(multi_index).c_str());
+ "first mismatch at array index %s:\n expected value: %s\n actual "
+ "value: %s",
+ LiteralUtil::MultiIndexAsString(multi_index).c_str(), StrCat(lhs).c_str(),
+ StrCat(rhs).c_str());
}
// Specializations for floating types that do bitwise comparisons when equality
@@ -119,7 +120,8 @@ Status Equal(LiteralSlice expected, LiteralSlice actual,
Status result;
for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
multi_index[dimension] = i;
- result.Update(Equal<NativeT>(expected, actual, multi_index, dimension + 1));
+ TF_RETURN_IF_ERROR(
+ Equal<NativeT>(expected, actual, multi_index, dimension + 1));
}
return result;
}
@@ -251,11 +253,6 @@ class NearComparator {
// Runs the comparison between expected and actual literals.
Status Run() {
- VLOG(1) << "expected:";
- XLA_VLOG_LINES(1, ToStringTruncated(expected_));
- VLOG(1) << "actual:";
- XLA_VLOG_LINES(1, ToStringTruncated(actual_));
-
// If the shapes mismatch, we simply fail the expectation instead of
// printing out data, as it's a type error rather than a value error.
TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape()));
@@ -539,6 +536,62 @@ constexpr std::array<float, 7> NearComparator<NativeT>::kAbsValueBucketBounds;
template <typename NativeT>
constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
+Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) {
+ TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
+ std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
+ Status result;
+ switch (expected.shape().element_type()) {
+ case PRED:
+ result = Equal<bool>(expected, actual, &multi_index, 0);
+ break;
+ case U8:
+ result = Equal<uint8>(expected, actual, &multi_index, 0);
+ break;
+ case S32:
+ result = Equal<int32>(expected, actual, &multi_index, 0);
+ break;
+ case S64:
+ result = Equal<int64>(expected, actual, &multi_index, 0);
+ break;
+ case U32:
+ result = Equal<uint32>(expected, actual, &multi_index, 0);
+ break;
+ case U64:
+ result = Equal<uint64>(expected, actual, &multi_index, 0);
+ break;
+ case BF16:
+ result = Equal<bfloat16>(expected, actual, &multi_index, 0);
+ break;
+ case F16:
+ result = Equal<half>(expected, actual, &multi_index, 0);
+ break;
+ case F32:
+ result = Equal<float>(expected, actual, &multi_index, 0);
+ break;
+ case F64:
+ result = Equal<double>(expected, actual, &multi_index, 0);
+ break;
+ case C64:
+ result = Equal<complex64>(expected, actual, &multi_index, 0);
+ break;
+ case TUPLE: {
+ for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
+ result.Update(EqualHelper(LiteralSlice(expected, {i}),
+ LiteralSlice(actual, {i})));
+ }
+ break;
+ }
+ case TOKEN:
+ // Tokens have no on-device representation and are trivially equal.
+ return Status::OK();
+ default:
+ LOG(FATAL) << "Unsupported primitive type: "
+ << PrimitiveType_Name(expected.shape().element_type());
+ }
+
+ return result;
+}
+
// Helper function for comparing two literals for nearness. Handles tuple-shapes
// via recursion. shape_index is the ShapeIndex of expected (or actual)
// currently being compared.
@@ -555,17 +608,18 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
const auto actual_element = LiteralSlice(actual, {i});
ShapeIndex element_index = shape_index;
element_index.push_back(i);
- Status res =
+ Status element_result =
NearHelper(expected_element, actual_element, error, detailed_message,
miscompare_callback, element_index);
- if (!res.ok()) {
- string err_message = Printf("\nArray at shape index %s%s",
- element_index.ToString().c_str(),
- res.error_message().c_str());
+ if (!element_result.ok()) {
+ element_result = InvalidArgument(
+ "Array at shape index %s, %s", element_index.ToString().c_str(),
+ element_result.error_message().c_str());
if (return_status.ok()) {
- return_status = res;
+ return_status = element_result;
} else {
- return_status = AppendStatus(return_status, res.error_message());
+ return_status =
+ AppendStatus(return_status, element_result.error_message());
}
}
}
@@ -611,8 +665,8 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
}
}
- // Non-floating point literal.
- return literal_comparison::Equal(expected, actual);
+ // Non-floating point, non-tuple literal.
+ return EqualHelper(expected, actual);
}
} // namespace
@@ -668,81 +722,44 @@ Status EqualShapes(const Shape& expected, const Shape& actual) {
return Status::OK();
}
+namespace {
+
+// If result is an error, extend the error message with the expected and actual
+// literals.
+Status EmitLiteralsInErrorMessage(const Status& result,
+ const LiteralSlice& expected,
+ const LiteralSlice& actual) {
+ if (result.ok()) {
+ return result;
+ }
+ return InvalidArgument("%s\n\nExpected literal:\n%s\n\nActual literal:\n%s",
+ result.error_message().c_str(),
+ ToStringTruncated(expected).c_str(),
+ ToStringTruncated(actual).c_str());
+}
+
+} // namespace
+
Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
VLOG(1) << "expected:";
XLA_VLOG_LINES(1, expected.ToString());
VLOG(1) << "actual:";
XLA_VLOG_LINES(1, actual.ToString());
-
- TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
- std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
- Status result;
- switch (expected.shape().element_type()) {
- case PRED:
- result = Equal<bool>(expected, actual, &multi_index, 0);
- break;
- case U8:
- result = Equal<uint8>(expected, actual, &multi_index, 0);
- break;
- case S32:
- result = Equal<int32>(expected, actual, &multi_index, 0);
- break;
- case S64:
- result = Equal<int64>(expected, actual, &multi_index, 0);
- break;
- case U32:
- result = Equal<uint32>(expected, actual, &multi_index, 0);
- break;
- case U64:
- result = Equal<uint64>(expected, actual, &multi_index, 0);
- break;
- case BF16:
- result = Equal<bfloat16>(expected, actual, &multi_index, 0);
- break;
- case F16:
- result = Equal<half>(expected, actual, &multi_index, 0);
- break;
- case F32:
- result = Equal<float>(expected, actual, &multi_index, 0);
- break;
- case F64:
- result = Equal<double>(expected, actual, &multi_index, 0);
- break;
- case C64:
- result = Equal<complex64>(expected, actual, &multi_index, 0);
- break;
- case TUPLE: {
- for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
- result.Update(
- Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i})));
- }
- break;
- }
- case TOKEN:
- // Tokens have no on-device representation and are trivially equal.
- return Status::OK();
- default:
- LOG(FATAL)
- << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: "
- << PrimitiveType_Name(expected.shape().element_type());
- }
-
- if (result.ok()) {
- return Status::OK();
- }
-
- return AppendStatus(
- result, tensorflow::strings::Printf("\nexpected: %s\nactual: %s",
- ToStringTruncated(expected).c_str(),
- ToStringTruncated(actual).c_str()));
+ Status result = EqualHelper(expected, actual);
+ return EmitLiteralsInErrorMessage(result, expected, actual);
}
Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
const ErrorSpec& error, bool detailed_message,
const MiscompareCallback& miscompare_callback) {
- return NearHelper(expected, actual, error, detailed_message,
- miscompare_callback,
- /*shape_index=*/{});
+ VLOG(1) << "Expected literal:";
+ XLA_VLOG_LINES(1, expected.ToString());
+ VLOG(1) << "Actual literal:";
+ XLA_VLOG_LINES(1, actual.ToString());
+ Status result =
+ NearHelper(expected, actual, error, detailed_message, miscompare_callback,
+ /*shape_index=*/{});
+ return EmitLiteralsInErrorMessage(result, expected, actual);
}
string ToStringTruncated(const LiteralSlice& literal) {
diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc
index c5d0c2c267..aef87e46d8 100644
--- a/tensorflow/compiler/xla/literal_test.cc
+++ b/tensorflow/compiler/xla/literal_test.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -1324,8 +1326,8 @@ TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) {
auto literal = LiteralUtil::CreateR0<uint32>(1234);
Status status = literal->BitcastConvert(F64).status();
EXPECT_NE(Status::OK(), status);
- EXPECT_TRUE(tensorflow::str_util::StrContains(status.error_message(),
- "bit widths are different"));
+ EXPECT_TRUE(
+ absl::StrContains(status.error_message(), "bit widths are different"));
}
TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
@@ -1819,21 +1821,20 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) {
"false");
ASSERT_EQ(LiteralUtil::CreateSparse<int64>(dimensions, indices, {1, 2, 3})
->GetSparseElementAsString(1),
- tensorflow::strings::StrCat(int64{2}));
+ absl::StrCat(int64{2}));
ASSERT_EQ(
LiteralUtil::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0})
->GetSparseElementAsString(1),
- tensorflow::strings::StrCat(double{2.0}));
+ absl::StrCat(double{2.0}));
ASSERT_EQ(LiteralUtil::CreateSparse<half>(dimensions, indices,
{half{1.0}, half{2.0}, half{3.0}})
->GetSparseElementAsString(1),
- tensorflow::strings::StrCat(static_cast<float>(half{2.0})));
- ASSERT_EQ(
- LiteralUtil::CreateSparse<complex64>(
- dimensions, indices,
- std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})
- ->GetSparseElementAsString(1),
- tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")"));
+ absl::StrCat(static_cast<float>(half{2.0})));
+ ASSERT_EQ(LiteralUtil::CreateSparse<complex64>(
+ dimensions, indices,
+ std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})
+ ->GetSparseElementAsString(1),
+ absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")"));
}
TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index d4c7b76b28..95d93acfe8 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -23,6 +23,8 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -31,19 +33,16 @@ limitations under the License.
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/types.h"
-using tensorflow::strings::StrCat;
-
namespace xla {
-
namespace {
+using absl::StrCat;
+
// Return a literal with all arrays of type FromNativeT converted to type
// ToNativeT in the given literal.
template <typename FromNativeT, typename ToNativeT>
@@ -287,7 +286,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
}
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1U8(
- tensorflow::StringPiece value) {
+ absl::string_view value) {
auto literal = absl::make_unique<Literal>(
ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
for (int i = 0; i < value.size(); ++i) {
@@ -477,7 +476,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
/* static */ string LiteralUtil::MultiIndexAsString(
tensorflow::gtl::ArraySlice<int64> multi_index) {
- return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}");
+ return StrCat("{", absl::StrJoin(multi_index, ","), "}");
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 1109021ea8..3d28c070f2 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -28,6 +28,7 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -43,7 +44,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/bitmap.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -187,7 +187,7 @@ class LiteralUtil {
const Array4D<NativeT>& values, const Layout& layout);
// Creates a new vector of U8s literal value from a string.
- static std::unique_ptr<Literal> CreateR1U8(tensorflow::StringPiece value);
+ static std::unique_ptr<Literal> CreateR1U8(absl::string_view value);
// Creates a linspace-populated literal with the given number of rows and
// columns.
diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc
index 69ef4f7a2f..2f22e02c3e 100644
--- a/tensorflow/compiler/xla/metric_table_report.cc
+++ b/tensorflow/compiler/xla/metric_table_report.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <cctype>
#include <unordered_map>
+#include "absl/strings/str_cat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -84,7 +85,7 @@ void MetricTableReport::WriteReportToInfoLog(double expected_metric_sum) {
if (end_of_line == string::npos) {
end_of_line = report.size();
}
- tensorflow::StringPiece line(report.data() + pos, end_of_line - pos);
+ absl::string_view line(report.data() + pos, end_of_line - pos);
// TODO(b/34779244): Figure out how to do this without the verbose log-line
// prefix. The usual way didn't compile on open source.
@@ -152,8 +153,8 @@ void MetricTableReport::AppendCategoryTable() {
if (text.empty()) {
text = "[no category]";
}
- tensorflow::strings::StrAppend(&text, " (", category.entries.size(), " ",
- entry_name_, ")");
+ absl::StrAppend(&text, " (", category.entries.size(), " ", entry_name_,
+ ")");
AppendTableRow(text, category.metric_sum, metric_sum);
// Show the top entries in the category.
@@ -177,9 +178,9 @@ void MetricTableReport::AppendCategoryTable() {
}
const int64 remaining_categories = categories.size() - categories_shown;
if (remaining_categories > 0) {
- AppendTableRow(tensorflow::strings::StrCat("... (", remaining_categories,
- " more categories)"),
- expected_metric_sum_ - metric_sum, expected_metric_sum_);
+ AppendTableRow(
+ absl::StrCat("... (", remaining_categories, " more categories)"),
+ expected_metric_sum_ - metric_sum, expected_metric_sum_);
}
}
@@ -206,9 +207,9 @@ void MetricTableReport::AppendEntryTable() {
}
const int64 remaining_entries = entries_.size() - entries_shown;
if (remaining_entries > 0) {
- AppendTableRow(tensorflow::strings::StrCat("... (", remaining_entries,
- " more ", entry_name_, ")"),
- expected_metric_sum_ - metric_sum, expected_metric_sum_);
+ AppendTableRow(
+ absl::StrCat("... (", remaining_entries, " more ", entry_name_, ")"),
+ expected_metric_sum_ - metric_sum, expected_metric_sum_);
}
}
@@ -241,10 +242,10 @@ double MetricTableReport::UnaccountedMetric() {
string MetricTableReport::MetricString(double metric) {
// Round to integer and stringify.
- string s1 = tensorflow::strings::StrCat(std::llround(metric));
+ string s1 = absl::StrCat(std::llround(metric));
// Code below commafies the string, e.g. "1234" becomes "1,234".
- tensorflow::StringPiece sp1(s1);
+ absl::string_view sp1(s1);
string output;
// Copy leading non-digit characters unconditionally.
// This picks up the leading sign.
diff --git a/tensorflow/compiler/xla/metric_table_report.h b/tensorflow/compiler/xla/metric_table_report.h
index 818fb1d3fe..062d8ed99b 100644
--- a/tensorflow/compiler/xla/metric_table_report.h
+++ b/tensorflow/compiler/xla/metric_table_report.h
@@ -18,9 +18,8 @@ limitations under the License.
#include <vector>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
@@ -108,7 +107,7 @@ class MetricTableReport {
// Append all parameters to the report.
template <typename... Args>
void AppendLine(Args... args) {
- tensorflow::strings::StrAppend(&report_, std::forward<Args>(args)..., "\n");
+ absl::StrAppend(&report_, std::forward<Args>(args)..., "\n");
}
// Represents a set of entries with the same category_text.
diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc
index 55c4a80e29..012df87551 100644
--- a/tensorflow/compiler/xla/packed_literal_reader.cc
+++ b/tensorflow/compiler/xla/packed_literal_reader.cc
@@ -64,7 +64,7 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
tensorflow::gtl::ArraySlice<float> field = result->data<float>();
char* data = tensorflow::bit_cast<char*>(field.data());
uint64 bytes = elements * sizeof(float);
- tensorflow::StringPiece sp;
+ tensorflow::StringPiece sp; // non-absl OK
auto s = file_->Read(offset_, bytes, &sp, data);
offset_ += sp.size();
if (!s.ok()) {
@@ -85,7 +85,7 @@ bool PackedLiteralReader::IsExhausted() const {
// Try to read a single byte from offset_. If we can't, we've
// exhausted the data.
char single_byte[1];
- tensorflow::StringPiece sp;
+ tensorflow::StringPiece sp; // non-absl OK
auto s = file_->Read(offset_, sizeof(single_byte), &sp, single_byte);
return !s.ok();
}
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index a91336c3ac..2d8fe434b0 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -39,6 +39,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/python:numpy_lib",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index e1060d54e2..08dccb3ee1 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -109,6 +109,7 @@ limitations under the License.
// Must be included first
#include "tensorflow/python/lib/core/numpy.h"
+#include "third_party/absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -896,7 +897,7 @@ tensorflow::ImportNumpy();
if (o != Py_None) {
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
if (!statusor.ok()) {
- PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str());
+ PyErr_SetString(PyExc_TypeError, absl::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str());
Py_DECREF(o);
SWIG_fail;
}
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc
index 4b9970eadc..f2f99c1745 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.cc
+++ b/tensorflow/compiler/xla/python/numpy_bridge.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/numpy_bridge.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/platform/logging.h"
@@ -191,8 +192,8 @@ StatusOr<Shape> XlaShapeFromPyShape(PyObject* o) {
PyObject* result =
PyObject_CallMethod(o, const_cast<char*>(method.c_str()), nullptr);
if (result == nullptr) {
- return error(tensorflow::strings::StrCat(
- "Failed to call method of shape object:", method));
+ return error(
+ absl::StrCat("Failed to call method of shape object:", method));
}
return result;
};
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 32723849a6..aa826aa770 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -99,6 +99,7 @@ cc_library(
":bfloat16_support",
":hlo",
":hlo_pass",
+ "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
@@ -176,6 +177,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
],
)
@@ -241,6 +243,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
@@ -320,6 +323,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -346,7 +350,7 @@ cc_library(
deps = [
":hlo",
"//tensorflow/compiler/xla:shape_util",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -398,7 +402,7 @@ cc_library(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
@@ -460,6 +464,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -564,6 +569,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
],
)
@@ -587,6 +593,7 @@ cc_library(
"//tensorflow/core:stream_executor_no_cuda",
"//third_party/eigen3",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -629,6 +636,7 @@ cc_library(
"//tensorflow/core:ptr_util",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
alwayslink = 1,
)
@@ -662,6 +670,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -684,6 +693,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
],
)
@@ -735,6 +745,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -832,6 +843,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -851,6 +863,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -886,6 +899,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -896,6 +910,7 @@ cc_library(
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -930,6 +945,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -939,6 +955,7 @@ tf_cc_test(
deps = [
":buffer_liveness",
":hlo",
+ ":hlo_dataflow_analysis",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
@@ -974,6 +991,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1021,6 +1039,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -1113,6 +1132,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1199,6 +1219,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -1215,6 +1236,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1330,6 +1352,7 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
@@ -1355,6 +1378,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1419,6 +1443,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -1457,6 +1482,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1497,6 +1523,7 @@ cc_library(
":while_loop_analysis",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
@@ -1511,6 +1538,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -1701,6 +1729,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
alwayslink = True, # Contains per-platform computation placer registration
)
@@ -1714,6 +1743,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -1807,6 +1837,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -1839,6 +1870,7 @@ tf_cc_binary(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1871,6 +1903,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/strings",
],
)
@@ -1898,6 +1931,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/strings",
],
)
@@ -1916,6 +1950,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1936,6 +1971,7 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1978,6 +2014,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -2014,6 +2051,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -2034,6 +2072,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -2093,6 +2132,7 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -2144,6 +2184,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -2166,6 +2207,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -2235,6 +2277,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -2278,6 +2321,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
],
)
@@ -2400,6 +2444,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -2633,6 +2678,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
"@llvm//:transform_utils",
],
@@ -2666,8 +2712,8 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
- "//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
@@ -2681,6 +2727,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -2717,8 +2764,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/core:framework",
- "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -2752,6 +2799,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
alwayslink = 1,
@@ -2769,6 +2817,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -2997,8 +3046,8 @@ cc_library(
":hlo_creation_utils",
":tuple_util",
"//tensorflow/compiler/xla:literal_util",
- "//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
],
)
@@ -3114,6 +3163,7 @@ cc_library(
"//tensorflow/core:ptr_util",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
@@ -3150,6 +3200,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -3164,6 +3215,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main", # fixdeps: keep
+ "@com_google_absl//absl/strings",
],
)
@@ -3182,6 +3234,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index b86b7d2e71..c236453fc7 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
@@ -1989,9 +1990,9 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
VLOG(10) << "Considering folding Pad: " << pad->ToString()
<< "\ninto reduce-window: " << reduce_window->ToString()
- << (convert != nullptr ? tensorflow::strings::StrCat(
- "\nvia convert: ", convert->ToString())
- : "");
+ << (convert != nullptr
+ ? absl::StrCat("\nvia convert: ", convert->ToString())
+ : "");
// Do not fold interior padding into ReduceWindow since the backends do not
// support it.
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h
index c48196e861..b864c372fa 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.h
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h
@@ -47,7 +47,7 @@ class AlgebraicSimplifier : public HloPassInterface {
enable_dot_strength_reduction_(enable_dot_strength_reduction),
enable_conv_simplification_(enable_conv_simplification) {}
~AlgebraicSimplifier() override = default;
- tensorflow::StringPiece name() const override { return "algsimp"; }
+ absl::string_view name() const override { return "algsimp"; }
// Run algebraic simplification on the given computation. Returns whether the
// computation was changed.
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 427069af5f..bb63ea26d4 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include <utility>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -34,13 +36,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-
-using ::testing::ElementsAre;
namespace xla {
namespace {
+using ::testing::ElementsAre;
+
namespace op = xla::testing::opcode_matchers;
AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() {
@@ -51,7 +52,12 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() {
return [](const Shape&, const Shape&) { return false; };
}
-class AlgebraicSimplifierTest : public HloVerifiedTestBase {};
+class AlgebraicSimplifierTest : public HloVerifiedTestBase {
+ public:
+ AlgebraicSimplifierTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
// Test that A + 0 is simplified to A
TEST_F(AlgebraicSimplifierTest, AddZero) {
@@ -2143,9 +2149,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
root->operand(0)->opcode() == HloOpcode::kDot) {
auto lhs_shape = root->operand(0)->operand(0)->shape();
auto rhs_shape = root->operand(0)->operand(1)->shape();
- return tensorflow::strings::StrCat(
- tensorflow::str_util::Join(lhs_shape.dimensions(), "x"), " DOT ",
- tensorflow::str_util::Join(rhs_shape.dimensions(), "x"));
+ return absl::StrCat(absl::StrJoin(lhs_shape.dimensions(), "x"), " DOT ",
+ absl::StrJoin(rhs_shape.dimensions(), "x"));
}
return "UNEXPECTED CHANGE";
};
@@ -2660,11 +2665,10 @@ struct PadReduceWindowEffectiveBroadcastCase {
bool should_become_broadcast;
string ToTestCaseName() const {
- return tensorflow::strings::StrCat(
- tensorflow::str_util::Join(input_spatials, ","), ";",
- tensorflow::str_util::Join(symmetric_pad_spatials, ","), ";",
- tensorflow::str_util::Join(reduce_window_spatials, ","), ";", prepend_a,
- ";", should_become_broadcast);
+ return absl::StrCat(absl::StrJoin(input_spatials, ","), ";",
+ absl::StrJoin(symmetric_pad_spatials, ","), ";",
+ absl::StrJoin(reduce_window_spatials, ","), ";",
+ prepend_a, ";", should_become_broadcast);
}
};
@@ -2852,7 +2856,12 @@ struct DotOfConcatTestSpec {
class DotOfConcatSimplificationTest
: public HloVerifiedTestBase,
- public ::testing::WithParamInterface<DotOfConcatTestSpec> {};
+ public ::testing::WithParamInterface<DotOfConcatTestSpec> {
+ public:
+ DotOfConcatSimplificationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
// Test that we transform
// dot(const, concat(A, B, C))
@@ -3025,7 +3034,12 @@ struct DotOfGatherTestSpec {
class DotOfGatherSimplificationTest
: public HloVerifiedTestBase,
- public ::testing::WithParamInterface<DotOfGatherTestSpec> {};
+ public ::testing::WithParamInterface<DotOfGatherTestSpec> {
+ public:
+ DotOfGatherSimplificationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
// input: dot(DS(ctA), ctB))
// where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}.
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc
index d0806d24a2..5115a14df0 100644
--- a/tensorflow/compiler/xla/service/allocation_tracker.cc
+++ b/tensorflow/compiler/xla/service/allocation_tracker.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <utility>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h
index 1bc3796fa4..4a6a78daf0 100644
--- a/tensorflow/compiler/xla/service/backend.h
+++ b/tensorflow/compiler/xla/service/backend.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@@ -29,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -130,7 +130,7 @@ class Backend {
// Return a string identifier for the given device, eg: "GPU:3".
string device_name(int device_ordinal) const {
- return tensorflow::strings::StrCat(platform_->Name(), ":", device_ordinal);
+ return absl::StrCat(platform_->Name(), ":", device_ordinal);
}
// Returns true if the devices with the given ordinals are equivalent from
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
index be6fbcc9e3..a16b85a0a5 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
@@ -78,7 +78,7 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot(
return true;
}
-tensorflow::StringPiece BatchDotSimplification::name() const {
+absl::string_view BatchDotSimplification::name() const {
return "batch-dot-simplification";
}
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h
index c0ca8d8eba..79d37f08d3 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification.h
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h
@@ -28,7 +28,7 @@ namespace xla {
class BatchDotSimplification : public HloPassInterface {
public:
StatusOr<bool> Run(HloModule* module) override;
- tensorflow::StringPiece name() const override;
+ absl::string_view name() const override;
private:
StatusOr<bool> ElideDegenerateBatchDimensionFromBatchDot(
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
index 38f1a5d3a6..b342acb025 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
@@ -24,7 +24,12 @@ namespace {
namespace op = xla::testing::opcode_matchers;
-class BatchDotSimplificationTest : public HloVerifiedTestBase {};
+class BatchDotSimplificationTest : public HloVerifiedTestBase {
+ public:
+ BatchDotSimplificationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
TEST_F(BatchDotSimplificationTest,
ElideSingleDegenerateBatchDotDim_VectorVector) {
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h
index 7ae202c583..76e32174f3 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.h
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.h
@@ -36,7 +36,7 @@ class BatchNormExpander : public HloPassInterface {
rewrite_inference_op_(rewrite_inference_op),
rewrite_grad_op_(rewrite_grad_op) {}
~BatchNormExpander() = default;
- tensorflow::StringPiece name() const override { return "batchnorm_expander"; }
+ absl::string_view name() const override { return "batchnorm_expander"; }
// Run operation expander on the given computation. Returns whether the
// computation was changed.
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
index f62ab12319..aba0d9bb5b 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
@@ -32,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace {
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
index c939838709..5dcd31b83d 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
@@ -37,7 +37,7 @@ class BFloat16ConversionFolding : public HloPassInterface {
: bfloat16_support_(bfloat16_support) {}
~BFloat16ConversionFolding() override = default;
- tensorflow::StringPiece name() const override { return "bfloat16-fold"; }
+ absl::string_view name() const override { return "bfloat16-fold"; }
// Run BF16 conversion folding on the given computation. Returns whether the
// computation was changed.
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc
index 16e99b5722..32573ed355 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -34,11 +35,6 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
Status DefaultAction(HloInstruction* hlo) override;
- // Special handling for cross-replica-sum and sort which can have a tuple
- // output.
- Status HandleCrossReplicaSum(HloInstruction* crs) override;
- Status HandleSort(HloInstruction* sort) override;
-
static bool Run(HloComputation* computation,
const BFloat16Support* bfloat16_support) {
BFloat16NormalizationVisitor visitor(computation, bfloat16_support);
@@ -150,23 +146,6 @@ Status BFloat16NormalizationVisitor::ConvertCalledComputations(
return Status::OK();
}
-Status BFloat16NormalizationVisitor::HandleCrossReplicaSum(
- HloInstruction* crs) {
- if (!ShapeUtil::IsTuple(crs->shape())) {
- return HandleInstruction(crs);
- } else {
- return HandleMultipleOutputs(crs);
- }
-}
-
-Status BFloat16NormalizationVisitor::HandleSort(HloInstruction* sort) {
- if (!ShapeUtil::IsTuple(sort->shape())) {
- return HandleInstruction(sort);
- } else {
- return HandleMultipleOutputs(sort);
- }
-}
-
Status BFloat16NormalizationVisitor::HandleMultipleOutputs(
HloInstruction* hlo) {
std::vector<PrimitiveType> operand_types(hlo->operand_count());
@@ -380,6 +359,11 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
hlo->opcode() == HloOpcode::kConditional) {
return Status::OK();
}
+ if ((hlo->opcode() == HloOpcode::kSort ||
+ hlo->opcode() == HloOpcode::kCrossReplicaSum) &&
+ ShapeUtil::IsTuple(hlo->shape())) {
+ return HandleMultipleOutputs(hlo);
+ }
return HandleInstruction(hlo);
}
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h
index 2a60fe0af3..30b6346312 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization.h
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h
@@ -31,7 +31,7 @@ class BFloat16Normalization : public HloPassInterface {
: bfloat16_support_(bfloat16_support) {}
~BFloat16Normalization() override = default;
- tensorflow::StringPiece name() const override { return "bf16-normalization"; }
+ absl::string_view name() const override { return "bf16-normalization"; }
// Run BF16 normalization on the given computation. Returns whether the
// computation was changed.
@@ -54,7 +54,7 @@ class BFloat16MixedPrecisionRemoval : public HloPassInterface {
~BFloat16MixedPrecisionRemoval() override = default;
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "bf16-mixed-precision-removal";
}
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index 49ae5320b0..b08705d4c2 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -76,7 +76,8 @@ class BFloat16NormalizationTest : public HloTestBase {
StatusOr<bool> result = normalization.Run(module);
EXPECT_IS_OK(result.status());
- HloVerifier verifier(/*allow_mixed_precision=*/true);
+ HloVerifier verifier(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true);
EXPECT_IS_OK(verifier.Run(module).status());
return result.ValueOrDie();
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h
index 02b8cad089..1ee64971ab 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.h
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h
@@ -64,9 +64,7 @@ class BFloat16Propagation : public HloPassInterface {
~BFloat16Propagation() override = default;
- tensorflow::StringPiece name() const override {
- return "bfloat16-propagation";
- }
+ absl::string_view name() const override { return "bfloat16-propagation"; }
// Runs the pass on the given module. Returns whether the module was changed
// (precision reductions were added).
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 0f08e7c52b..c8c36ae60e 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <utility>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/buffer_value_containers.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
@@ -36,20 +37,17 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
+namespace {
+using absl::StrAppend;
using ::tensorflow::gtl::FlatMap;
using ::tensorflow::gtl::FlatSet;
using ::tensorflow::strings::Appendf;
using ::tensorflow::strings::HumanReadableNumBytes;
using ::tensorflow::strings::Printf;
-using ::tensorflow::strings::StrAppend;
-
-namespace {
template <typename T>
string ColocatedBufferSetsToString(const T& container, const char* title) {
@@ -236,8 +234,8 @@ size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const {
}
string BufferAllocation::Slice::ToString() const {
- return tensorflow::strings::StrCat("{index:", index(), ", offset:", offset_,
- ", size:", size_, "}");
+ return absl::StrCat("{index:", index(), ", offset:", offset_,
+ ", size:", size_, "}");
}
BufferAllocation::Slice BufferAllocation::GetSlice(
@@ -678,9 +676,9 @@ string BufferAssignment::Stats::ToString() const {
string BufferAssignment::ToString() const {
string output;
- tensorflow::strings::StrAppend(&output, "BufferAssignment:\n");
+ absl::StrAppend(&output, "BufferAssignment:\n");
for (auto& allocation : allocations_) {
- tensorflow::strings::StrAppend(&output, allocation.ToString());
+ absl::StrAppend(&output, allocation.ToString());
}
return output;
}
diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc
index 810d597e73..8d0ac3b84a 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
@@ -89,13 +89,13 @@ string BufferLiveness::ToString() const {
pieces.push_back(
tensorflow::strings::Printf(" %s", buffer->ToString().c_str()));
}
- return tensorflow::str_util::Join(pieces, "\n");
+ return absl::StrJoin(pieces, "\n");
}
bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
const LogicalBuffer& b) const {
- TF_CHECK_OK(points_to_analysis_->VerifyBuffer(a));
- TF_CHECK_OK(points_to_analysis_->VerifyBuffer(b));
+ TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(a));
+ TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(b));
if (!hlo_ordering_->ExecutesBefore(a.instruction(), b.instruction())) {
return false;
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
index 3ffb7de65f..26e26e316d 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -610,11 +611,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) {
class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
protected:
// Builds and runs a computation (see test case computation graphs below).
- // Runs BufferLiveness on this computation.
- // Returns whether buffer interference is detected between tuple-shaped
- // parameter and root instructions at tuple element 1.
- bool Run(const bool update_uses_tuple_element1,
- const bool fuse_gte0 = false) {
+ std::unique_ptr<HloModule> BuildModule(const bool update_uses_tuple_element1,
+ const bool fuse_gte0) {
auto builder = HloComputation::Builder(TestName());
// Create param0 Tuple.
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
@@ -645,12 +643,12 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
// Create output tuple.
- auto tuple_root = builder.AddInstruction(
+ builder.AddInstruction(
HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
// Build module and get reference to entry computation.
auto module = CreateNewModule();
- module->AddEntryComputation(BuildDummyComputation());
- auto* computation = module->AddEmbeddedComputation(builder.Build());
+ module->AddEntryComputation(builder.Build());
+ auto* computation = module->entry_computation();
// Create fusion instruction based on number of tuple element 1 users.
if (update_uses_tuple_element1) {
computation->CreateFusionInstruction(
@@ -666,7 +664,14 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
computation->CreateFusionInstruction({gte0},
HloInstruction::FusionKind::kLoop);
}
+ return module;
+ }
+ // Returns whether buffer interference is detected between tuple-shaped
+ // parameter and root instructions at tuple element 1.
+ bool Run(const bool update_uses_tuple_element1,
+ const bool fuse_gte0 = false) {
+ auto module = BuildModule(update_uses_tuple_element1, fuse_gte0);
// Run BufferLiveness on 'module'.
auto liveness = BufferLiveness::Run(
module.get(),
@@ -674,8 +679,24 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
.ConsumeValueOrDie();
// Return whether or not buffers interference is detected between
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
+ auto tuple_param0 = FindInstruction(module.get(), "param0");
+ auto tuple_root = module->entry_computation()->root_instruction();
return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
}
+ bool RunWithHloDataflowAnalysis(const bool update_uses_tuple_element1,
+ const bool fuse_gte0 = false) {
+ auto module = BuildModule(update_uses_tuple_element1, fuse_gte0);
+ // Run BufferLiveness on 'module'.
+ auto dataflow = HloDataflowAnalysis::Run(*module).ConsumeValueOrDie();
+ auto hlo_ordering = absl::make_unique<DependencyHloOrdering>(module.get());
+ // Return whether or not buffers interference is detected between
+ // 'tuple_param0' and 'tuple_root' at shape index '{1}'.
+ auto tuple_param0 = FindInstruction(module.get(), "param0");
+ auto tuple_root = module->entry_computation()->root_instruction();
+ return hlo_ordering->MayInterfere(
+ dataflow->GetUniqueValueAt(tuple_param0, {1}),
+ dataflow->GetUniqueValueAt(tuple_root, {1}), *dataflow);
+ }
};
// Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion)
@@ -693,6 +714,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
//
TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) {
EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false));
+ EXPECT_FALSE(
+ RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false));
}
// Tests that live ranges of buffers Param0[1] and Tuple[1] (which aliases
@@ -712,6 +735,8 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) {
//
TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) {
EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false, /*fuse_gte0=*/true));
+ EXPECT_FALSE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false,
+ /*fuse_gte0=*/true));
}
// Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion)
@@ -736,6 +761,7 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) {
//
TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) {
EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true));
+ EXPECT_TRUE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/true));
}
class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
diff --git a/tensorflow/compiler/xla/service/buffer_value.cc b/tensorflow/compiler/xla/service/buffer_value.cc
index 2bc556a9e2..fdf822c666 100644
--- a/tensorflow/compiler/xla/service/buffer_value.cc
+++ b/tensorflow/compiler/xla/service/buffer_value.cc
@@ -17,11 +17,10 @@ limitations under the License.
#include <iosfwd>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc
index d6efef5f12..37523a73ff 100644
--- a/tensorflow/compiler/xla/service/call_graph.cc
+++ b/tensorflow/compiler/xla/service/call_graph.cc
@@ -18,20 +18,20 @@ limitations under the License.
#include <queue>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
+using absl::StrCat;
using ::tensorflow::strings::Appendf;
-using ::tensorflow::strings::StrCat;
string CallContextToString(CallContext context) {
switch (context) {
@@ -71,10 +71,10 @@ CallContext GetInstructionCallContext(HloOpcode opcode) {
}
string CallSite::ToString() const {
- return StrCat(instruction()->name(), " calls in context ",
- CallContextToString(context()), ": ",
- tensorflow::str_util::Join(
- called_computations(), ", ",
+ return StrCat(
+ instruction()->name(), " calls in context ",
+ CallContextToString(context()), ": ",
+ absl::StrJoin(called_computations(), ", ",
[](string* out, const HloComputation* computation) {
out->append(computation->name());
}));
diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h
index c0e95e1578..c5cd88b9ea 100644
--- a/tensorflow/compiler/xla/service/call_inliner.h
+++ b/tensorflow/compiler/xla/service/call_inliner.h
@@ -35,7 +35,7 @@ class CallInliner : public HloPassInterface {
static StatusOr<InlinedInstructionMap> Inline(HloInstruction* call);
~CallInliner() override = default;
- tensorflow::StringPiece name() const override { return "CallInliner"; }
+ absl::string_view name() const override { return "CallInliner"; }
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc
index e75f6f146d..5d85a3f173 100644
--- a/tensorflow/compiler/xla/service/call_inliner_test.cc
+++ b/tensorflow/compiler/xla/service/call_inliner_test.cc
@@ -32,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace op = xla::testing::opcode_matchers;
diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc
index 9c9e373821..601a3e9a01 100644
--- a/tensorflow/compiler/xla/service/channel_tracker.cc
+++ b/tensorflow/compiler/xla/service/channel_tracker.cc
@@ -16,13 +16,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/channel_tracker.h"
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index 7426672a7a..3079695e96 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/host_info.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -76,9 +76,9 @@ CompileOnlyService::CompileAheadOfTime(
if (!directory_path.empty()) {
HloSnapshot hlo_snapshot;
*hlo_snapshot.mutable_hlo()->mutable_hlo_module() = instance.computation;
- string filename = tensorflow::strings::StrCat(
- "computation_", instance.computation.id(), "__",
- instance.computation.entry_computation_name());
+ string filename =
+ absl::StrCat("computation_", instance.computation.id(), "__",
+ instance.computation.entry_computation_name());
const string& per_host_path = tensorflow::io::JoinPath(
directory_path, tensorflow::port::Hostname());
diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc
index cb61f3da39..af8f7f1027 100644
--- a/tensorflow/compiler/xla/service/computation_layout.cc
+++ b/tensorflow/compiler/xla/service/computation_layout.cc
@@ -17,9 +17,9 @@ limitations under the License.
#include <algorithm>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
@@ -52,9 +52,8 @@ string ComputationLayout::ToString() const {
for (auto& param_layout : parameter_layouts_) {
params.push_back(param_layout.ToString());
}
- return tensorflow::strings::StrCat("(",
- tensorflow::str_util::Join(params, ", "),
- ") => ", result_layout_.ToString());
+ return absl::StrCat("(", absl::StrJoin(params, ", "), ") => ",
+ result_layout_.ToString());
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc
index afbbea35b8..61b1dba6c9 100644
--- a/tensorflow/compiler/xla/service/computation_placer.cc
+++ b/tensorflow/compiler/xla/service/computation_placer.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status.h"
@@ -29,12 +30,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
-using tensorflow::strings::StrAppend;
-using tensorflow::strings::StrCat;
+using absl::StrAppend;
+using absl::StrCat;
namespace xla {
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc
index b7be3ba605..4ea3a13f28 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier.cc
+++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -28,8 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h
index 063261e26d..3de50cbd7f 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier.h
+++ b/tensorflow/compiler/xla/service/conditional_simplifier.h
@@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
@@ -27,9 +27,7 @@ namespace xla {
// with their true or false computation as appropriate.
class ConditionalSimplifier : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override {
- return "simplify-conditional";
- }
+ absl::string_view name() const override { return "simplify-conditional"; }
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
index c43a31b167..6c477da038 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
@@ -39,6 +39,10 @@ namespace op = xla::testing::opcode_matchers;
class ConditionalSimplifierTest : public HloVerifiedTestBase {
public:
+ ConditionalSimplifierTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
// Makes a computation that contains a conditional with constant predicate.
HloComputation* MakeConditional(HloModule* module);
};
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
index f213cc8709..498894737f 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
@@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
@@ -29,7 +29,7 @@ class ConvolutionFeatureGroupConverter : public HloPassInterface {
public:
ConvolutionFeatureGroupConverter() {}
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "convolution-feature-group-converter";
}
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index 3e39c1bab1..231d31d960 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/copy_insertion.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
@@ -31,18 +33,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
-
namespace {
+using absl::StrAppend;
+
bool IsEntryParameterValue(const HloValue& value) {
const HloComputation* computation = value.defining_instruction()->parent();
return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
@@ -381,7 +378,7 @@ class CopyRemover {
}
string ToString() const {
- string out = StrCat("CopyRemover, module ", module_->name(), "\n");
+ string out = absl::StrCat("CopyRemover, module ", module_->name(), "\n");
StrAppend(&out, " Buffer values, in dependency order:\n");
for (const HloBuffer& buffer : alias_analysis_.buffers()) {
StrAppend(&out, " HloBuffer ", buffer.id(), ":\n");
@@ -863,16 +860,16 @@ class CopyRemover {
for (const ValueNode* p = head; p != nullptr; p = Next(*p)) {
values.push_back(p->value);
}
- return StrCat("{",
- Join(values, ", ",
- [](string* s, const HloValue* value) {
- StrAppend(s, value->ToShortString());
- }),
- "}");
+ return absl::StrCat("{",
+ absl::StrJoin(values, ", ",
+ [](string* s, const HloValue* value) {
+ StrAppend(s, value->ToShortString());
+ }),
+ "}");
}
string ToString() const {
- string out = StrCat("BufferValueTracker:\n");
+ string out = absl::StrCat("BufferValueTracker:\n");
StrAppend(&out, " Def-use chains in each buffer:\n");
for (const ValueNode* head : value_lists_) {
StrAppend(&out, " Buffer defined by ", head->value->ToShortString(),
@@ -880,10 +877,10 @@ class CopyRemover {
const ValueNode* p = head;
do {
StrAppend(&out, " ", p->value->ToShortString(), ", uses: ",
- Join(p->uses, "; ",
- [](string* s, const HloUse* use) {
- StrAppend(s, use->ToString());
- }),
+ absl::StrJoin(p->uses, "; ",
+ [](string* s, const HloUse* use) {
+ StrAppend(s, use->ToString());
+ }),
"\n");
p = p->next;
diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h
index 5ba64b78a3..f797ee7e4d 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.h
+++ b/tensorflow/compiler/xla/service/copy_insertion.h
@@ -45,7 +45,7 @@ namespace xla {
// InstructionAliasSet::IsDistinct return true.
class CopyInsertion : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "copy-insertion"; }
+ absl::string_view name() const override { return "copy-insertion"; }
// fusion_can_share_buffer: backend specific function that decides whether a
// fusion can share buffer with its operand.
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 850948b54b..e01fecffd0 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -87,6 +87,8 @@ cc_library(
":parallel_task_assignment",
":simple_orc_jit",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ":target_machine_features",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/xla/service:scatter_expander",
"//tensorflow/compiler/xla:literal",
@@ -232,6 +234,7 @@ cc_library(
"//tensorflow/compiler/xla/service:tuple_points_to_analysis",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
"@llvm//:orc_jit",
],
)
@@ -279,6 +282,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:code_gen",
"@llvm//:core",
"@llvm//:support",
@@ -323,6 +327,7 @@ cc_library(
"//tensorflow/compiler/xla/service/cpu:cpu_runtime",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
],
)
@@ -365,6 +370,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
],
)
@@ -653,6 +659,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -816,6 +823,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
"//tensorflow/compiler/xla/service:hlo_pass",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -852,6 +860,7 @@ cc_library(
deps = [
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
index e6fd1499ed..59437e88af 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
@@ -38,7 +38,7 @@ class ConvCanonicalization : public HloPassInterface {
: target_machine_features_(*target_machine_features) {}
~ConvCanonicalization() override {}
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "convolution-canonicalization";
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 5116f926f5..279aa42fe2 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -27,6 +27,7 @@ limitations under the License.
// IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc"
// IWYU pragma: no_include "llvm/Config/Targets.def.inc"
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Triple.h"
#include "llvm/IR/Function.h"
@@ -101,8 +102,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace cpu {
@@ -235,15 +234,15 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx_;
const std::unordered_map<const HloInstruction*, int64>& assigned_indices_;
};
-} // namespace
-Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
- llvm::TargetMachine* target_machine) {
- LLVMTargetMachineFeatures target_machine_features(target_machine);
+} // namespace
- // Optimization pipeline.
- HloPassPipeline pipeline("CPU");
- pipeline.AddInvariantChecker<HloVerifier>();
+Status CpuCompiler::RunHloPassesThroughLayoutAssn(
+ HloModule* module, bool /*is_aot_compile*/,
+ LLVMTargetMachineFeatures* target_machine_features) {
+ HloPassPipeline pipeline("HLO passes through layout assignment");
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
pipeline.AddPass<CpuHloSupportChecker>();
ReducePrecisionInsertion::AddPasses(
@@ -260,11 +259,12 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pipeline.AddPass<BatchDotSimplification>();
pipeline.AddPass<DotDecomposer>();
pipeline.AddPass<ConvolutionFeatureGroupConverter>();
- pipeline.AddPass<ConvCanonicalization>(&target_machine_features);
+ pipeline.AddPass<ConvCanonicalization>(target_machine_features);
{
auto& pass =
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
- pass.AddInvariantChecker<HloVerifier>();
+ pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
pass.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/true,
@@ -291,10 +291,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
}
pipeline.AddPass<IndexedArrayAnalysisPrinterPass>();
pipeline.AddPass<TransposeFolding>(
- [&target_machine_features](
- const HloInstruction& dot,
+ [&](const HloInstruction& dot,
const TransposeFolding::OperandIndices& candidate_operands) {
- return PotentiallyImplementedAsEigenDot(dot, target_machine_features)
+ return PotentiallyImplementedAsEigenDot(dot, *target_machine_features)
? candidate_operands
: TransposeFolding::OperandIndices{};
},
@@ -309,12 +308,28 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
pipeline.AddPass<CpuLayoutAssignment>(
- module->mutable_entry_computation_layout(), &target_machine_features);
+ module->mutable_entry_computation_layout(), target_machine_features);
+ return pipeline.Run(module).status();
+}
+
+Status CpuCompiler::RunHloPassesAfterLayoutAssn(
+ HloModule* module, bool is_aot_compile,
+ LLVMTargetMachineFeatures* target_machine_features) {
+ HloPassPipeline pipeline("HLO passes after layout assignment");
+ // After layout assignment, use a layout-sensitive verifier.
+ auto& after_layout_assn =
+ pipeline.AddPass<HloPassPipeline>("after layout assignment");
+ after_layout_assn.AddInvariantChecker<HloVerifier>(
+ /*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
+
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
{
auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
- "after layout assignement");
+ "simplification after layout assignement");
+ pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
pass.AddPass<HloPassFix<AlgebraicSimplifier>>(
/*is_layout_sensitive=*/true,
[](const Shape&, const Shape&) { return true; },
@@ -322,7 +337,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pass.AddPass<HloDCE>();
pass.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
}
+
pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
+
// Outline ops in the entry computation into calls to subcomputations.
const int max_parallelism =
module->config().intra_op_parallelism_threads() > 0
@@ -335,14 +352,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
// binary size (and most AOT applications are single-threaded).
// TODO(b/29630486) Support multi-threaded AOT.
pipeline.AddPass<ParallelTaskAssigner>(
- max_parallelism, ShapeSizeBytesFunction(), &target_machine_features);
+ max_parallelism, ShapeSizeBytesFunction(), target_machine_features);
}
- // Copy insertion should be performed immediately before IR emission to avoid
- // inserting unnecessary copies (later pass adds an instruction which
- // materializes the value) or missing a necessary copy (later pass removes an
- // instruction which materializes a value). DCE must be run immediately before
- // (and sometime after) copy insertion, to avoid dead code from interfering
- // with the rewrites.
+ // Copy insertion should be performed immediately before IR emission to
+ // avoid inserting unnecessary copies (later pass adds an instruction which
+ // materializes the value) or missing a necessary copy (later pass removes
+ // an instruction which materializes a value). DCE must be run immediately
+ // before (and sometime after) copy insertion, to avoid dead code from
+ // interfering with the rewrites.
pipeline.AddPass<HloDCE>();
pipeline.AddPass<FlattenCallGraph>();
pipeline.AddPass<CpuCopyInsertion>();
@@ -350,6 +367,15 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
return pipeline.Run(module).status();
}
+Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
+ llvm::TargetMachine* target_machine) {
+ LLVMTargetMachineFeatures target_machine_features(target_machine);
+ TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(module, is_aot_compile,
+ &target_machine_features));
+ return RunHloPassesAfterLayoutAssn(module, is_aot_compile,
+ &target_machine_features);
+}
+
namespace {
// Align buffers to 16-byte boundaries.
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
index 04e1c48872..47b5edabff 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "llvm/Target/TargetMachine.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
+#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
@@ -157,6 +158,16 @@ class CpuCompiler : public LLVMCompiler {
Status RunHloPasses(HloModule* module, bool is_aot_compile,
llvm::TargetMachine* target_machine);
+ // Runs HLO passes up to and including layout assignment.
+ Status RunHloPassesThroughLayoutAssn(
+ HloModule* module, bool /*is_aot_compile*/,
+ LLVMTargetMachineFeatures* target_machine_features);
+
+ // Runs HLO passes after layout assignment.
+ Status RunHloPassesAfterLayoutAssn(
+ HloModule* module, bool is_aot_compile,
+ LLVMTargetMachineFeatures* target_machine_features);
+
TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler);
};
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
index 6398d8c98d..d49f7d7cc2 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
@@ -32,7 +32,7 @@ namespace xla {
// (module-scoped).
class CpuCopyInsertion : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "copy-insertion"; }
+ absl::string_view name() const override { return "copy-insertion"; }
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index c376864c3e..fbcbbbd200 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -22,6 +22,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
@@ -35,8 +37,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -177,12 +177,12 @@ Status CpuExecutable::ExecuteComputeFunction(
buffer_pointers.size(), profile_counters_size);
VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer);
auto ptr_printer = [](string* out, const void* p) {
- tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p));
+ absl::StrAppend(out, tensorflow::strings::Printf("%p", p));
};
VLOG(3) << " params = nullptr";
VLOG(3) << tensorflow::strings::Printf(
" temps = [%s]",
- tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str());
+ absl::StrJoin(buffer_pointers, ", ", ptr_printer).c_str());
VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p",
profile_counters);
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
index 2924b63659..6af724b2a5 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
@@ -28,9 +28,7 @@ class CpuHloSupportChecker : public HloPassInterface {
CpuHloSupportChecker() = default;
~CpuHloSupportChecker() override = default;
- tensorflow::StringPiece name() const override {
- return "cpu_hlo_support_checker";
- }
+ absl::string_view name() const override { return "cpu_hlo_support_checker"; }
// Note: always returns false (no instructions are ever modified by this
// pass).
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index e6130c7d76..c3e03056f0 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <set>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
@@ -773,8 +774,8 @@ class GatherLoopFusionTest
TEST_P(GatherLoopFusionTest, GatherLoopFusion) {
const GatherLoopFusionTestSpec& spec = GetParam();
- string hlo_string = tensorflow::strings::StrCat(
- "HloModule ", spec.test_name, "\n\n", spec.hlo_computation_text);
+ string hlo_string = absl::StrCat("HloModule ", spec.test_name, "\n\n",
+ spec.hlo_computation_text);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(hlo_string));
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
index 69acca86bf..bfecbd6e01 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
@@ -34,8 +34,8 @@ namespace cpu {
// instruction stream.
namespace {
-using ::absl::nullopt;
-using ::absl::optional;
+using absl::nullopt;
+using absl::optional;
using ShouldMakeOperandColMajorCache =
tensorflow::gtl::FlatMap<const HloInstruction*, bool>;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
index b6039b465e..b8ace57026 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
@@ -15,8 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace {
@@ -51,7 +52,7 @@ absl::optional<int64> LlvmIrGemvTilingFactor(const HloModuleConfig& config) {
auto it = extra_options_map.find(kLlvmIrDotTilingFactor);
int64 tiling_factor;
if (it != extra_options_map.end() &&
- tensorflow::strings::safe_strto64(it->second, &tiling_factor)) {
+ absl::SimpleAtoi(it->second, &tiling_factor)) {
return tiling_factor;
}
return absl::nullopt;
@@ -63,8 +64,8 @@ bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) {
return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0;
}
-static tensorflow::StringPiece RemoveSuffix(tensorflow::StringPiece str,
- tensorflow::StringPiece suffix) {
+static absl::string_view RemoveSuffix(absl::string_view str,
+ absl::string_view suffix) {
CHECK_GE(str.size(), suffix.size());
CHECK_EQ(str.substr(str.size() - suffix.size()), suffix);
return str.substr(0, str.size() - suffix.size());
@@ -79,22 +80,21 @@ absl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
return absl::nullopt;
}
- std::vector<string> tile_components =
- tensorflow::str_util::Split(it->second, ':');
+ std::vector<string> tile_components = absl::StrSplit(it->second, ':');
CHECK_EQ(tile_components.size(), 3);
int64 tile_size_m;
int64 tile_size_k;
int64 tile_size_n_in_vector_width;
- CHECK(tensorflow::strings::safe_strto64(tile_components[0], &tile_size_m));
- CHECK(tensorflow::strings::safe_strto64(tile_components[1], &tile_size_k));
+ CHECK(absl::SimpleAtoi(tile_components[0], &tile_size_m));
+ CHECK(absl::SimpleAtoi(tile_components[1], &tile_size_k));
- tensorflow::StringPiece tile_size_n_in_vector_width_str =
+ absl::string_view tile_size_n_in_vector_width_str =
RemoveSuffix(tile_components[2], "*vectwidth");
- CHECK(tensorflow::strings::safe_strto64(tile_size_n_in_vector_width_str,
- &tile_size_n_in_vector_width));
+ CHECK(absl::SimpleAtoi(tile_size_n_in_vector_width_str,
+ &tile_size_n_in_vector_width));
return std::tuple<int64, int64, int64>(tile_size_m, tile_size_k,
tile_size_n_in_vector_width);
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index 797392c265..4af16f4fa0 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
@@ -146,9 +147,9 @@ class GemvConfig {
bool has_addend() const { return has_addend_; }
string GetCacheKey() const {
- return tensorflow::strings::StrCat(
- name_, "_", PrimitiveType_Name(scalar_type()), "_", tile_rows(), "_",
- tile_cols(), "_", m(), "_", k(), has_addend() ? "_with_addend" : "");
+ return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_",
+ tile_rows(), "_", tile_cols(), "_", m(), "_", k(),
+ has_addend() ? "_with_addend" : "");
}
protected:
@@ -642,9 +643,7 @@ class TiledSmallGemmEmitter {
int64 k() const { return k_; }
int64 n() const { return n_; }
- string ToString() const {
- return tensorflow::strings::StrCat(m(), "x", k(), "x", n());
- }
+ string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); }
private:
const int64 m_;
@@ -687,10 +686,10 @@ class TiledSmallGemmEmitter {
tile_size_k_(tile_size_k) {}
string GetCacheKey() const {
- return tensorflow::strings::StrCat(
- "gemm_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(),
- "_", max_vectorization_width(), "_", min_vectorization_width(), "_",
- tile_size_m(), "_", tile_size_k());
+ return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_",
+ dims().ToString(), "_", max_vectorization_width(),
+ "_", min_vectorization_width(), "_", tile_size_m(),
+ "_", tile_size_k());
}
PrimitiveType scalar_type() const { return scalar_type_; }
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
index 05322faa75..4c2041b556 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_
+#include "absl/strings/string_view.h"
#include "llvm/IR/IRBuilder.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 6f433b4f30..417a1dba1f 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/logging.h"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "absl/strings/str_cat.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/BasicBlock.h"
@@ -67,7 +68,6 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
@@ -502,7 +502,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
llvm::Value* IrEmitter::EmitElementalMap(
const HloMapInstruction& map_instr,
tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
- tensorflow::StringPiece name) {
+ absl::string_view name) {
return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name);
}
@@ -846,7 +846,7 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
loops
.AddLoop(
0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
- tensorflow::strings::StrCat("k", i))
+ absl::StrCat("k", i))
->GetIndVarValue();
}
llvm::Value* input_feature =
@@ -2118,7 +2118,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
gtl::ArraySlice<HloInstruction*> operands(custom_call->operands());
- tensorflow::StringPiece custom_call_target(custom_call->custom_call_target());
+ absl::string_view custom_call_target(custom_call->custom_call_target());
llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
llvm::AllocaInst* operands_alloca =
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
@@ -2687,9 +2687,8 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
auto buf_it = thread_local_buffers_.find(key);
if (buf_it == thread_local_buffers_.end()) {
llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry(
- IrShapeType(shape),
- tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_,
- MinimumAlignmentForShape(target_shape));
+ IrShapeType(shape), absl::StrCat("thread_local", slice.ToString()),
+ &b_, MinimumAlignmentForShape(target_shape));
auto it_inserted_pair = thread_local_buffers_.insert({key, buffer});
CHECK(it_inserted_pair.second);
buf_it = it_inserted_pair.first;
@@ -2753,7 +2752,7 @@ Status IrEmitter::EmitTargetElementLoop(
}
Status IrEmitter::EmitTargetElementLoop(
- HloInstruction* target_op, tensorflow::StringPiece desc,
+ HloInstruction* target_op, absl::string_view desc,
const llvm_ir::ElementGenerator& element_generator) {
VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString();
@@ -2848,7 +2847,7 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
llvm::Value* IrEmitter::EmitThreadLocalCall(
const HloComputation& callee,
tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
- tensorflow::StringPiece name) {
+ absl::string_view name) {
const Shape& return_shape = callee.root_instruction()->shape();
// Lifting this restriction to allow "small" arrays should be easy. Allowing
@@ -2869,7 +2868,7 @@ llvm::Value* IrEmitter::EmitThreadLocalCall(
llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(return_type, module_),
- tensorflow::strings::StrCat(name, "_retval_addr"), &b_,
+ absl::StrCat(name, "_retval_addr"), &b_,
MinimumAlignmentForPrimitiveType(return_type));
b_.CreateCall(
@@ -2886,7 +2885,7 @@ llvm::Value* IrEmitter::EmitThreadLocalCall(
}
void IrEmitter::EmitGlobalCall(const HloComputation& callee,
- tensorflow::StringPiece name) {
+ absl::string_view name) {
b_.CreateCall(FindOrDie(emitted_functions_, &callee),
GetArrayFunctionCallArguments(
/*parameter_addresses=*/{}, &b_, name,
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index c9a1dab62d..99c080b3db 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -23,6 +23,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/strings/string_view.h"
#include "llvm/ADT/Triple.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
@@ -44,7 +45,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
@@ -107,7 +107,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
llvm::Value* EmitElementalMap(
const HloMapInstruction& map_instr,
tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
- tensorflow::StringPiece name);
+ absl::string_view name);
protected:
//
@@ -239,7 +239,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// function that a map operation applies.
StatusOr<llvm::Function*> EmitFunction(
HloComputation* function, // The function to emit.
- tensorflow::StringPiece
+ absl::string_view
function_name_suffix); // Used for LLVM IR register names.
// Emits a call to a thread local function (e.g. to the computation nested
@@ -251,14 +251,13 @@ class IrEmitter : public DfsHloVisitorWithDefault {
llvm::Value* EmitThreadLocalCall(
const HloComputation& callee,
tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
- tensorflow::StringPiece name);
+ absl::string_view name);
// Emits a call to a "global" function (e.g. to the computation nested within
// a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to
// the parameters and return values for these computations so there is no need
// to explicitly pass parameters or return results.
- void EmitGlobalCall(const HloComputation& callee,
- tensorflow::StringPiece name);
+ void EmitGlobalCall(const HloComputation& callee, absl::string_view name);
// Returns the buffer to which a global call to `callee` would have written
// its result.
@@ -285,7 +284,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
HloInstruction* target_op,
const llvm_ir::ElementGenerator& element_generator);
Status EmitTargetElementLoop(
- HloInstruction* target_op, tensorflow::StringPiece desc,
+ HloInstruction* target_op, absl::string_view desc,
const llvm_ir::ElementGenerator& element_generator);
// Emits a memcpy from the source instruction's result value to the
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc
index 2db4d000f5..784045313d 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/ir_function.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
@@ -189,7 +190,7 @@ void IrFunction::Initialize(const string& function_name,
llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
CHECK_GT(num_dynamic_loop_bounds_, 0);
CHECK_LT(offset, num_dynamic_loop_bounds_ * 2);
- string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset);
+ string name = absl::StrCat("dynamic_loop_bound_", offset);
return b_->CreateLoad(b_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_),
b_->getInt64(offset), AsStringRef(name)));
}
@@ -200,7 +201,7 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
// address buffer).
std::vector<llvm::Value*> GetArrayFunctionCallArguments(
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::IRBuilder<>* b, tensorflow::StringPiece name,
+ llvm::IRBuilder<>* b, absl::string_view name,
llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) {
llvm::Value* parameter_addresses_buffer;
@@ -211,13 +212,13 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments(
} else {
parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()),
- tensorflow::strings::StrCat(name, "_parameter_addresses"), b);
+ absl::StrCat(name, "_parameter_addresses"), b);
for (size_t i = 0; i < parameter_addresses.size(); ++i) {
llvm::Value* parameter_as_i8ptr =
b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(),
- AsStringRef(tensorflow::strings::StrCat(
- name, "_parameter_", i, "_address_as_i8ptr")));
+ AsStringRef(absl::StrCat(name, "_parameter_", i,
+ "_address_as_i8ptr")));
llvm::Value* slot_in_param_addresses =
b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)});
b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
@@ -320,8 +321,7 @@ Status EmitCallToParallelForkJoin(
/*Linkage=*/llvm::GlobalValue::PrivateLinkage,
/*Initializer=*/partitions_array,
/*Name=*/
- AsStringRef(
- tensorflow::strings::StrCat(name, "_parallel_dimension_partitions")));
+ AsStringRef(absl::StrCat(name, "_parallel_dimension_partitions")));
// Add argument specifying parallel dimension partitions.
fork_join_arguments.push_back(
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h
index a41cbb64cd..ee7595f6e9 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.h
@@ -116,7 +116,7 @@ class IrFunction {
// Returns an array of compute function call argument ir values.
std::vector<llvm::Value*> GetArrayFunctionCallArguments(
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::IRBuilder<>* b, tensorflow::StringPiece name,
+ llvm::IRBuilder<>* b, absl::string_view name,
llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg);
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
index 8560e4296a..aedb069dce 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
@@ -30,8 +30,8 @@ ParallelLoopEmitter::ParallelLoopEmitter(
dynamic_loop_bounds_(dynamic_loop_bounds) {}
std::vector<llvm_ir::IrArray::Index>
-ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name, llvm::Type* index_type) {
+ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
+ llvm::Type* index_type) {
CHECK_NE(index_type, nullptr);
CHECK(!ShapeUtil::IsTuple(shape_));
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
index 076c683ca5..a604e1db22 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
@@ -61,7 +61,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
~ParallelLoopEmitter() override = default;
std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name, llvm::Type* index_type) override;
+ absl::string_view loop_name, llvm::Type* index_type) override;
private:
const DynamicLoopBounds* dynamic_loop_bounds_;
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
index 286d407ca6..b4c0c09ec0 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
@@ -217,8 +218,7 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper(
// Outline 'instruction' in 'computation' for parallel task assignment.
auto* call = module->OutlineExpressionFromComputation(
- {instruction},
- tensorflow::strings::StrCat("parallel_", instruction->name()),
+ {instruction}, absl::StrCat("parallel_", instruction->name()),
computation);
// Set assigned dimension partitioning to 'instruction'.
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
index 8becc8fa23..a99cd99c14 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
@@ -73,7 +73,7 @@ class ParallelTaskAssigner : public HloPassInterface {
target_machine_features_(*target_machine_features) {}
~ParallelTaskAssigner() override {}
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "cpu-parallel-task-assigner";
}
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
index ee272b5f4f..a84ee78b19 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
@@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace {
@@ -36,7 +35,9 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase {
cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_;
ParallelTaskAssignmentTest()
- : target_machine_features_([](int64 shape_size) {
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false),
+ target_machine_features_([](int64 shape_size) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
}) {}
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index b026aef3fe..bf98064647 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -170,15 +170,14 @@ namespace {
bool RegisterKnownJITSymbols() {
CustomCallTargetRegistry* registry = CustomCallTargetRegistry::Global();
-#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \
- do { \
- auto* function_address = \
- reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \
- registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \
- function_address); \
- CHECK_EQ( \
- tensorflow::StringPiece(xla::cpu::runtime::k##base_name##SymbolName), \
- "__xla_cpu_runtime_" #base_name); \
+#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \
+ do { \
+ auto* function_address = \
+ reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \
+ registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \
+ function_address); \
+ CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \
+ "__xla_cpu_runtime_" #base_name); \
} while (false)
REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue);
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index 4635fa5d74..2384166fd2 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -110,6 +110,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
],
)
@@ -123,6 +124,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
index 6fcce42eaa..fcd87b36b3 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
@@ -19,10 +19,10 @@ limitations under the License.
#include <cctype>
#include <string>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc
index 973aac8766..9457e57d7b 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc
@@ -17,10 +17,10 @@ limitations under the License.
#include <cctype>
#include <string>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@@ -32,9 +32,9 @@ const char* const kTriple_android_arm = "armv7-none-android";
struct IntrinsicTestSpec {
HloOpcode opcode;
- tensorflow::StringPiece triple;
- tensorflow::StringPiece features;
- tensorflow::StringPiece check_lines;
+ absl::string_view triple;
+ absl::string_view features;
+ absl::string_view check_lines;
};
// Tests that unary functions get lowered using intrinsic calls.
@@ -65,9 +65,8 @@ class CpuUnaryIntrinsicTest
features = "";
}
- return tensorflow::strings::StrCat(opcode.c_str(), "_On_", triple.c_str(),
- features.empty() ? "" : "_With",
- features.c_str());
+ return absl::StrCat(opcode.c_str(), "_On_", triple.c_str(),
+ features.empty() ? "" : "_With", features.c_str());
}
};
diff --git a/tensorflow/compiler/xla/service/defuser.h b/tensorflow/compiler/xla/service/defuser.h
index 56b28fd22d..c326beb899 100644
--- a/tensorflow/compiler/xla/service/defuser.h
+++ b/tensorflow/compiler/xla/service/defuser.h
@@ -29,7 +29,7 @@ class Defuser : public HloPassInterface {
public:
Defuser() {}
~Defuser() override {}
- tensorflow::StringPiece name() const override { return "defuser"; }
+ absl::string_view name() const override { return "defuser"; }
// Run defusion on the given module. Returns whether the module was
// changed.
diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc
index e727ba49cb..37d1895d41 100644
--- a/tensorflow/compiler/xla/service/defuser_test.cc
+++ b/tensorflow/compiler/xla/service/defuser_test.cc
@@ -26,6 +26,11 @@ namespace xla {
namespace {
class DefuserTest : public HloVerifiedTestBase {
+ public:
+ DefuserTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
// Returns the number of fusion instructions in the module.
int FusionCount() {
diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc
index 48e4471499..ba2a674d9a 100644
--- a/tensorflow/compiler/xla/service/despecializer.cc
+++ b/tensorflow/compiler/xla/service/despecializer.cc
@@ -27,9 +27,7 @@ namespace {
class ControlDepRemover : public HloPassInterface {
public:
ControlDepRemover() = default;
- tensorflow::StringPiece name() const override {
- return "control-dep-remover";
- }
+ absl::string_view name() const override { return "control-dep-remover"; }
StatusOr<bool> Run(HloModule* module) override {
bool changed = false;
diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h
index cc1695b7f8..7be70add2f 100644
--- a/tensorflow/compiler/xla/service/despecializer.h
+++ b/tensorflow/compiler/xla/service/despecializer.h
@@ -33,7 +33,7 @@ namespace xla {
class Despecializer : public HloPassInterface {
public:
Despecializer();
- tensorflow::StringPiece name() const override { return "despecializer"; }
+ absl::string_view name() const override { return "despecializer"; }
StatusOr<bool> Run(HloModule* module) override;
private:
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 690b5df514..275e6cc61d 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -19,13 +19,13 @@ limitations under the License.
#include <type_traits>
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index 20c6bafe7c..6ec4893f7a 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -16,13 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h
index 1959b687f1..fc38e31700 100644
--- a/tensorflow/compiler/xla/service/dot_decomposer.h
+++ b/tensorflow/compiler/xla/service/dot_decomposer.h
@@ -29,7 +29,7 @@ class DotDecomposer : public HloPassInterface {
DotDecomposer(bool decompose_batch_dot = true)
: decompose_batch_dot_(decompose_batch_dot) {}
~DotDecomposer() = default;
- tensorflow::StringPiece name() const override { return "dot_decomposer"; }
+ absl::string_view name() const override { return "dot_decomposer"; }
// Run DotDecomposer pass on computations in 'module'.
// Returns whether the 'module' was changed.
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index cc7a87f9e8..26af67cc1c 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -22,6 +22,7 @@ limitations under the License.
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
#include "absl/algorithm/container.h"
+#include "absl/strings/str_cat.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
@@ -39,17 +40,16 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/random/random.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
+using absl::StrCat;
using llvm_ir::AsStringRef;
using llvm_ir::IrArray;
using llvm_ir::IrName;
using llvm_ir::SetToFirstInsertPoint;
-using tensorflow::strings::StrCat;
namespace {
@@ -306,18 +306,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
{operand_value->getType()}, b_);
}
case HloOpcode::kSign: {
- bool is_signed =
- primitive_util::IsSignedIntegralType(op->shape().element_type());
+ CHECK(primitive_util::IsSignedIntegralType(op->shape().element_type()))
+ << op->shape().element_type();
auto type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
auto cmp = b_->CreateICmpEQ(operand_value, GetZero(type));
- if (is_signed) {
- auto ashr =
- b_->CreateAShr(operand_value, type->getIntegerBitWidth() - 1);
- return Select(cmp, GetZero(type), b_->CreateOr(ashr, 1));
- } else {
- return Select(cmp, GetZero(type), GetOne(type));
- }
+ auto ashr = b_->CreateAShr(operand_value, type->getIntegerBitWidth() - 1);
+ return Select(cmp, GetZero(type), b_->CreateOr(ashr, 1));
}
case HloOpcode::kNegate:
return b_->CreateNeg(operand_value);
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.h b/tensorflow/compiler/xla/service/flatten_call_graph.h
index d3efab3614..3cccec9862 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph.h
+++ b/tensorflow/compiler/xla/service/flatten_call_graph.h
@@ -28,7 +28,7 @@ namespace xla {
// points-to analysis (see b/36865746 for details).
class FlattenCallGraph : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "flatten-call-graph"; }
+ absl::string_view name() const override { return "flatten-call-graph"; }
// Duplicates computations called from multiple call- or while-nodes to
// flatten the call graph.
diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h
index c1fc8574da..7bd9ea5984 100644
--- a/tensorflow/compiler/xla/service/gather_expander.h
+++ b/tensorflow/compiler/xla/service/gather_expander.h
@@ -25,7 +25,7 @@ namespace xla {
// nevertheless have a minimum level of support.
class GatherExpander : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "gather_expander"; }
+ absl::string_view name() const override { return "gather_expander"; }
StatusOr<bool> Run(HloModule* module) override;
private:
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index fbef487ac8..e53f525517 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -129,6 +129,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
],
)
@@ -186,6 +187,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@llvm//:core",
"@llvm//:support",
@@ -231,6 +233,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:math_ops",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
"@llvm//:support",
],
@@ -347,6 +350,7 @@ cc_library(
"//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep
"//tensorflow/stream_executor",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
@@ -384,6 +388,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
@@ -402,6 +407,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
],
)
@@ -496,6 +502,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -527,6 +534,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
],
)
@@ -687,6 +695,7 @@ cc_library(
"//tensorflow/core:regexp_internal",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@llvm//:core",
],
@@ -775,6 +784,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep
+ "@com_google_absl//absl/strings",
],
)
@@ -888,9 +898,8 @@ cc_library(
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:hlo_parser",
- "//tensorflow/compiler/xla/service:hlo_runner",
- "//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc
index 6a285a6b98..f22c2a8add 100644
--- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc
+++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc
@@ -16,9 +16,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
#include <cmath>
+#include "absl/strings/str_replace.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace gpu {
@@ -74,9 +74,8 @@ ENTRY MaxDifference {
%error = f32[SIZE] divide(%sub_abs, %denominator)
ROOT %max_diff = f32[] reduce(%error, %zero_constant), dimensions={0}, to_apply=MaxF32
})";
- auto size_string = std::to_string(num_elements);
- return tensorflow::str_util::StringReplace(
- kF16CompHloText, "SIZE", {size_string.data(), size_string.size()}, true);
+ return absl::StrReplaceAll(kF16CompHloText,
+ {{"SIZE", absl::StrCat(num_elements)}});
}
StatusOr<F16BufferComparator> F16BufferComparator::Create(
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 7833a4077e..854a2f50b2 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -17,11 +17,11 @@ limitations under the License.
#include <string>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
index e09cde9abf..6e2e330edd 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
@@ -54,9 +54,7 @@ namespace gpu {
// BatchNormRewriter.
class CudnnBatchNormRewriter : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override {
- return "cudnn_batchnorm_rewriter";
- }
+ absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; }
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
index 7b172812c3..18a76e8c26 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
@@ -17,11 +17,11 @@ limitations under the License.
#include <string>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 5a8fc76e85..3d421ebb69 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
+#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
@@ -21,7 +22,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
namespace xla {
@@ -128,14 +128,14 @@ std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
string AlgorithmToString(const AlgorithmDesc& algo) {
if (algo.tensor_ops_enabled()) {
- return tensorflow::strings::StrCat(algo.algo_id(), "+TC");
+ return absl::StrCat(algo.algo_id(), "+TC");
}
- return tensorflow::strings::StrCat(algo.algo_id());
+ return absl::StrCat(algo.algo_id());
}
string NumBytesToString(int64 bytes) {
- return tensorflow::strings::StrCat(
- tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)");
+ return absl::StrCat(tensorflow::strings::HumanReadableNumBytes(bytes), " (",
+ bytes, "B)");
}
// Acquires a process-global lock on the device pointed to by the given
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index 472de2ff0f..f76d273e8c 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -39,7 +39,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
Compiler* compiler)
: stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {}
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "cudnn-convolution-algorithm-picker";
}
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
index 0c0578d888..fbe7e98494 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
@@ -26,7 +26,7 @@ namespace gpu {
// backwards-input convolutions into CustomCall HLOs that call into cuDNN.
class CudnnConvolutionRewriter : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "cudnn-convolution-rewriter";
}
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 7b0d9e53d6..68086c86e9 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -56,7 +57,7 @@ class ScratchBufAllocator : public se::ScratchAllocator {
"Can't allocate twice from a ScratchBufAllocator.");
}
if (byte_size > scratch_.size()) {
- return se::port::InternalError(tensorflow::strings::StrCat(
+ return se::port::InternalError(absl::StrCat(
"Can't allocate ", byte_size,
" bytes from a ScratchBufAllocator of size ", scratch_.size()));
}
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index 9b6de115ad..2460d951bd 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -23,6 +23,8 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
// IWYU pragma: no_include "llvm/IR/Attributes.gen.inc"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "llvm/ADT/APInt.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Instructions.h"
@@ -43,16 +45,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace gpu {
+using absl::StrAppend;
using llvm_ir::IrArray;
using llvm_ir::IrName;
using llvm_ir::SetToFirstInsertPoint;
-using tensorflow::strings::StrAppend;
namespace {
// Returns whether operand is a floating-point literal with the given value.
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
index 0cdddf8bcf..def595d217 100644
--- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
@@ -17,10 +17,10 @@ limitations under the License.
#include <string>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
index 9b86e5315b..1bd88233e1 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
@@ -19,12 +19,12 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace gpu {
@@ -289,11 +289,10 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
<< " flops_to_bytes_ratio: " << CalculateFlopsToBytesRatio(fusion)
<< " merged_to_current_bytes_ratio: " << merged_to_current_bytes_ratio
<< " into users { "
- << tensorflow::str_util::Join(users, ", ",
- [](string* out, HloInstruction* user) {
- tensorflow::strings::StrAppend(
- out, user->name());
- })
+ << absl::StrJoin(users, ", ",
+ [](string* out, HloInstruction* user) {
+ absl::StrAppend(out, user->name());
+ })
<< " }";
// Remove 'fusion' instruction.
CHECK_EQ(0, fusion->user_count());
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
index 4c523a66de..7e3f5775b8 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
@@ -34,7 +34,7 @@ namespace gpu {
//
class FusionMerger : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "fusion merger"; }
+ absl::string_view name() const override { return "fusion merger"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
index 74282c568c..2c02ec2584 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include <functional>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
index 0c6f9b511f..8ffae18fe8 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
@@ -27,7 +27,7 @@ namespace gpu {
// inserting kCopy instructions.
class GpuCopyInsertion : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "copy-insertion"; }
+ absl::string_view name() const override { return "copy-insertion"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index 09a1d9c12b..627a05e240 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
index d63e213d2b..bbb3340760 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
@@ -28,9 +28,7 @@ class GpuHloSupportChecker : public HloPassInterface {
GpuHloSupportChecker() = default;
~GpuHloSupportChecker() override = default;
- tensorflow::StringPiece name() const override {
- return "gpu_hlo_support_checker";
- }
+ absl::string_view name() const override { return "gpu_hlo_support_checker"; }
// Note: always returns false (no instructions are ever modified by this
// pass).
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
index 286547ebae..fbc8ddf599 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
@@ -119,7 +120,7 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) {
for (const Shape& input_shape : AllLayoutsOf(shape)) {
for (const Shape& result_shape : AllLayoutsOf(shape)) {
- SCOPED_TRACE(tensorflow::strings::StrCat(
+ SCOPED_TRACE(absl::StrCat(
"input_shape=", ShapeUtil::HumanStringWithLayout(input_shape),
", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape)));
@@ -192,7 +193,7 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) {
// Enumerate all combinations of shapes.
for (const Shape& input_shape : AllLayoutsOf(shape)) {
for (const Shape& result_shape : AllLayoutsOf(shape)) {
- SCOPED_TRACE(tensorflow::strings::StrCat(
+ SCOPED_TRACE(absl::StrCat(
"input_shape=", ShapeUtil::HumanStringWithLayout(input_shape),
", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape)));
@@ -265,7 +266,7 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) {
for (const Shape& input_shape : AllLayoutsOf(shape)) {
for (const Shape& result_shape : AllLayoutsOf(shape)) {
for (int constrained_param_no : {0, 4}) {
- SCOPED_TRACE(tensorflow::strings::StrCat(
+ SCOPED_TRACE(absl::StrCat(
"input_shape=", ShapeUtil::HumanStringWithLayout(input_shape),
", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape)));
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
index 8c11cd0541..0e205b9c02 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
+#include "absl/strings/str_cat.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
@@ -24,16 +25,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace gpu {
-using tensorflow::strings::StrAppend;
-using tensorflow::strings::StrCat;
+using absl::StrAppend;
+using absl::StrCat;
void HloToIrBindings::EmitBasePointersForHlos(
tensorflow::gtl::ArraySlice<const HloInstruction*> io_hlos,
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index c349063c71..f544bcc919 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -215,7 +215,7 @@ bool IsReductionToVector(const HloInstruction& reduce) {
// This emits a device-side call to
// "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see
// http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
-llvm::Value* EmitPrintf(tensorflow::StringPiece fmt,
+llvm::Value* EmitPrintf(absl::string_view fmt,
tensorflow::gtl::ArraySlice<llvm::Value*> arguments,
llvm::IRBuilder<>* builder) {
std::vector<llvm::Type*> argument_types;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index 5d23a3d018..a35e250101 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -126,7 +126,7 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo);
bool IsReductionToVector(const HloInstruction& reduce);
// Emits call to "vprintf" with given format and arguments.
-llvm::Value* EmitPrintf(tensorflow::StringPiece fmt,
+llvm::Value* EmitPrintf(absl::string_view fmt,
tensorflow::gtl::ArraySlice<llvm::Value*> arguments,
llvm::IRBuilder<>* builder);
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 561c683879..76e069fc41 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/string_view.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
@@ -40,7 +41,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index bda2986202..84043689bd 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/BasicBlock.h"
@@ -90,10 +91,10 @@ namespace {
using absl::InlinedVector;
using absl::nullopt;
using absl::optional;
+using absl::StrCat;
using llvm_ir::IrArray;
using llvm_ir::IrName;
using tensorflow::gtl::ArraySlice;
-using tensorflow::strings::StrCat;
// If a dimensions is smaller than this, untiled transposition may be more
// efficient.
@@ -801,8 +802,7 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// // RoundUpToNextMultipleOf(Ceil(num_elems / kTileSize), warpSize),
// //
// // and threads_per_block is a multiple of warpSize.
- // reduce_kernel<<<num_blocks, threads_per_block>>>();
- //
+ // reduce_kernel //
auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status {
const int num_reduces = reducers.size();
llvm::Type* element_ir_type =
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
index 6305396635..d856299889 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
@@ -16,11 +16,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -41,8 +41,8 @@ Status KernelThunk::Initialize(const GpuExecutable& executable,
tensorflow::mutex_lock lock(mutex_);
if (!loader_spec_) {
loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size()));
- tensorflow::StringPiece ptx = executable.ptx();
- // Convert tensorflow::StringPiece to se::port::StringPiece because
+ absl::string_view ptx = executable.ptx();
+ // Convert absl::string_view to se::port::StringPiece because
// StreamExecutor uses the latter.
loader_spec_->AddCudaPtxInMemory(
se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_);
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
index 6bd9c58f83..ccf082c4c6 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
@@ -35,6 +35,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
"@llvm//:amdgpu_code_gen",
"@llvm//:analysis",
"@llvm//:bit_reader",
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc
index 12a8a59488..a3c74507dd 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc
@@ -15,12 +15,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h"
+#include "absl/strings/string_view.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
@@ -86,7 +86,7 @@ void IrDumpingPassManager::run(llvm::Module &module) {
const llvm::PassInfo *PI =
llvm::PassRegistry::getPassRegistry()->getPassInfo(P->getPassID());
const string basename = ReplaceFilenameExtension(
- tensorflow::io::Basename(input_filename_),
+ absl::string_view(tensorflow::io::Basename(input_filename_)),
tensorflow::strings::Printf(
"pass-%02d.before.%s.ll", i,
(PI == nullptr ? "unknown" : PI->getPassArgument().data())));
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
index cce6e48141..e18d7e764a 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
@@ -27,6 +27,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
@@ -54,9 +56,7 @@ limitations under the License.
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "llvm/Transforms/Scalar.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -107,8 +107,7 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path,
<< ", " << compute_capability.second << ") ."
<< "Defaulting to libdevice for compute_" << libdevice_version;
}
- return tensorflow::strings::StrCat("libdevice.compute_", libdevice_version,
- ".10.bc");
+ return absl::StrCat("libdevice.compute_", libdevice_version, ".10.bc");
}
// Gets the GPU name as it's known to LLVM for a given compute capability. If
@@ -138,15 +137,16 @@ static string GetSmName(std::pair<int, int> compute_capability) {
<< "Defaulting to telling LLVM that we're compiling for sm_"
<< sm_version;
}
- return tensorflow::strings::StrCat("sm_", sm_version);
+ return absl::StrCat("sm_", sm_version);
}
// Convenience function for producing a name of a temporary compilation product
// from the input filename.
string MakeNameForTempProduct(const std::string& input_filename,
- tensorflow::StringPiece extension) {
- return ReplaceFilenameExtension(
- tensorflow::io::Basename(llvm_ir::AsString(input_filename)), extension);
+ absl::string_view extension) {
+ return ReplaceFilenameExtension(absl::string_view(tensorflow::io::Basename(
+ llvm_ir::AsString(input_filename))),
+ extension);
}
// Initializes LLVM passes. Uses the PassRegistry mechanism.
@@ -167,7 +167,7 @@ void InitializePasses(llvm::PassRegistry* pass_registry) {
// Returns the TargetMachine, given a triple.
std::unique_ptr<llvm::TargetMachine> GetTargetMachine(
- llvm::Triple triple, tensorflow::StringPiece cpu_name,
+ llvm::Triple triple, absl::string_view cpu_name,
const HloModuleConfig& hlo_module_config) {
std::string error;
const llvm::Target* target = TargetRegistry::lookupTarget("", triple, error);
@@ -243,9 +243,9 @@ void AddOptimizationPasses(unsigned opt_level, unsigned size_level,
}
// Emits the given module to a bit code file.
-void EmitBitcodeToFile(const Module& module, tensorflow::StringPiece filename) {
+void EmitBitcodeToFile(const Module& module, absl::string_view filename) {
std::error_code error_code;
- llvm::ToolOutputFile outfile(filename.ToString().c_str(), error_code,
+ llvm::ToolOutputFile outfile(string(filename).c_str(), error_code,
llvm::sys::fs::F_None);
if (error_code) {
LOG(FATAL) << "opening bitcode file for writing: " << error_code.message();
@@ -266,8 +266,9 @@ string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) {
// get creative to add a suffix.
string module_id(llvm_ir::AsString(module->getModuleIdentifier()));
IrDumpingPassManager codegen_passes(
- ReplaceFilenameExtension(tensorflow::io::Basename(module_id),
- "-nvptx.dummy"),
+ ReplaceFilenameExtension(
+ absl::string_view(tensorflow::io::Basename(module_id)),
+ "-nvptx.dummy"),
"", false);
codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass(
llvm::Triple(module->getTargetTriple())));
@@ -332,8 +333,8 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module,
return !GV.hasName() || (GVS.count(GV.getName()) == 0);
});
})) {
- return tensorflow::errors::Internal(tensorflow::strings::StrCat(
- "Error linking libdevice from ", libdevice_path));
+ return tensorflow::errors::Internal(
+ absl::StrCat("Error linking libdevice from ", libdevice_path));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h
index 54e0e140de..9654175bfa 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h
@@ -20,11 +20,11 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/strings/string_view.h"
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
namespace gpu {
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc
index 9ef9bc3a50..3b2c3591d9 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc
@@ -17,13 +17,13 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/SourceMgr.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace {
@@ -52,14 +52,13 @@ std::unique_ptr<llvm::Module> LoadIRModule(const string& filename,
return module;
}
-string ReplaceFilenameExtension(tensorflow::StringPiece filename,
- tensorflow::StringPiece new_extension) {
+string ReplaceFilenameExtension(absl::string_view filename,
+ absl::string_view new_extension) {
auto pos = filename.rfind('.');
- tensorflow::StringPiece stem =
- pos == tensorflow::StringPiece::npos
- ? filename
- : tensorflow::StringPiece(filename.data(), pos);
- return tensorflow::strings::StrCat(stem, ".", new_extension);
+ absl::string_view stem = pos == absl::string_view::npos
+ ? filename
+ : absl::string_view(filename.data(), pos);
+ return absl::StrCat(stem, ".", new_extension);
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h
index a6daeca95a..60f4926849 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace llvm {
class LLVMContext;
@@ -41,8 +41,8 @@ std::unique_ptr<llvm::Module> LoadIRModule(const string& filename,
//
// For example:
// ReplaceFilenameExtension("/foo/baz.txt", "cc") --> "/foo/baz.cc"
-string ReplaceFilenameExtension(tensorflow::StringPiece filename,
- tensorflow::StringPiece new_extension);
+string ReplaceFilenameExtension(absl::string_view filename,
+ absl::string_view new_extension);
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index 5575f6c0c6..9fb6f569ae 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -49,7 +49,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
// If possible, we want to pick a reduce operand of the fusion root,
// because it has the most constraints.
for (const auto* inst : fused_expression_root->operands()) {
- if (inst->opcode() == HloOpcode::kReduce) {
+ if (IsReductionToVector(*inst)) {
return inst;
}
}
@@ -64,7 +64,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
auto get_element_shape = [&](const HloInstruction* element_instr) {
// Special handling of kReduce instructions -- the fusion
// applies to the first operand.
- if (element_instr->opcode() == HloOpcode::kReduce) {
+ if (IsReductionToVector(*element_instr)) {
return element_instr->operand(0)->shape();
}
return element_instr->shape();
@@ -141,10 +141,15 @@ bool ReduceFriendlyInputLayouts(HloInstruction* instr) {
} // namespace
bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) {
- // We can fuse reduces and loop fusions.
- return IsInputFusibleReduction(instr) ||
- (instr->opcode() == HloOpcode::kFusion &&
- instr->fusion_kind() == HloInstruction::FusionKind::kLoop);
+ // We can fuse reduces and loop fusions. Elementwise instructions can be fused
+ // with any other instruction.
+ // TODO(b/112957171): This should use the same isFusible logic as
+ // instruction_fusion.
+ return instr->IsFusable() &&
+ (IsInputFusibleReduction(instr) ||
+ (instr->opcode() == HloOpcode::kFusion &&
+ instr->fusion_kind() == HloInstruction::FusionKind::kLoop) ||
+ instr->IsElementwise());
}
int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1,
@@ -178,28 +183,16 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1,
// merge into bigger loop fusions and input (reduce) fusions become fusions
// with multiple reduce outputs. We could fuse reduce and loop fusions
// together too (the result being an input fusion) if we find cases where this
- // improves things.
+ // improves things. Also disable fusing standalone input-fusible reduces into
+ // loop fusions.
CHECK(instr1->opcode() == HloOpcode::kFusion);
if ((instr2->opcode() == HloOpcode::kFusion &&
instr1->fusion_kind() != instr2->fusion_kind()) ||
- (instr2->opcode() != HloOpcode::kFusion &&
+ (IsReductionToVector(*instr2) &&
instr1->fusion_kind() == HloInstruction::FusionKind::kLoop)) {
return false;
}
- // Multi-output loop fusions must have equal output shapes to be lowered.
- if (instr1->fusion_kind() == HloInstruction::FusionKind::kLoop) {
- Shape shape1 = instr1->IsMultiOutputFusion()
- ? instr1->shape().tuple_shapes(0)
- : instr1->shape();
- Shape shape2 = instr2->IsMultiOutputFusion()
- ? instr2->shape().tuple_shapes(0)
- : instr2->shape();
- if (!ShapeUtil::Equal(shape1, shape2)) {
- return false;
- }
- }
-
// Do this check last, as it may be expensive.
return !GpuInstructionFusion::FusionWouldBeTooLarge(instr1, instr2);
}
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
index 072f885bc1..c822c94f1b 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
@@ -15,19 +15,19 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-
-namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace gpu {
+namespace op = xla::testing::opcode_matchers;
+
using MultiOutputFusionTest = HloTestBase;
const char kModulePrefix[] = R"(
@@ -47,7 +47,7 @@ const char kModulePrefix[] = R"(
TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
// Fusion with reduce instruction root and a sibling reduce instruction
// sharing the same input param.
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation {
p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
@@ -74,7 +74,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
}
TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation_1 {
p1.1 = f32[6400]{0} parameter(1)
mul = f32[6400]{0} multiply(p1.1, p1.1)
@@ -101,7 +101,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
}
TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation_1 {
p1.1 = f32[10,10]{1,0} parameter(1)
mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
@@ -130,7 +130,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) {
TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceFusions) {
// Two sibling fusions with reduce instruction roots sharing the same input
// param.
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation_1 {
p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
@@ -165,7 +165,7 @@ TEST_F(MultiOutputFusionTest,
MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) {
// Multi-output fusion with two reduce instructions root and a sibling reduce
// instruction sharing the same input param.
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) {
const.1 = f32[] constant(1)
p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
@@ -198,7 +198,7 @@ TEST_F(MultiOutputFusionTest,
MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) {
// Verify that if we already have a multi-output fusion that we prefer to pick
// a reduce op from its operands for checking shape compatibility.
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation_1 {
p1.1 = f32[10,10]{1,0} parameter(1)
mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
@@ -228,7 +228,7 @@ TEST_F(MultiOutputFusionTest,
}
TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation_1 {
p0.1 = f32[6400]{0} parameter(0)
ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
@@ -256,6 +256,50 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
op::Tuple(op::Multiply(), op::Divide()));
}
+TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopReduceToInputFusion) {
+ // Fusing a reduce into a loop fusion would require changing the fusion kind.
+ // That's not supported yet.
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[6400]{0} parameter(0)
+ ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
+ }
+
+ ENTRY entry {
+ p0 = f32[6400]{0} parameter(0)
+ fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
+ const.2 = f32[] constant(0)
+ reduce = f32[] reduce(p0, const.2), dimensions={0}, to_apply=scalar_add_computation
+ ROOT root = (f32[6400]{0}, f32[]) tuple(fusion.1, reduce)
+ })"))
+ .ValueOrDie();
+ ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[6400]{0} parameter(0)
+ ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
+ }
+
+ ENTRY entry {
+ p0 = f32[6400]{0} parameter(0)
+ fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
+ const.2 = f32[] constant(1)
+ div = f32[6400]{0} divide(p0, const.2)
+ ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, div)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Multiply(), op::Divide()));
+}
+
TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
fused_computation_1 {
@@ -341,7 +385,7 @@ TEST_F(MultiOutputFusionTest,
}
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
ENTRY reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
c0 = f32[] constant(0)
@@ -361,7 +405,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
}
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_add {
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
@@ -388,7 +432,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
}
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_select {
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
c0 = f32[] constant(0)
@@ -429,7 +473,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
}
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_element_wise {
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
@@ -456,7 +500,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
TEST_F(MultiOutputFusionTest,
ProducerConsumerFusionFp16LoopFusionAndReduceFusion) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_select {
p1.1 = f16[2,2,2]{2,1,0} parameter(1)
c0 = f16[] constant(0)
@@ -497,7 +541,7 @@ TEST_F(MultiOutputFusionTest,
TEST_F(MultiOutputFusionTest,
ProducerConsumerFusionReduceUnfriendlyLoopFusion) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
mixed_input_layouts_computation {
p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 5868c1a42e..695feadb11 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -22,6 +22,8 @@ limitations under the License.
#include <utility>
#include "absl/memory/memory.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
#include "llvm/IR/LLVMContext.h"
@@ -85,7 +87,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cuda_libdevice_path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -140,7 +141,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
Compiler* compiler) {
{
HloPassPipeline pipeline("optimization");
- pipeline.AddInvariantChecker<HloVerifier>();
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
pipeline.AddPass<GpuHloSupportChecker>();
ReducePrecisionInsertion::AddPasses(
&pipeline, hlo_module->config().debug_options(),
@@ -156,7 +158,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
{
auto& pass =
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
- pass.AddInvariantChecker<HloVerifier>();
+ pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
// If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls
// where possible. Not every batchnorm op can be implemented as a call to
@@ -203,7 +206,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// Convert convolutions into CustomCalls to cudnn, then canonicalize them
// (PadInsertion).
HloPassPipeline pipeline("conv_canonicalization");
- pipeline.AddInvariantChecker<HloVerifier>();
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
// TODO(b/31709653): Directly use the grouped convolution support of Cudnn.
pipeline.AddPass<ConvolutionFeatureGroupConverter>();
pipeline.AddPass<CudnnConvolutionRewriter>();
@@ -218,9 +222,22 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
}
{
- HloPassPipeline pipeline("layout_assignment");
+ // Run layout assignment in a separate pipeline from
+ // "post-layout-assignment" because we want everything after layout
+ // assignment to have a layout-sensitive invariant-checker, but
+ // HloPassPipeline also runs its invariant checker before any passes are
+ // run, meaning, the pipeline that contains layout assignment cannot contain
+ // a layout-sensitive verifier!
+ HloPassPipeline pipeline("layout assignment");
pipeline.AddPass<GpuLayoutAssignment>(
hlo_module->mutable_entry_computation_layout(), stream_exec);
+ TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
+ }
+
+ {
+ HloPassPipeline pipeline("post-layout_assignment");
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
@@ -266,17 +283,20 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
{
HloPassFix<HloPassPipeline> fusion("fusion");
- fusion.AddInvariantChecker<HloVerifier>();
+ fusion.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
fusion.AddPass<FusionMerger>();
fusion.AddPass<GpuMultiOutputFusion>();
fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
/*only_fusion_computations=*/true);
+ fusion.AddPass<HloDCE>();
TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
HloPassPipeline reduce_pipeline("reduce-precision");
- reduce_pipeline.AddInvariantChecker<HloVerifier>();
+ reduce_pipeline.AddInvariantChecker<HloVerifier>(
+ /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false);
ReducePrecisionInsertion::AddPasses(
&reduce_pipeline, hlo_module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
@@ -302,7 +322,8 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
// (b/27180329). Therefore, in that case, we set the output to be a copy of
// the parameter.
HloPassPipeline pipeline("GPU-ir-emit-prepare");
- pipeline.AddInvariantChecker<HloVerifier>();
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
// Copy insertion should be performed immediately before IR emission to avoid
// inserting unnecessary copies (later pass adds an instruction which
@@ -352,9 +373,9 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) {
string vmaj_str, vmin_str, vdot_str;
if (!RE2::PartialMatch(out, R"(\bV(\d+)\.(\d+)\.(\d+)\b)", &vmaj_str,
&vmin_str, &vdot_str) ||
- !tensorflow::strings::safe_strto64(vmaj_str, &vmaj) ||
- !tensorflow::strings::safe_strto64(vmin_str, &vmin) ||
- !tensorflow::strings::safe_strto64(vdot_str, &vdot)) {
+ !absl::SimpleAtoi(vmaj_str, &vmaj) ||
+ !absl::SimpleAtoi(vmin_str, &vmin) ||
+ !absl::SimpleAtoi(vdot_str, &vdot)) {
LOG(WARNING) << "Couldn't parse ptxas version in output of " << ptxas_path
<< " --version:\n"
<< out;
@@ -466,7 +487,7 @@ StatusOr<std::vector<uint8>> CompilePtx(const string& ptx, int cc_major,
tensorflow::SubProcess ptxas_info_dumper;
std::vector<string> ptxas_args = {
ptxas_path, ptx_path, "-o", cubin_path,
- tensorflow::strings::StrCat("-arch=sm_", cc_major, cc_minor)};
+ absl::StrCat("-arch=sm_", cc_major, cc_minor)};
if (VLOG_IS_ON(2)) {
ptxas_args.push_back("-v");
}
@@ -674,7 +695,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
// Write PTX to IR dump directory, if IR dumping was requested.
if (!ir_dump_directory.empty()) {
const string ptx_outfile = tensorflow::io::JoinPath(
- ir_dump_directory, tensorflow::strings::StrCat(module->name(), ".ptx"));
+ ir_dump_directory, absl::StrCat(module->name(), ".ptx"));
auto status = [&] {
auto* env = tensorflow::Env::Default();
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(ir_dump_directory));
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
index 192359f026..11dc56a64f 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
@@ -32,9 +32,7 @@ namespace gpu {
// TODO(jlebar): Also pad dots.
class PadForTensorCores : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override {
- return "pad for tensor cores";
- }
+ absl::string_view name() const override { return "pad for tensor cores"; }
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
index 99e7580b82..104af48c82 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
@@ -29,7 +29,12 @@ namespace {
namespace op = xla::testing::opcode_matchers;
using ::testing::_;
-using PadForTensorCoresTest = HloVerifiedTestBase;
+class PadForTensorCoresTest : public HloVerifiedTestBase {
+ public:
+ PadForTensorCoresTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) {
ParseAndVerifyModule(R"(
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h
index 67e51509e4..a622e894ed 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h
@@ -26,7 +26,7 @@ namespace gpu {
// padding, so that they can be lowered to cuDNN convolution.
class PadInsertion : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "pad insertion"; }
+ absl::string_view name() const override { return "pad insertion"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
index 3838fee674..ca57cacb98 100644
--- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
@@ -57,8 +57,8 @@ ParallelLoopEmitter::ParallelLoopEmitter(
unroll_factor_(unroll_factor) {}
std::vector<llvm_ir::IrArray::Index>
-ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name, llvm::Type* index_type) {
+ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
+ llvm::Type* index_type) {
// Emit the following code in LLVM IR:
// linear_index = blockIdx.x * blockDim.x + threadIdx.x;
// if (linear_index < num_elements) {
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
index b82a23419d..cc7da2e73b 100644
--- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
@@ -58,7 +58,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
~ParallelLoopEmitter() override = default;
std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name, llvm::Type* index_type) override;
+ absl::string_view loop_name, llvm::Type* index_type) override;
private:
// The thread and block dimension to parallelize the loop on.
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
index cca35316f0..15d1e269cc 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
@@ -27,13 +27,22 @@ namespace {
class GpuKernelTilingTest : public GpuCodegenTest {
protected:
- GpuKernelTilingTest() {
+ GpuKernelTilingTest() {}
+
+ // Most tests in this file want to skip layout assignment, but a few need it
+ // enabled.
+ HloModuleConfig ConfigWithLayoutAssignment() {
+ return GetModuleConfigForTest();
+ }
+
+ HloModuleConfig ConfigWithoutLayoutAssignment() {
+ HloModuleConfig config;
auto debug_options = HloTestBase::GetDebugOptionsForTest();
- config_.set_debug_options(debug_options);
// Disable layout_assignment to use the preassigned layouts.
- debug_options.add_xla_disable_hlo_passes("layout_assignment");
+ debug_options.add_xla_disable_hlo_passes("layout-assignment");
+ config.set_debug_options(debug_options);
+ return config;
}
- HloModuleConfig config_;
};
TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) {
@@ -46,7 +55,13 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) {
})";
// Check that a call to llvm.nvvm.barrier0 is generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ //
+ // We must enable layout assignment in order for this test to work correctly.
+ // AlgebraicSimplifier removes copy1; it's added back by layout assignment,
+ // which respects the module's entry computation layout. But if we don't run
+ // layout assignment...well, nobody else adds the copy back.
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @copy
@@ -68,8 +83,11 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) {
ROOT copy1 = f16[2,3,64]{1,0,2} copy(para0)
})";
- // Check that a call to llvm.nvvm.barrier0 is not generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ // Check that a call to llvm.nvvm.barrier0 is not generated. As in
+ // UnnestedTransposeWithProperDimensionsTiled, we must run layout assignment
+ // here.
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @copy
@@ -95,7 +113,8 @@ TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) {
})";
// Check that a call to llvm.nvvm.barrier0 is generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @fusion
@@ -128,7 +147,8 @@ TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) {
})";
// Check that a call to llvm.nvvm.barrier0 is generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @fusion
@@ -162,7 +182,8 @@ TEST_F(GpuKernelTilingTest,
})";
// Check that a call to llvm.nvvm.barrier0 is not generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @fusion
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc
index 9622936306..0f2d5568ca 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc
@@ -138,6 +138,9 @@ TEST_F(GpuUnrollingTest, UnrollMultiOutputFusion) {
HloModuleConfig config;
auto debug_options = HloTestBase::GetDebugOptionsForTest();
debug_options.set_xla_gpu_max_kernel_unroll_factor(2);
+ // Disable layout assignment for this test. Layout assignment does not expect
+ // fusions to be present, and so it does the wrong thing.
+ debug_options.add_xla_disable_hlo_passes("layout-assignment");
config.set_debug_options(debug_options);
const char *const kMultiOutputFusionModule = R"(
diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc
index bdb062837c..141f321938 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc
+++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc
@@ -144,16 +144,15 @@ const std::list<const Thunk*>& ThunkSchedule::DependsOn(
string ThunkSchedule::ToString() const {
string result = "Total order:\n";
for (Thunk* thunk : thunk_total_order_) {
- tensorflow::strings::StrAppend(&result, "\t",
- thunk->hlo_instruction()->ToString(), "\n");
+ absl::StrAppend(&result, "\t", thunk->hlo_instruction()->ToString(), "\n");
}
- tensorflow::strings::StrAppend(&result, "Dependencies:\n");
+ absl::StrAppend(&result, "Dependencies:\n");
for (const auto& entry : depends_on_) {
const Thunk* dependent = entry.first;
for (const Thunk* dependency : entry.second) {
- tensorflow::strings::StrAppend(
- &result, "\t", dependent->hlo_instruction()->name(), " depends on ",
- dependency->hlo_instruction()->name(), "\n");
+ absl::StrAppend(&result, "\t", dependent->hlo_instruction()->name(),
+ " depends on ", dependency->hlo_instruction()->name(),
+ "\n");
}
}
return result;
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
index c5f3906356..40183de96e 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
@@ -118,7 +118,8 @@ class WhileTransformerTest : public HloTestBase {
}
void RunCopyInsertionPass() {
- HloVerifier verifier;
+ HloVerifier verifier(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
TF_ASSERT_OK(verifier.Run(module_.get()).status());
CopyInsertion copy_insertion;
TF_ASSERT_OK(copy_insertion.Run(module_.get()).status());
diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc
index 31431f115f..a2be89511b 100644
--- a/tensorflow/compiler/xla/service/graphviz_example.cc
+++ b/tensorflow/compiler/xla/service/graphviz_example.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <string>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/types.h"
@@ -43,8 +43,7 @@ namespace {
// Adds a computation to the given HLO module which adds a scalar constant to
// its parameter and returns the result.
HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) {
- auto builder =
- HloComputation::Builder(tensorflow::strings::StrCat("add_", addend));
+ auto builder = HloComputation::Builder(absl::StrCat("add_", addend));
auto x_value = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "x_value"));
auto half = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
index 0ca489846e..0986da65cb 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_buffer.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -28,15 +30,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
+using absl::StrAppend;
// Data structure used to construct the alias analysis. Thrown away after alias
// analysis is complete. This data structure keeps track of which sets of
@@ -414,7 +412,7 @@ Status HloAliasAnalysis::Verify() const {
}
string HloAliasAnalysis::ToString() const {
- string out = StrCat("HloAliasAnalysis, module ", module_->name(), "\n");
+ string out = absl::StrCat("HloAliasAnalysis, module ", module_->name(), "\n");
StrAppend(&out, " Buffers at each position:\n");
for (const HloComputation* computation : module_->computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
@@ -537,10 +535,10 @@ bool HloAliasAnalysis::HasLiveRangeInterference(
if (ordering.MayInterfere(*values[i - 1], *values[i],
dataflow_analysis())) {
VLOG(1) << "In buffer " << buffer.id() << " containing values:\n "
- << Join(values, ", ",
- [](string* out, const HloValue* value) {
- StrAppend(out, value->ToShortString());
- })
+ << absl::StrJoin(values, ", ",
+ [](string* out, const HloValue* value) {
+ StrAppend(out, value->ToShortString());
+ })
<< "\nValue " << values[i - 1]->ToShortString()
<< " may interfere with value " << values[i]->ToShortString();
diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc
index e16413f361..6c11a073b7 100644
--- a/tensorflow/compiler/xla/service/hlo_buffer.cc
+++ b/tensorflow/compiler/xla/service/hlo_buffer.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -27,15 +29,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrCat;
-
bool HloBuffer::operator==(const HloBuffer& other) const {
bool equal = id() == other.id();
if (equal) {
@@ -59,10 +56,11 @@ std::vector<HloPosition> HloBuffer::ComputePositions() const {
}
string HloBuffer::ToString() const {
- return StrCat("HloBuffer ", id_, ", values: ",
- Join(values_, ", ", [](string* result, const HloValue* value) {
- result->append(value->ToShortString());
- }));
+ return absl::StrCat(
+ "HloBuffer ", id_, ", values: ",
+ absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) {
+ result->append(value->ToShortString());
+ }));
}
std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer) {
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 4c036ea1bf..cf95b112d7 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -25,6 +25,9 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@@ -37,13 +40,11 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-using ::tensorflow::strings::StrCat;
+using absl::StrCat;
std::unique_ptr<HloComputation> HloComputation::Builder::Build(
HloInstruction* root_instruction) {
@@ -136,7 +137,7 @@ string RenameFusionParameter(const string& original_name, int64 new_param_no) {
}
string after_param = original_name.substr(index + param_underscore.size());
int64 numeric_suffix;
- if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) {
+ if (absl::SimpleAtoi(after_param, &numeric_suffix)) {
return StrCat(original_name.substr(0, index + param_underscore.size()),
new_param_no);
}
@@ -805,11 +806,10 @@ std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const {
}
}
VLOG(3) << "Unreachable roots:"
- << tensorflow::str_util::Join(
- unreachable_roots, "\n\t",
- [](string* out, const HloInstruction* hlo) {
- tensorflow::strings::StrAppend(out, hlo->ToString());
- });
+ << absl::StrJoin(unreachable_roots, "\n\t",
+ [](string* out, const HloInstruction* hlo) {
+ absl::StrAppend(out, hlo->ToString());
+ });
return unreachable_roots;
}
@@ -980,8 +980,7 @@ void HloComputation::UniquifyName(NameUniquer* name_uniquer) {
name_ = name_uniquer->GetUniqueName(name_);
}
-HloInstruction* HloComputation::GetInstructionWithName(
- tensorflow::StringPiece name) {
+HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) {
auto instructions_in_computation = instructions();
auto it = absl::c_find_if(
instructions_in_computation,
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index faa33f0f90..8d9b694977 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -367,7 +367,7 @@ class HloComputation {
// Returns the instruction in this computation that has name `name`. Returns
// null if there is no such computation.
- HloInstruction* GetInstructionWithName(tensorflow::StringPiece name);
+ HloInstruction* GetInstructionWithName(absl::string_view name);
int64 unique_id() const { return unique_id_; }
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h
index 331480bd02..4557983a9c 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.h
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h
@@ -25,7 +25,7 @@ namespace xla {
// computation on constants.
class HloConstantFolding : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "constant_folding"; }
+ absl::string_view name() const override { return "constant_folding"; }
// Run constant folding operations on the given module. Returns whether the
// module was changed (constant expressions folded).
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index c4e27dc558..0ceb6a2968 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -16,14 +16,15 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
+using absl::StrCat;
using tensorflow::gtl::ArraySlice;
-using tensorflow::strings::StrCat;
StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
HloInstruction* rhs) {
@@ -336,7 +337,7 @@ StatusOr<HloInstruction*> BroadcastZeros(
StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
ArraySlice<const Shape*> domain, const Shape& range,
- tensorflow::StringPiece name) {
+ absl::string_view name) {
HloComputation::Builder b{std::string(name)};
int64 param_idx = 0;
for (const Shape* param_shape : domain) {
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h
index 5ff8946fb0..1bc6d09b45 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h
@@ -177,7 +177,7 @@ StatusOr<HloInstruction*> BroadcastZeros(
// a value of type `range`.
StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
tensorflow::gtl::ArraySlice<const Shape*> domain, const Shape& range,
- tensorflow::StringPiece name);
+ absl::string_view name);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h
index 5e2b348bdd..a28c03599a 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.h
+++ b/tensorflow/compiler/xla/service/hlo_cse.h
@@ -34,7 +34,7 @@ class HloCSE : public HloPassInterface {
: is_layout_sensitive_(is_layout_sensitive),
only_fusion_computations_(only_fusion_computations) {}
~HloCSE() override = default;
- tensorflow::StringPiece name() const override { return "cse"; }
+ absl::string_view name() const override { return "cse"; }
// Run CSE on the given module. Returns whether the module was changed (common
// subexpressions were found and eliminated).
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 01840a56e2..1d35757b42 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -30,8 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -79,8 +78,8 @@ bool MultiDynamicSliceUseShareSameIndices(
} // namespace
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
+using absl::StrAppend;
+using absl::StrCat;
HloDataflowAnalysis::HloDataflowAnalysis(
const HloModule& module, bool ssa_form, bool bitcast_defines_value,
@@ -977,28 +976,22 @@ Status HloDataflowAnalysis::Verify() const {
bool HloDataflowAnalysis::DoesNotUseOperandBuffer(
const HloInstruction* operand, const ShapeIndex& index,
const HloInstruction* user) const {
- CHECK(user->IsUserOf(operand))
- << "user: " << user->ToString() << " operand: " << operand->ToString();
- if (user->opcode() == HloOpcode::kFusion &&
- user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
- // Find fusion parameter associated with 'operand'.
- HloInstruction* fusion_param =
- user->fused_parameter(user->operand_index(operand));
- // Iterate through all users of all uses of the fusion parameter value.
- // Return false if any uses are detected, returns true otherwise.
- const HloValue& value = GetValueDefinedAt(fusion_param, index);
- return value.uses().empty();
- } else {
- // Return false if no value at 'operand' and 'index' is used at 'user'.
- for (const HloValue* value : GetValueSet(operand, index).values()) {
- for (const HloUse& use : value->uses()) {
- if (use.instruction == user) {
- return false;
+ // Return false if no value at 'operand' and 'index' is used at 'user'.
+ for (const HloValue* value : GetValueSet(operand, index).values()) {
+ for (const HloUse& use : value->uses()) {
+ if (use.instruction == user) {
+ if (user->opcode() == HloOpcode::kFusion &&
+ user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
+ HloInstruction* fusion_param =
+ user->fused_parameter(use.operand_number);
+ const HloValue& value =
+ GetValueDefinedAt(fusion_param, use.operand_index);
+ return value.uses().empty();
}
+ return false;
}
}
}
-
return true;
}
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index f4abc7a7c7..a1678d4943 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -138,7 +138,8 @@ class HloDataflowAnalysis {
// Returns true if 'user' cannot possibly use the buffer at 'index' in
// 'operand'. Returns false otherwise.
//
- // REQUIRES: 'operand' is an operand of 'user'.
+ // 'operand' does not have to be an operand of 'user'. This can be the case
+ // with indirect uses.
bool DoesNotUseOperandBuffer(const HloInstruction* operand,
const ShapeIndex& index,
const HloInstruction* user) const;
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 4755c4a0cf..d1a96c10f8 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -1963,6 +1963,54 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
}
+// Similar to FusedDynamicUpdateSlice above, but tests indirect uses of the
+// parameter tuple.
+TEST_F(DoesNotUseOperandBufferTest, IndirectUses) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape data_shape = ShapeUtil::MakeShape(F32, {8});
+ auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
+ auto t0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 0));
+ auto t1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 1));
+ // Swap the tuple elements.
+ auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({t1, t0}));
+
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
+
+ // Create a DynamicUpdateSlice instruction of tuple element 1.
+ auto starts = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
+ auto update = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
+ auto dynamic_update_slice =
+ builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+ data_shape, gte1, update, starts));
+ builder.AddInstruction(
+ HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {dynamic_update_slice, starts, update, gte1},
+ HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ // The fusion instruction never uses tuple element 0, but does use element 1.
+ EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion));
+ EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
+ // The same holds for the parameter tuple, except that the tuple elements are
+ // swapped in 'tuple'.
+ EXPECT_TRUE(
+ dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {1}, fusion));
+ EXPECT_FALSE(
+ dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {0}, fusion));
+}
+
class CanShareOperandBufferWithUserTest : public HloDataflowAnalysisTestBase {};
TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h
index 4e244494d6..1fe69b1395 100644
--- a/tensorflow/compiler/xla/service/hlo_dce.h
+++ b/tensorflow/compiler/xla/service/hlo_dce.h
@@ -36,7 +36,7 @@ namespace xla {
class HloDCE : public HloPassInterface {
public:
~HloDCE() override {}
- tensorflow::StringPiece name() const override { return "dce"; }
+ absl::string_view name() const override { return "dce"; }
// Run the pass on the given module. Returns whether the module was changed
// (instructions were removed).
diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc
index af904647f8..72185698c9 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc
@@ -31,31 +31,10 @@ class HloDomainIsolator::RunContext {
StatusOr<bool> Run();
private:
- // Inserts a kDomain instruction between operand and instruction in case
- // the attribute (ie, sharding) values change between root and instruction.
- // Returns the newly inserted kDomain instruction, or nullptr if no kDomain
- // instruction was necessary.
- StatusOr<HloInstruction*> CreateDomain(HloInstruction* instruction,
- HloInstruction* root,
- HloInstruction* operand);
-
HloModule* module_;
HloDomainIsolator* isolator_;
};
-StatusOr<HloInstruction*> HloDomainIsolator::RunContext::CreateDomain(
- HloInstruction* instruction, HloInstruction* root,
- HloInstruction* operand) {
- HloInstruction* domain = nullptr;
- std::unique_ptr<HloInstruction> domain_instruction =
- isolator_->creator_(instruction, root, operand);
- if (domain_instruction != nullptr) {
- domain = operand->parent()->AddInstruction(std::move(domain_instruction));
- TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain));
- }
- return domain;
-}
-
StatusOr<bool> HloDomainIsolator::RunContext::Run() {
hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Isolator");
@@ -76,10 +55,11 @@ StatusOr<bool> HloDomainIsolator::RunContext::Run() {
root = root->mutable_operand(0);
}
// Check whether a kDomain is necessary between instruction and operand.
- TF_ASSIGN_OR_RETURN(HloInstruction * domain,
- CreateDomain(instruction, root, operand));
+ HloInstruction* domain =
+ isolator_->creator_(instruction, root, operand);
if (domain != nullptr) {
VLOG(4) << "New domain: " << domain->ToString();
+ TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain));
++added_domains;
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
index bb1537766c..d36631fc2f 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
@@ -38,12 +38,12 @@ class HloDomainIsolator : public HloPassInterface {
// instruction differes from the attribute of the root (the second
// HloInstruction argument).
// Returns nullptr in case no domain separation is necessary.
- using DomainCreator = std::function<std::unique_ptr<HloInstruction>(
+ using DomainCreator = std::function<HloInstruction*(
HloInstruction*, HloInstruction*, HloInstruction*)>;
explicit HloDomainIsolator(DomainCreator creator);
- tensorflow::StringPiece name() const override { return "domain_isolator"; }
+ absl::string_view name() const override { return "domain_isolator"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
index f855f2a1fc..575149c8b8 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
@@ -20,10 +20,10 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -63,7 +63,7 @@ class DomainMetadata {
// Returns the metadata type. A unique identifier which describes the real
// metadata type.
- virtual tensorflow::StringPiece Kind() const = 0;
+ virtual absl::string_view Kind() const = 0;
// Compares the metadata object with another one and returns true if the
// two matches.
diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h
index c859e05f02..97bc8ef604 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_remover.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h
@@ -35,13 +35,13 @@ class HloDomainRemover : public HloPassInterface {
// instructions in it with the same attributes (ie, sharding), a normalizer
// function is tasked at applying attribute normalization on the instructions
// within such domain.
- HloDomainRemover(tensorflow::StringPiece kind,
+ HloDomainRemover(absl::string_view kind,
std::function<Status(const DomainMetadata::Domain&,
const DomainMetadata* metadata)>
normalizer)
- : kind_(kind.ToString()), normalizer_(std::move(normalizer)) {}
+ : kind_(kind), normalizer_(std::move(normalizer)) {}
- tensorflow::StringPiece name() const override { return "domain_remover"; }
+ absl::string_view name() const override { return "domain_remover"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index 2654929bf0..79e78ee2d0 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -29,6 +29,11 @@ namespace xla {
namespace {
class HloDomainTest : public HloVerifiedTestBase {
+ public:
+ HloDomainTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
bool FindUserViaDomainPath(HloInstruction* instruction,
HloInstruction* operand) const {
@@ -46,9 +51,8 @@ class HloDomainTest : public HloVerifiedTestBase {
// Checks whether there is a kDomain instruction in the edge between the
// instruction and the operand.
- bool HasDomainEdge(HloModule* module,
- tensorflow::StringPiece instruction_name,
- tensorflow::StringPiece operand_name) {
+ bool HasDomainEdge(HloModule* module, absl::string_view instruction_name,
+ absl::string_view operand_name) {
HloInstruction* instruction = FindInstruction(module, instruction_name);
HloInstruction* operand = FindInstruction(module, operand_name);
CHECK_NE(instruction, nullptr);
@@ -66,7 +70,7 @@ class HloDomainTest : public HloVerifiedTestBase {
return false;
}
- StatusOr<HloModule*> ParseModule(tensorflow::StringPiece hlo_string) {
+ StatusOr<HloModule*> ParseModule(absl::string_view hlo_string) {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
ParseAndVerifyModule(hlo_string, config);
@@ -84,7 +88,7 @@ class OpNameMetadata : public DomainMetadata {
return absl::make_unique<OpNameMetadata>(opname_);
}
- tensorflow::StringPiece Kind() const override { return KindName(); }
+ absl::string_view Kind() const override { return KindName(); }
bool Matches(const DomainMetadata& other) const override {
const OpNameMetadata* other_ptr =
@@ -98,16 +102,16 @@ class OpNameMetadata : public DomainMetadata {
string ToString() const override { return opname_; }
- static tensorflow::StringPiece KindName() { return "opname"; }
+ static absl::string_view KindName() { return "opname"; }
private:
string opname_;
};
// Creator function for OpNameMetadata domains.
-std::unique_ptr<HloInstruction> OpNameDomainCreator(HloInstruction* instruction,
- HloInstruction* root,
- HloInstruction* operand) {
+HloInstruction* OpNameDomainCreator(HloInstruction* instruction,
+ HloInstruction* root,
+ HloInstruction* operand) {
if (instruction->metadata().op_name() == root->metadata().op_name()) {
return nullptr;
}
@@ -115,9 +119,9 @@ std::unique_ptr<HloInstruction> OpNameDomainCreator(HloInstruction* instruction,
absl::make_unique<OpNameMetadata>(root->metadata().op_name());
std::unique_ptr<DomainMetadata> user_side_metadata =
absl::make_unique<OpNameMetadata>(instruction->metadata().op_name());
- return HloInstruction::CreateDomain(operand->shape(), operand,
- std::move(operand_side_metadata),
- std::move(user_side_metadata));
+ return operand->parent()->AddInstruction(HloInstruction::CreateDomain(
+ operand->shape(), operand, std::move(operand_side_metadata),
+ std::move(user_side_metadata)));
}
Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain,
@@ -144,7 +148,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);
@@ -186,7 +190,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(!isolator_changed);
}
@@ -213,7 +217,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);
@@ -250,7 +254,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_FALSE(isolator_changed);
}
@@ -304,7 +308,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
- HloDomainIsolator sharding_isolator(CreateShardingDomain);
+ HloDomainIsolator sharding_isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed,
sharding_isolator.Run(module));
EXPECT_TRUE(sharding_isolator_changed);
@@ -358,7 +362,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);
@@ -447,7 +451,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);
@@ -506,7 +510,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);
diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc
index 751fc677e2..dc514ae3e5 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc
@@ -52,7 +52,7 @@ Status HloDomainVerifier::RunContext::PopulateDomainKinds() {
TF_RET_CHECK(instruction->user_side_metadata().Kind() ==
instruction->operand_side_metadata().Kind())
<< instruction->ToString();
- kinds.insert(instruction->user_side_metadata().Kind().ToString());
+ kinds.insert(string(instruction->user_side_metadata().Kind()));
}
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
index 8e53cf97f8..81d6d69a8c 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
@@ -33,7 +33,7 @@ class HloDomainVerifier : public HloPassInterface {
public:
HloDomainVerifier(std::vector<string> kinds) : kinds_(std::move(kinds)) {}
- tensorflow::StringPiece name() const override { return "domain_verifier"; }
+ absl::string_view name() const override { return "domain_verifier"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
index 2b109225d0..44ded2c2fa 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter.h
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
@@ -32,9 +32,7 @@ class HloElementTypeConverter : public HloPassInterface {
HloElementTypeConverter(PrimitiveType eliminate_type,
PrimitiveType replace_with_type);
- tensorflow::StringPiece name() const override {
- return "element_type_converter";
- }
+ absl::string_view name() const override { return "element_type_converter"; }
// Returns the pass on the module and returns whether the module was modified.
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index fb90049491..ca1c4dd0e9 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
@@ -44,7 +45,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 4b8e6260ac..c3af15c6a8 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -52,7 +52,10 @@ static std::array<bool, 2> use_bf16_params{true, false};
class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
public HloVerifiedTestBase {
protected:
- HloEvaluatorTest() : use_bfloat16_(GetParam()) {
+ HloEvaluatorTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false),
+ use_bfloat16_(GetParam()) {
evaluator_ = absl::make_unique<HloEvaluator>();
}
@@ -1216,7 +1219,12 @@ TEST_P(HloEvaluatorTest,
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
-class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
+class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {
+ public:
+ HloEvaluatorPreciseReduceTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
// Tests that Reduce doesn't lose precision when adding many numbers (because
// it accumulates its result in a double).
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
index eba80c0f19..460ae2b5ec 100644
--- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
@@ -14,15 +14,15 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace {
-using tensorflow::strings::StrCat;
+using absl::StrCat;
using ::testing::AllOf;
using ::testing::ContainsRegex;
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index f8ade39e8c..59c628e945 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -26,6 +26,10 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_replace.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
@@ -40,27 +44,24 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/regexp.h"
-using ::absl::nullopt;
-using ::absl::optional;
-using ::tensorflow::Env;
-using ::tensorflow::WriteStringToFile;
-using ::tensorflow::io::JoinPath;
-using ::tensorflow::str_util::Join;
-using ::tensorflow::str_util::StringReplace;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
-
namespace xla {
namespace hlo_graph_dumper {
namespace {
+using absl::nullopt;
+using absl::optional;
+using absl::StrAppend;
+using absl::StrCat;
+using absl::StrJoin;
+using tensorflow::Env;
+using tensorflow::WriteStringToFile;
+using tensorflow::io::JoinPath;
+
// Helpers for Printf and Appendf.
template <typename T>
struct PrintfConvert {
@@ -217,9 +218,8 @@ string NodeColorAttributes(ColorScheme color) {
// Replaces <> with &lt;&gt;, so that this string is safe(er) for use in a
// graphviz HTML-like string.
-string HtmlLikeStringSanitize(tensorflow::StringPiece s) {
- return StringReplace(StringReplace(s, "<", "&lt;", /*replace_all=*/true), ">",
- "&gt;", /*replace_all=*/true);
+string HtmlLikeStringSanitize(absl::string_view s) {
+ return absl::StrReplaceAll(s, {{"<", "&lt;"}, {">", "&gt;"}});
}
// Tries to generates a human-readable one-word description of the given
@@ -322,7 +322,7 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
// Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax).
class HloDotDumper {
public:
- HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label,
+ HloDotDumper(const HloComputation* computation, absl::string_view label,
const DebugOptions& debug_options, bool show_backend_config,
const HloExecutionProfile* profile, NodeFilter filter)
: computation_(computation),
@@ -457,7 +457,7 @@ labelloc = t;
tooltip = " ";
// DOT graphs accept a stylesheet as a URI. So naturally, an inline
// stylesheet is a data URI!
-stylesheet="
+stylesheet=<
data:text/css,
@import url(https://fonts.googleapis.com/css?family=Roboto:400,700);
svg text {
@@ -466,7 +466,7 @@ stylesheet="
}
%s
-"
+>
)";
@@ -559,10 +559,10 @@ stylesheet="
}
}
- return Printf(fmt, graph_label, Join(edge_css_rules, "\n"));
+ return Printf(fmt, graph_label, StrJoin(edge_css_rules, "\n"));
}
-string HloDotDumper::Footer() { return StrCat(Join(edges_, "\n"), "\n}"); }
+string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); }
bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
@@ -854,7 +854,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
// Otherwise, print e.g. "%constant.42 (s32[100])".
string constant_name;
- if (tensorflow::str_util::StartsWith(constant->name(), "constant")) {
+ if (absl::StartsWith(constant->name(), "constant")) {
constant_name = constant->name();
} else {
constant_name = StrCat("constant ", constant->name());
@@ -896,7 +896,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
}
}
}
- return Join(lines, "<br/>");
+ return StrJoin(lines, "<br/>");
}
ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
@@ -1084,8 +1084,7 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
// The HLO instruction name contains usually the opcode, e.g. "%add.42" is
// an add instruction. In this case we render just the name.
- if (tensorflow::str_util::StartsWith(instr->name(),
- HloOpcodeString(instr->opcode()))) {
+ if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) {
return Printf("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
}
string extended_opcode =
@@ -1113,7 +1112,7 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
instr->metadata().source_line()));
}
- return Join(lines, "<br/>");
+ return StrJoin(lines, "<br/>");
}
string HloDotDumper::GetInstructionNodeBackendConfig(
@@ -1160,8 +1159,7 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
constexpr int kMaxShapeLen = 64;
if (instr_shape.length() > kMaxShapeLen) {
instr_shape = StrCat(
- tensorflow::StringPiece(instr_shape).substr(0, kMaxShapeLen - 3),
- "...");
+ absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "...");
}
lines.push_back(instr_shape);
}
@@ -1178,7 +1176,7 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
100 * hlo_cycles_executed / total_cycles_executed));
}
}
- return Join(lines, "<br/>");
+ return StrJoin(lines, "<br/>");
}
// Gets the total number of array elements in the given shape. For tuples, this
@@ -1271,7 +1269,7 @@ string HloDotDumper::GetInstructionTrivialComputationStr(
HtmlLikeStringSanitize(*computation_type)));
}
}
- return Join(lines, "<br/>");
+ return StrJoin(lines, "<br/>");
}
const HloInstruction* HloDotDumper::GetNodeForEdge(
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
index 1d7a062c55..064c53252c 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -23,12 +24,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla.pb.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace {
-using ::tensorflow::strings::StrCat;
+using absl::StrCat;
using ::testing::HasSubstr;
string TestName() {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 668ed9d6c3..2bb9de686f 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -24,6 +24,11 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/escaping.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
@@ -41,17 +46,15 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/human_readable_json.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-using tensorflow::str_util::CEscape;
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
+using absl::CEscape;
+using absl::StrAppend;
+using absl::StrCat;
+using absl::StrJoin;
/* static */
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
@@ -664,8 +667,7 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
HloInstruction::CreateCrossReplicaSum(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* reduce_computation,
- const std::vector<ReplicaGroup>& replica_groups,
- tensorflow::StringPiece barrier,
+ const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier,
const absl::optional<int64>& all_reduce_id) {
return absl::make_unique<HloAllReduceInstruction>(
shape, operands, reduce_computation, replica_groups, barrier,
@@ -688,7 +690,7 @@ HloInstruction::CreateCrossReplicaSum(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
const Shape& outfeed_shape, HloInstruction* operand,
- HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) {
+ HloInstruction* token_operand, absl::string_view outfeed_config) {
return absl::make_unique<HloOutfeedInstruction>(
outfeed_shape, operand, token_operand, outfeed_config);
}
@@ -1066,7 +1068,7 @@ bool HloInstruction::HasSideEffect() const {
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece custom_call_target) {
+ absl::string_view custom_call_target) {
return absl::make_unique<HloCustomCallInstruction>(shape, operands,
custom_call_target);
}
@@ -1345,7 +1347,7 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone(
// If names ends with .suffix[0-9]+ then replace with a suffix with the
// numeric value incremented.
int64 numeric_suffix;
- if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) {
+ if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) {
clone->name_ =
StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1);
} else {
@@ -1817,7 +1819,7 @@ void HloInstruction::set_false_computation(HloComputation* false_computation) {
string HloInstruction::SignatureString() const {
string operands =
- Join(operands_, ", ", [](string* out, HloInstruction* operand) {
+ StrJoin(operands_, ", ", [](string* out, HloInstruction* operand) {
StrAppend(out, ShapeUtil::HumanString(operand->shape()));
});
return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
@@ -1964,7 +1966,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap(
slice.size() > kMaxOperandsToShowIfCompact) {
slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
}
- operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) {
+ operands = StrJoin(slice, ", ", [&](string* out, HloInstruction* operand) {
// If operand is already been deleted, put `null` to the string output.
if (operand == nullptr) {
StrAppend(out, "null ");
@@ -1984,7 +1986,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap(
} else if (!options.compact_operands()) {
str.push_back(PrintName(operand->name(), options));
}
- StrAppend(out, Join(str, " "));
+ StrAppend(out, StrJoin(str, " "));
});
const int64 remaining = operands_.size() - slice.size();
if (slice.size() != operands_.size()) {
@@ -2030,8 +2032,9 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
extra.push_back(
StrCat("to_apply=", PrintName(to_apply()->name(), options)));
} else if (!called_computations().empty()) {
- extra.push_back(StrCat(
- "calls=", Join(called_computations(), ", ",
+ extra.push_back(
+ StrCat("calls=",
+ StrJoin(called_computations(), ", ",
[&](string* out, const HloComputation* computation) {
StrAppend(out,
PrintName(computation->name(), options));
@@ -2068,12 +2071,12 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
break;
default:
if (!called_computations().empty()) {
- extra.push_back(
- StrCat("calls=\n",
- Join(called_computations(), ", ",
- [&](string* out, const HloComputation* computation) {
- StrAppend(out, computation->ToString(new_options));
- })));
+ extra.push_back(StrCat(
+ "calls=\n",
+ StrJoin(called_computations(), ", ",
+ [&](string* out, const HloComputation* computation) {
+ StrAppend(out, computation->ToString(new_options));
+ })));
}
break;
}
@@ -2084,11 +2087,11 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
}
if (!control_predecessors_.empty()) {
extra.push_back(StrCat("control-predecessors={",
- Join(control_predecessors_, ", ",
- [&](string* out, HloInstruction* pre) {
- StrAppend(out,
- PrintName(pre->name(), options));
- }),
+ StrJoin(control_predecessors_, ", ",
+ [&](string* out, HloInstruction* pre) {
+ StrAppend(out,
+ PrintName(pre->name(), options));
+ }),
"}"));
}
if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
@@ -2102,10 +2105,10 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
string HloInstruction::ToShortString() const {
return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
- Join(operands_, ", ",
- [](string* out, HloInstruction* operand) {
- StrAppend(out, "%", operand->name());
- }),
+ StrJoin(operands_, ", ",
+ [](string* out, HloInstruction* operand) {
+ StrAppend(out, "%", operand->name());
+ }),
")");
}
@@ -2795,7 +2798,7 @@ string PaddingConfigToString(const PaddingConfig& padding) {
[](const PaddingConfig::PaddingConfigDimension& dim) {
return dim.interior_padding() != 0;
});
- return Join(
+ return StrJoin(
padding.dimensions(), "x",
[&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {
StrAppend(
@@ -2819,16 +2822,15 @@ string OpMetadataToString(const OpMetadata& metadata) {
if (metadata.source_line() != 0) {
result.push_back(StrCat("source_line=", metadata.source_line()));
}
- return Join(result, " ");
+ return StrJoin(result, " ");
}
string RandomDistributionToString(const RandomDistribution& distribution) {
- return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution));
+ return absl::AsciiStrToLower(RandomDistribution_Name(distribution));
}
string PrecisionToString(const PrecisionConfigProto::Precision& precision) {
- return tensorflow::str_util::Lowercase(
- PrecisionConfigProto::Precision_Name(precision));
+ return absl::AsciiStrToLower(PrecisionConfigProto::Precision_Name(precision));
}
string ConvolutionDimensionNumbersToString(
@@ -2856,8 +2858,8 @@ string ConvolutionDimensionNumbersToString(
output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
}
- return StrCat(Join(lhs_dims, ""), "_", Join(rhs_dims, ""), "->",
- Join(output_dims, ""));
+ return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->",
+ StrJoin(output_dims, ""));
}
string HloInstruction::DotDimensionNumbersToString() const {
@@ -2868,19 +2870,21 @@ string HloInstruction::DotDimensionNumbersToString() const {
const DotDimensionNumbers& dnums = *dot_dimension_numbers_;
if (!dnums.lhs_batch_dimensions().empty()) {
result.push_back(StrCat("lhs_batch_dims={",
- Join(dnums.lhs_batch_dimensions(), ","), "}"));
+ StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
}
result.push_back(StrCat("lhs_contracting_dims={",
- Join(dnums.lhs_contracting_dimensions(), ","), "}"));
+ StrJoin(dnums.lhs_contracting_dimensions(), ","),
+ "}"));
if (!dnums.rhs_batch_dimensions().empty()) {
result.push_back(StrCat("rhs_batch_dims={",
- Join(dnums.rhs_batch_dimensions(), ","), "}"));
+ StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
}
result.push_back(StrCat("rhs_contracting_dims={",
- Join(dnums.rhs_contracting_dimensions(), ","), "}"));
+ StrJoin(dnums.rhs_contracting_dimensions(), ","),
+ "}"));
- return Join(result, ", ");
+ return StrJoin(result, ", ");
}
StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
@@ -2894,7 +2898,7 @@ StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
}
return map;
}();
- auto found = map->find(tensorflow::str_util::Lowercase(name));
+ auto found = map->find(absl::AsciiStrToLower(name));
if (found == map->end()) {
return InvalidArgument("Unknown distribution");
}
@@ -2907,15 +2911,14 @@ string HloInstruction::PrecisionConfigToString() const {
}
return StrCat(
"operand_precision={",
- Join(precision_config_.operand_precision(), ",",
- [](string* out, int32 precision) {
- CHECK(PrecisionConfigProto::Precision_IsValid(precision))
- << precision;
- StrAppend(
- out,
- PrecisionToString(
- static_cast<PrecisionConfigProto::Precision>(precision)));
- }),
+ StrJoin(precision_config_.operand_precision(), ",",
+ [](string* out, int32 precision) {
+ CHECK(PrecisionConfigProto::Precision_IsValid(precision))
+ << precision;
+ StrAppend(out, PrecisionToString(
+ static_cast<PrecisionConfigProto::Precision>(
+ precision)));
+ }),
"}");
}
@@ -2932,7 +2935,7 @@ StatusOr<PrecisionConfigProto::Precision> StringToPrecision(
}
return map;
}();
- auto found = map->find(tensorflow::str_util::Lowercase(name));
+ auto found = map->find(absl::AsciiStrToLower(name));
if (found == map->end()) {
return InvalidArgument("Unknown distribution");
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 121a9e55f6..566c1c449a 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -34,6 +34,8 @@ limitations under the License.
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -47,7 +49,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
@@ -222,7 +223,7 @@ class CanonicalNameMap {
return iter->second;
}
- string new_name = tensorflow::strings::StrCat("tmp_", index++);
+ string new_name = absl::StrCat("tmp_", index++);
canonical_name_map[old_name] = new_name;
return new_name;
}
@@ -450,8 +451,7 @@ class HloInstruction {
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups,
- tensorflow::StringPiece barrier,
- const absl::optional<int64>& all_reduce_id);
+ absl::string_view barrier, const absl::optional<int64>& all_reduce_id);
// This op handles the communication of an Alltoall operation. On each core,
// the operands are N ops in the same shape, where N is the number of cores
@@ -493,7 +493,7 @@ class HloInstruction {
// which is a TOKEN.
static std::unique_ptr<HloInstruction> CreateOutfeed(
const Shape& outfeed_shape, HloInstruction* operand,
- HloInstruction* token_operand, tensorflow::StringPiece outfeed_config);
+ HloInstruction* token_operand, absl::string_view outfeed_config);
// Creates an asynchronous send instruction with the given channel id, which
// initiates sending the operand data to a unique receive instruction in
@@ -706,7 +706,7 @@ class HloInstruction {
// to the given operands. "shape" is the resultant shape.
static std::unique_ptr<HloInstruction> CreateCustomCall(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece custom_call_target);
+ absl::string_view custom_call_target);
// Creates a tuple instruction with the given elements. This is a convenience
// wrapper around CreateVariadic.
@@ -1037,6 +1037,8 @@ class HloInstruction {
CHECK(has_sharding());
return *sharding_;
}
+ std::shared_ptr<const HloSharding> sharding_ptr() const { return sharding_; }
+
// Returns the sharding applied to this operator, or default_ if none exists.
const HloSharding& sharding_or_default(const HloSharding& default_) const {
return sharding_ ? *sharding_ : default_;
@@ -1051,7 +1053,10 @@ class HloInstruction {
// Sets the sharding of this operator. Should only be called by HloModule or
// HloComputation methods.
void set_sharding(const HloSharding& sharding) {
- sharding_ = absl::make_unique<HloSharding>(sharding);
+ sharding_ = std::make_shared<const HloSharding>(sharding);
+ }
+ void set_sharding(std::shared_ptr<const HloSharding> sharding) {
+ sharding_ = std::move(sharding);
}
void set_single_sharding(const HloSharding& sharding);
// Sets a sharding that assigns the current instruction to device.
@@ -1652,7 +1657,10 @@ class HloInstruction {
bool copy_elision_allowed_ = true;
// The sharding, if one exists.
- std::unique_ptr<HloSharding> sharding_;
+ // Uses std::shared_ptr to allow reuse of the same sharding object between
+ // HloInstructions and other components as HloSharding can be very large for
+ // many element tuples.
+ std::shared_ptr<const HloSharding> sharding_;
// Fields used by the kDomain instruction.
std::unique_ptr<DomainMetadata> operand_side_metadata_;
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 2a99d4d7c4..a0de253eda 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -19,6 +19,10 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
+#include "absl/strings/escaping.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -29,10 +33,10 @@ limitations under the License.
namespace xla {
namespace {
-using ::tensorflow::str_util::CEscape;
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
+using absl::CEscape;
+using absl::StrAppend;
+using absl::StrCat;
+using absl::StrJoin;
bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
const HloInstruction* operand) {
@@ -160,7 +164,7 @@ HloInstructionProto HloFftInstruction::ToProto() const {
std::vector<string> HloFftInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
return {StrCat("fft_type=", FftType_Name(fft_type())),
- StrCat("fft_length={", Join(fft_length(), ","), "}")};
+ StrCat("fft_length={", StrJoin(fft_length(), ","), "}")};
}
bool HloFftInstruction::IdenticalSlowPath(
@@ -320,10 +324,10 @@ std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
std::vector<string> replica_group_str;
for (const ReplicaGroup& group : replica_groups()) {
replica_group_str.push_back(
- StrCat("{", Join(group.replica_ids(), ","), "}"));
+ StrCat("{", StrJoin(group.replica_ids(), ","), "}"));
}
result.push_back(
- StrCat("replica_groups={", Join(replica_group_str, ","), "}"));
+ StrCat("replica_groups={", StrJoin(replica_group_str, ","), "}"));
return result;
}
@@ -343,11 +347,11 @@ bool HloCollectiveInstruction::IdenticalSlowPath(
HloAllReduceInstruction::HloAllReduceInstruction(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* reduce_computation,
- const std::vector<ReplicaGroup>& replica_groups,
- tensorflow::StringPiece barrier, const absl::optional<int64>& all_reduce_id)
+ const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier,
+ const absl::optional<int64>& all_reduce_id)
: HloCollectiveInstruction(HloOpcode::kCrossReplicaSum, shape, operands,
replica_groups),
- cross_replica_sum_barrier_(barrier.begin(), barrier.end()),
+ cross_replica_sum_barrier_(barrier),
all_reduce_id_(all_reduce_id) {
AppendComputation(reduce_computation);
}
@@ -430,7 +434,7 @@ HloInstructionProto HloReverseInstruction::ToProto() const {
std::vector<string> HloReverseInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloReverseInstruction::IdenticalSlowPath(
@@ -469,7 +473,7 @@ HloInstructionProto HloConcatenateInstruction::ToProto() const {
std::vector<string> HloConcatenateInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloConcatenateInstruction::IdenticalSlowPath(
@@ -512,7 +516,7 @@ HloInstructionProto HloReduceInstruction::ToProto() const {
std::vector<string> HloReduceInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloReduceInstruction::IdenticalSlowPath(
@@ -555,7 +559,7 @@ HloInstructionProto HloSortInstruction::ToProto() const {
std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloSortInstruction::IdenticalSlowPath(
@@ -588,7 +592,7 @@ HloTransposeInstruction::HloTransposeInstruction(
Permute(dimensions, shape.dimensions()).begin()))
<< "shape: " << ShapeUtil::HumanString(shape)
<< ", operand->shape(): " << ShapeUtil::HumanString(shape)
- << ", dimensions: {" << Join(dimensions, ", ") << "}";
+ << ", dimensions: {" << StrJoin(dimensions, ", ") << "}";
AppendOperand(operand);
}
@@ -609,7 +613,7 @@ HloInstructionProto HloTransposeInstruction::ToProto() const {
std::vector<string> HloTransposeInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloTransposeInstruction::IdenticalSlowPath(
@@ -648,7 +652,7 @@ HloInstructionProto HloBroadcastInstruction::ToProto() const {
std::vector<string> HloBroadcastInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloBroadcastInstruction::IdenticalSlowPath(
@@ -709,7 +713,7 @@ bool HloMapInstruction::IsElementwiseImpl(
std::vector<string> HloMapInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloMapInstruction::IdenticalSlowPath(
@@ -767,7 +771,7 @@ std::vector<string> HloSliceInstruction::ExtraAttributesToStringImpl(
bounds.push_back(
StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]"));
}
- return {StrCat("slice={", Join(bounds, ", "), "}")};
+ return {StrCat("slice={", StrJoin(bounds, ", "), "}")};
}
bool HloSliceInstruction::IdenticalSlowPath(
@@ -853,7 +857,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
// lines. Compact this into one line by stripping out white space.
string tmp = literal().ToString();
std::replace(tmp.begin(), tmp.end(), '\n', ' ');
- std::vector<string> v = tensorflow::str_util::Split(tmp, ' ');
+ std::vector<string> v = absl::StrSplit(tmp, ' ');
bool first = true;
// Concatenate elements in "v" with spaces separating them, but ignoring
// empty entries.
@@ -1554,12 +1558,13 @@ std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl(
infeed_shape(), new_operands[0], infeed_config());
}
-HloOutfeedInstruction::HloOutfeedInstruction(
- const Shape& outfeed_shape, HloInstruction* operand,
- HloInstruction* token_operand, tensorflow::StringPiece outfeed_config)
+HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape,
+ HloInstruction* operand,
+ HloInstruction* token_operand,
+ absl::string_view outfeed_config)
: HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
outfeed_shape_(outfeed_shape),
- outfeed_config_(outfeed_config.begin(), outfeed_config.end()) {
+ outfeed_config_(outfeed_config) {
CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape))
<< "Outfeed shape " << outfeed_shape
<< " must be compatible with operand shape " << operand->shape();
@@ -1767,7 +1772,7 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
HloCustomCallInstruction::HloCustomCallInstruction(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece custom_call_target)
+ absl::string_view custom_call_target)
: HloInstruction(HloOpcode::kCustomCall, shape),
custom_call_target_(custom_call_target.begin(),
custom_call_target.end()) {
@@ -1903,8 +1908,8 @@ HloInstructionProto HloDynamicSliceInstruction::ToProto() const {
std::vector<string> HloDynamicSliceInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {
- StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")};
+ return {StrCat("dynamic_slice_sizes={", StrJoin(dynamic_slice_sizes(), ","),
+ "}")};
}
bool HloDynamicSliceInstruction::IdenticalSlowPath(
@@ -1940,17 +1945,17 @@ string HloGatherInstruction::GatherDimensionNumbersToString() const {
CHECK(gather_dimension_numbers_ != nullptr);
string offset_dims =
StrCat("offset_dims={",
- Join(gather_dimension_numbers_->offset_dims(), ","), "}");
- string collapsed_slice_dims =
- StrCat("collapsed_slice_dims={",
- Join(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}");
+ StrJoin(gather_dimension_numbers_->offset_dims(), ","), "}");
+ string collapsed_slice_dims = StrCat(
+ "collapsed_slice_dims={",
+ StrJoin(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}");
string start_index_map =
StrCat("start_index_map={",
- Join(gather_dimension_numbers_->start_index_map(), ","), "}");
+ StrJoin(gather_dimension_numbers_->start_index_map(), ","), "}");
string index_vector_dim = StrCat(
"index_vector_dim=", gather_dimension_numbers_->index_vector_dim());
- return Join<std::initializer_list<string>>(
+ return StrJoin<std::initializer_list<string>>(
{offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim},
", ");
}
@@ -1987,7 +1992,7 @@ HloInstructionProto HloGatherInstruction::ToProto() const {
std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
return {GatherDimensionNumbersToString(),
- StrCat("slice_sizes={", Join(gather_slice_sizes(), ","), "}")};
+ StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")};
}
bool HloGatherInstruction::IdenticalSlowPath(
@@ -2026,20 +2031,20 @@ HloScatterInstruction::HloScatterInstruction(
}
string HloScatterInstruction::ScatterDimensionNumbersToString() const {
- string update_window_dims =
- StrCat("update_window_dims={",
- Join(scatter_dimension_numbers().update_window_dims(), ","), "}");
+ string update_window_dims = StrCat(
+ "update_window_dims={",
+ StrJoin(scatter_dimension_numbers().update_window_dims(), ","), "}");
string inserted_window_dims = StrCat(
"inserted_window_dims={",
- Join(scatter_dimension_numbers().inserted_window_dims(), ","), "}");
+ StrJoin(scatter_dimension_numbers().inserted_window_dims(), ","), "}");
string scatter_dims_to_operand_dims = StrCat(
"scatter_dims_to_operand_dims={",
- Join(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","),
+ StrJoin(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","),
"}");
string index_vector_dim = StrCat(
"index_vector_dim=", scatter_dimension_numbers().index_vector_dim());
- return Join<std::initializer_list<string>>(
+ return StrJoin<std::initializer_list<string>>(
{update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims,
index_vector_dim},
", ");
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 19e98c6fb4..efdb9e9781 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -248,8 +248,7 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups,
- tensorflow::StringPiece barrier,
- const absl::optional<int64>& all_reduce_id);
+ absl::string_view barrier, const absl::optional<int64>& all_reduce_id);
// Returns the barrier config used for the CrossReplicaSum implementation of
// each backend.
@@ -908,7 +907,7 @@ class HloOutfeedInstruction : public HloInstruction {
explicit HloOutfeedInstruction(const Shape& outfeed_shape,
HloInstruction* operand,
HloInstruction* token_operand,
- tensorflow::StringPiece outfeed_config);
+ absl::string_view outfeed_config);
// Returns the shape for the Outfeed instruction.
const Shape& outfeed_shape() const {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_));
@@ -1061,7 +1060,7 @@ class HloCustomCallInstruction : public HloInstruction {
public:
explicit HloCustomCallInstruction(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece custom_call_target);
+ absl::string_view custom_call_target);
const Window& window() const override {
CHECK(window_ != nullptr);
return *window_;
diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc
index 2e01b090be..0e49d343d6 100644
--- a/tensorflow/compiler/xla/service/hlo_lexer.cc
+++ b/tensorflow/compiler/xla/service/hlo_lexer.cc
@@ -17,20 +17,20 @@ limitations under the License.
#include <unordered_map>
+#include "absl/strings/escaping.h"
+#include "absl/strings/numbers.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/regexp.h"
namespace xla {
-
-using ::tensorflow::StringPiece;
-
namespace {
+using absl::string_view;
+
constexpr int kEOF = -1;
constexpr int kError = -2;
@@ -66,12 +66,12 @@ bool HloLexer::CanDereference(const char* ptr) const {
return ptr < buf_.end() && ptr >= buf_.begin();
}
-tensorflow::StringPiece HloLexer::StringPieceFromPointers(
- const char* begin, const char* end) const {
+absl::string_view HloLexer::StringPieceFromPointers(const char* begin,
+ const char* end) const {
CHECK(begin <= end);
CHECK(begin == buf_.end() || CanDereference(begin));
CHECK(end == buf_.end() || CanDereference(end));
- return tensorflow::StringPiece(begin, end - begin);
+ return absl::string_view(begin, end - begin);
}
tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers(
@@ -235,7 +235,7 @@ TokKind HloLexer::LexIdentifier() {
return TokKind::kAttributeName;
}
- tensorflow::StringPiece identifier =
+ absl::string_view identifier =
StringPieceFromPointers(token_start_, current_ptr_);
// See if this is a keyword.
@@ -306,8 +306,8 @@ TokKind HloLexer::LexNumberOrPattern() {
R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"};
if (RE2::Consume(&consumable, *float_pattern)) {
current_ptr_ = consumable.begin();
- tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(),
- &decimal_val_);
+ CHECK(absl::SimpleAtod(string(token_start_, current_ptr_).c_str(),
+ &decimal_val_));
return TokKind::kDecimal;
}
@@ -339,7 +339,7 @@ TokKind HloLexer::LexNumberOrPattern() {
if (RE2::Consume(&consumable, *int_pattern)) {
current_ptr_ = consumable.begin();
auto slice = StringPieceFromPointers(token_start_, current_ptr_);
- if (tensorflow::strings::safe_strto64(slice, &int64_val_)) {
+ if (absl::SimpleAtoi(slice, &int64_val_)) {
return TokKind::kInt;
}
LOG(ERROR) << "Failed to parse int literal: " << slice;
@@ -375,24 +375,24 @@ std::pair<unsigned, unsigned> HloLexer::GetLineAndColumn(LocTy location) const {
line_no_cache_.last_query = ptr;
line_no_cache_.line_no_of_query = line_no;
size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n');
- if (line_offset == tensorflow::StringPiece::npos) {
+ if (line_offset == absl::string_view::npos) {
line_offset = 0;
}
return {line_no, ptr - start - line_offset};
}
-tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const {
+absl::string_view HloLexer::GetLine(LocTy loc) const {
if (!CanDereference(loc)) {
return "LINE OUT OF RANGE";
}
size_t line_start =
StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n');
- const char* start = line_start == tensorflow::StringPiece::npos
+ const char* start = line_start == absl::string_view::npos
? buf_.begin()
: buf_.begin() + line_start + 1;
size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n');
const char* end =
- line_end == tensorflow::StringPiece::npos ? buf_.end() : loc + line_end;
+ line_end == absl::string_view::npos ? buf_.end() : loc + line_end;
return StringPieceFromPointers(start, end);
}
@@ -404,10 +404,14 @@ TokKind HloLexer::LexString() {
static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"};
if (RE2::Consume(&consumable, *escaping_pattern)) {
current_ptr_ = consumable.begin();
- tensorflow::StringPiece raw =
+ absl::string_view raw =
StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1);
string error;
- if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) {
+ // TODO(b/113077997): Change to absl::CUnescape once it works properly with
+ // copy-on-write std::string implementations.
+ if (!tensorflow::str_util::CUnescape( // non-absl ok
+ tensorflow::StringPiece(raw.data(), raw.size()), // non-absl ok
+ &str_val_, &error)) {
LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error;
return TokKind::kError;
}
diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h
index f9ecd9ccb9..3e2f8bcd52 100644
--- a/tensorflow/compiler/xla/service/hlo_lexer.h
+++ b/tensorflow/compiler/xla/service/hlo_lexer.h
@@ -18,10 +18,10 @@ limitations under the License.
#include <string>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_token.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/types.h"
@@ -34,7 +34,7 @@ namespace xla {
// it directly.
class HloLexer {
public:
- explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) {
+ explicit HloLexer(absl::string_view buf) : buf_(buf) {
current_ptr_ = buf_.begin();
}
@@ -77,7 +77,7 @@ class HloLexer {
std::pair<unsigned, unsigned> GetLineAndColumn(LocTy location) const;
// Returns the whole line given the location.
- tensorflow::StringPiece GetLine(LocTy loc) const;
+ absl::string_view GetLine(LocTy loc) const;
private:
// Returns the current character. If it's neither the end of input buffer nor
@@ -89,8 +89,8 @@ class HloLexer {
// Creates StringPiece with the given begin and end. Exits if the begin > end,
// or it's out of the range of the current buffer.
- tensorflow::StringPiece StringPieceFromPointers(const char* begin,
- const char* end) const;
+ absl::string_view StringPieceFromPointers(const char* begin,
+ const char* end) const;
tensorflow::RegexpStringPiece RegexpStringPieceFromPointers(
const char* begin, const char* end) const;
@@ -107,7 +107,7 @@ class HloLexer {
TokKind LexNumberOrPattern();
TokKind LexString();
- const tensorflow::StringPiece buf_;
+ const absl::string_view buf_;
const char* current_ptr_;
// Information about the current token.
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
index 18f17b75ae..3a1dd471c6 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <deque>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -29,17 +30,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
+namespace {
using Worklist = std::deque<const HloInstruction*>;
using Workset = std::unordered_set<const HloInstruction*>;
-namespace {
-
void AddToWorklist(const HloInstruction* instruction, Worklist* worklist,
Workset* workset) {
if (workset->count(instruction) == 0) {
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc
index 7e4b883435..5269cad94d 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers.cc
@@ -15,15 +15,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace testing {
-using ::tensorflow::str_util::Join;
-
bool HloMatcher::MatchAndExplain(
const HloInstruction* instruction,
::testing::MatchResultListener* listener) const {
@@ -210,8 +208,8 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain(
dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) {
*listener << instruction->ToString()
<< " has wrong lhs_contracting_dimensions (got {"
- << Join(dim_nums.lhs_contracting_dimensions(), ",") << "} want {"
- << lhs_contracting_dim_ << "})";
+ << absl::StrJoin(dim_nums.lhs_contracting_dimensions(), ",")
+ << "} want {" << lhs_contracting_dim_ << "})";
return false;
}
@@ -219,8 +217,8 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain(
dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) {
*listener << instruction->ToString()
<< " has wrong rhs_contracting_dimensions (got {"
- << Join(dim_nums.rhs_contracting_dimensions(), ",") << "} want {"
- << rhs_contracting_dim_ << "})";
+ << absl::StrJoin(dim_nums.rhs_contracting_dimensions(), ",")
+ << "} want {" << rhs_contracting_dim_ << "})";
return false;
}
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index 0a442e77f0..9ace0d76e0 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -306,7 +306,7 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape));
}
inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
- tensorflow::StringPiece shape) {
+ absl::string_view shape) {
return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(
ShapeUtil::ParseShapeString(shape).ValueOrDie()));
}
@@ -316,7 +316,7 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
new ::xla::testing::HloShapeAndLayoutMatcher(shape));
}
inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
- tensorflow::StringPiece shape) {
+ absl::string_view shape) {
return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher(
ShapeUtil::ParseShapeString(shape).ValueOrDie()));
}
@@ -329,7 +329,7 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
}
// Matcher for Sharding from sharding string
inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
- tensorflow::StringPiece sharding) {
+ absl::string_view sharding) {
return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher(
ParseSharding(sharding).ValueOrDie()));
}
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index d60b76d63f..78167335c8 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -24,11 +24,11 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -410,7 +410,7 @@ HloInstruction* HloModule::OutlineExpressionFromComputation(
string error_message =
"The subcomputation to outline has multiple outputs:\n";
for (HloInstruction* output : outputs) {
- tensorflow::strings::StrAppend(&error_message, output->ToString(), "\n");
+ absl::StrAppend(&error_message, output->ToString(), "\n");
}
LOG(FATAL) << error_message;
}
@@ -536,8 +536,7 @@ uint64 HloModule::RandomNew64() const {
return rng_();
}
-HloComputation* HloModule::GetComputationWithName(
- tensorflow::StringPiece name) {
+HloComputation* HloModule::GetComputationWithName(absl::string_view name) {
auto computations_in_module = computations();
auto it = absl::c_find_if(
computations_in_module,
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index d2e726a0db..cf129b835d 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -24,6 +24,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_clone_context.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
#include "tensorflow/core/platform/logging.h"
@@ -142,7 +142,7 @@ class HloModule {
// Returns the computation in this module that has the name `name`. Returns
// null if there is no such computation.
- HloComputation* GetComputationWithName(tensorflow::StringPiece name);
+ HloComputation* GetComputationWithName(absl::string_view name);
// Gets the number of computations in this module.
int64 computation_count() const { return computations_.size(); }
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc
index f9708283eb..9bfa3a5f45 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_config.cc
@@ -19,14 +19,14 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
-using tensorflow::strings::StrAppend;
+using absl::StrAppend;
HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape,
bool ignore_layouts)
@@ -39,15 +39,14 @@ void HloModuleConfig::SetDefaultComputationLayout(
}
string HloModuleConfig::compilation_cache_key() const {
- string key =
- tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled());
+ string key = absl::StrCat("profiling=", hlo_profiling_enabled());
StrAppend(&key, "::(");
std::vector<string> params;
for (const ShapeLayout& param_layout :
entry_computation_layout_->parameter_layouts()) {
params.push_back(param_layout.shape().DebugString());
}
- StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ",
+ StrAppend(&key, absl::StrJoin(params, ", "), ") => ",
entry_computation_layout_->result_shape().SerializeAsString());
if (seed() != 0) {
// TODO(b/32083678): force recompilation to reset global state.
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h
index 29024085c1..12ca2340a6 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce.h
+++ b/tensorflow/compiler/xla/service/hlo_module_dce.h
@@ -31,7 +31,7 @@ namespace xla {
class HloModuleDCE : public HloPassInterface {
public:
~HloModuleDCE() override {}
- tensorflow::StringPiece name() const override { return "hlo-module-dce"; }
+ absl::string_view name() const override { return "hlo-module-dce"; }
// Run the pass on the given module. Returns whether the module was changed
// (instructions were removed).
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
index 1a4da388e4..b5c7681edd 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <utility>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -270,8 +270,8 @@ Status HloModuleGroupUtil::VisitTopologicalOrder(
string cyclic_instructions;
for (const auto& state : *visit_state) {
if (state.second == VisitState::kVisiting) {
- tensorflow::strings::StrAppend(&cyclic_instructions,
- state.first->ToString(), "\n");
+ absl::StrAppend(&cyclic_instructions, state.first->ToString(),
+ "\n");
}
}
// TODO(b/64305524): Improve the error message to print out the
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index 6c1e015f77..8fe91c7278 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
@@ -254,6 +254,10 @@ bool HloOrdering::LiveRangeStrictlyBefore(
}
// All uses of 'a' must be before 'b' is defined.
for (const HloUse& use : a.uses()) {
+ if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(),
+ use.instruction)) {
+ continue;
+ }
if (!UseIsBeforeValueDefinition(use, b, dataflow)) {
VLOG(4) << "use of " << a << " (" << use << ") not before " << b
<< " is defined";
@@ -317,7 +321,7 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const {
}
}
}
- return tensorflow::str_util::Join(pieces, "\n");
+ return absl::StrJoin(pieces, "\n");
}
DependencyHloOrdering::DependencyHloOrdering(const HloModule* module)
@@ -388,7 +392,7 @@ string SequentialHloOrdering::ToString() const {
tensorflow::strings::Printf(" %s", instruction->name().c_str()));
}
}
- return tensorflow::str_util::Join(pieces, "\n");
+ return absl::StrJoin(pieces, "\n");
}
std::ostream& operator<<(
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 90a493d29f..df789e6222 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -17,6 +17,9 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
@@ -26,22 +29,18 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
namespace {
-using ::absl::nullopt;
-using ::absl::optional;
-using ::tensorflow::StringPiece;
-using ::tensorflow::str_util::Join;
-using ::tensorflow::str_util::Split;
-using ::tensorflow::str_util::SplitAndParseAsInts;
+using absl::nullopt;
+using absl::optional;
+using absl::StrAppend;
+using absl::StrCat;
+using absl::StrJoin;
using ::tensorflow::strings::Printf;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
const double kF16max = 65504;
@@ -50,7 +49,7 @@ class HloParser {
public:
using LocTy = HloLexer::LocTy;
- explicit HloParser(StringPiece str, const HloModuleConfig& config)
+ explicit HloParser(absl::string_view str, const HloModuleConfig& config)
: lexer_(str), config_(config) {}
// Runs the parser. Returns false if an error occurred.
@@ -60,7 +59,7 @@ class HloParser {
std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); }
// Returns the error information.
- string GetError() const { return Join(error_, "\n"); }
+ string GetError() const { return StrJoin(error_, "\n"); }
// Stand alone parsing utils for various aggregate data types.
StatusOr<HloSharding> ParseShardingOnly();
@@ -253,8 +252,8 @@ class HloParser {
bool CanBeParamListToShape();
// Logs the current parsing line and the given message. Always returns false.
- bool TokenError(StringPiece msg);
- bool Error(LocTy loc, StringPiece msg);
+ bool TokenError(absl::string_view msg);
+ bool Error(LocTy loc, absl::string_view msg);
// If the current token is 'kind', eats it (i.e. lexes the next token) and
// returns true.
@@ -293,6 +292,17 @@ class HloParser {
missing_instruction_hook_;
};
+bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) {
+ for (const auto& split : absl::StrSplit(s, delim)) {
+ int64 val;
+ if (!absl::SimpleAtoi(split, &val)) {
+ return false;
+ }
+ out->push_back(val);
+ }
+ return true;
+}
+
// Creates replica groups from the provided nested array. groups[i] represents
// the replica ids for group 'i'.
std::vector<ReplicaGroup> CreateReplicaGroups(
@@ -307,7 +317,7 @@ std::vector<ReplicaGroup> CreateReplicaGroups(
return replica_groups;
}
-bool HloParser::Error(LocTy loc, StringPiece msg) {
+bool HloParser::Error(LocTy loc, absl::string_view msg) {
auto line_col = lexer_.GetLineAndColumn(loc);
const unsigned line = line_col.first;
const unsigned col = line_col.second;
@@ -317,12 +327,12 @@ bool HloParser::Error(LocTy loc, StringPiece msg) {
error_lines.push_back(std::string(lexer_.GetLine(loc)));
error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^"));
- error_.push_back(Join(error_lines, "\n"));
+ error_.push_back(StrJoin(error_lines, "\n"));
VLOG(1) << "Error: " << error_.back();
return false;
}
-bool HloParser::TokenError(StringPiece msg) {
+bool HloParser::TokenError(absl::string_view msg) {
return Error(lexer_.GetLoc(), msg);
}
@@ -1806,10 +1816,10 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
std::vector<tensorflow::int64> elems_seen_until_dim(
elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim);
return StrCat("[",
- Join(elems_seen_until_dim, ",",
- [](string* out, const tensorflow::int64& num_elems) {
- StrAppend(out, num_elems - 1);
- }),
+ StrJoin(elems_seen_until_dim, ",",
+ [](string* out, const tensorflow::int64& num_elems) {
+ StrAppend(out, num_elems - 1);
+ }),
"]");
};
do {
@@ -1996,7 +2006,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
return Error(
index_loc,
StrCat("invalid multi-dimension index for shape with rank ", rank,
- ": [", Join(index, ", "), "]"));
+ ": [", StrJoin(index, ", "), "]"));
}
}
if (!ParseToken(TokKind::kColon,
@@ -2173,10 +2183,10 @@ bool HloParser::ParseAttributeHelper(
} else {
allowed_attrs = StrCat(
"Allowed attributes: ",
- Join(attrs, ", ",
- [&](string* out, const std::pair<string, AttrConfig>& kv) {
- StrAppend(out, kv.first);
- }));
+ StrJoin(attrs, ", ",
+ [&](string* out, const std::pair<string, AttrConfig>& kv) {
+ StrAppend(out, kv.first);
+ }));
}
return Error(loc, Printf("unexpected attribute \"%s\". %s", name.c_str(),
allowed_attrs.c_str()));
@@ -2489,20 +2499,24 @@ bool HloParser::ParseConvolutionDimensionNumbers(
}
string str = lexer_.GetStrVal();
- // The str is expected to have 3 items, lhs, rhs, out, and it must looks like
+ // The str is expected to have 3 items, lhs, rhs, out, and it must look like
// lhs_rhs->out, that is, the first separator is "_" and the second is "->".
- // So we replace the "->" with "_" and then split on "_".
- str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->",
- /*newsub=*/"_",
- /*replace_all=*/false);
- std::vector<string> lhs_rhs_out = Split(str, "_");
- if (lhs_rhs_out.size() != 3) {
+ std::vector<string> split1 = absl::StrSplit(str, "_");
+ if (split1.size() != 2) {
LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
<< str;
}
+ std::vector<string> split2 = absl::StrSplit(split1[1], "->");
+ if (split2.size() != 2) {
+ LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
+ << str;
+ }
+ absl::string_view lhs = split1[0];
+ absl::string_view rhs = split2[0];
+ absl::string_view out = split2[1];
- const tensorflow::int64 rank = lhs_rhs_out[0].length();
- if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) {
+ const tensorflow::int64 rank = lhs.length();
+ if (rank != rhs.length() || rank != out.length()) {
return TokenError(
"convolution lhs, rhs, and output must have the same rank");
}
@@ -2517,8 +2531,7 @@ bool HloParser::ParseConvolutionDimensionNumbers(
// lhs
{
- const string& lhs = lhs_rhs_out[0];
- if (!is_unique(lhs)) {
+ if (!is_unique(string(lhs))) {
return TokenError(
StrCat("expects unique lhs dimension numbers, but sees ", lhs));
}
@@ -2541,8 +2554,7 @@ bool HloParser::ParseConvolutionDimensionNumbers(
}
// rhs
{
- const string& rhs = lhs_rhs_out[1];
- if (!is_unique(rhs)) {
+ if (!is_unique(string(rhs))) {
return TokenError(
StrCat("expects unique rhs dimension numbers, but sees ", rhs));
}
@@ -2565,8 +2577,7 @@ bool HloParser::ParseConvolutionDimensionNumbers(
}
// output
{
- const string& out = lhs_rhs_out[2];
- if (!is_unique(out)) {
+ if (!is_unique(string(out))) {
return TokenError(
StrCat("expects unique output dimension numbers, but sees ", out));
}
@@ -2832,7 +2843,7 @@ bool HloParser::ParseDxD(const string& name,
// 2D or higher.
if (lexer_.GetKind() == TokKind::kDxD) {
string str = lexer_.GetStrVal();
- if (!SplitAndParseAsInts(str, 'x', result)) {
+ if (!SplitToInt64s(str, 'x', result)) {
return Error(loc,
Printf("expects sub-attribute '%s=ixj...'", name.c_str()));
}
@@ -2852,10 +2863,9 @@ bool HloParser::ParseWindowPad(
return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
}
string str = lexer_.GetStrVal();
- std::vector<string> padding_str = Split(str, 'x');
- for (int i = 0; i < padding_str.size(); i++) {
+ for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
std::vector<tensorflow::int64> low_high;
- if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) ||
+ if (!SplitToInt64s(padding_dim_str, '_', &low_high) ||
low_high.size() != 2) {
return Error(loc,
"expects padding_low and padding_high separated by '_'");
@@ -2876,10 +2886,9 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) {
}
LocTy loc = lexer_.GetLoc();
string str = lexer_.GetStrVal();
- std::vector<string> padding_str = Split(str, 'x');
- for (const auto& padding_dim_str : padding_str) {
+ for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
std::vector<tensorflow::int64> padding_dim;
- if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) ||
+ if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) ||
(padding_dim.size() != 2 && padding_dim.size() != 3)) {
return Error(loc,
"expects padding config pattern like 'low_high_interior' or "
@@ -3162,7 +3171,7 @@ Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder,
} // namespace
StatusOr<std::unique_ptr<HloModule>> ParseHloString(
- tensorflow::StringPiece str, const HloModuleConfig& config) {
+ absl::string_view str, const HloModuleConfig& config) {
HloParser parser(str, config);
if (!parser.Run()) {
return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str());
@@ -3170,39 +3179,38 @@ StatusOr<std::unique_ptr<HloModule>> ParseHloString(
return parser.ConsumeHloModule();
}
-StatusOr<std::unique_ptr<HloModule>> ParseHloString(
- tensorflow::StringPiece str) {
+StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str) {
HloModuleConfig config;
return ParseHloString(str, config);
}
StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
- tensorflow::StringPiece str, tensorflow::StringPiece name) {
+ absl::string_view str, absl::string_view name) {
HloModuleConfig config;
HloParser parser(str, config);
- auto builder = absl::make_unique<HloComputation::Builder>(name.ToString());
+ auto builder = absl::make_unique<HloComputation::Builder>(string(name));
string root_name;
TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name));
std::unique_ptr<HloComputation> computation = builder->Build();
- auto module = absl::make_unique<HloModule>(name.ToString(), config);
+ auto module = absl::make_unique<HloModule>(string(name), config);
module->AddEntryComputation(std::move(computation));
return std::move(module);
}
-StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str) {
+StatusOr<HloSharding> ParseSharding(absl::string_view str) {
HloModuleConfig config;
HloParser parser(str, config);
return parser.ParseShardingOnly();
}
-StatusOr<Window> ParseWindow(tensorflow::StringPiece str) {
+StatusOr<Window> ParseWindow(absl::string_view str) {
HloModuleConfig config;
HloParser parser(str, config);
return parser.ParseWindowOnly();
}
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
- tensorflow::StringPiece str) {
+ absl::string_view str) {
HloModuleConfig config;
HloParser parser(str, config);
return parser.ParseConvolutionDimensionNumbersOnly();
diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h
index 6c184bfe9a..0c64b50481 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.h
+++ b/tensorflow/compiler/xla/service/hlo_parser.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_lexer.h"
@@ -32,32 +33,31 @@ namespace xla {
// The api of the hlo parser. Given a string in the HloModule::ToString()
// format, parses the string and creates a HloModule with the given config.
StatusOr<std::unique_ptr<HloModule>> ParseHloString(
- tensorflow::StringPiece str, const HloModuleConfig& config);
+ absl::string_view str, const HloModuleConfig& config);
// Parses the text for a single HLO operation into an HLO module with a function
// that runs that operation (with the same parameters) as its entry computation.
StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
- tensorflow::StringPiece str, tensorflow::StringPiece name = "single_op");
+ absl::string_view str, absl::string_view name = "single_op");
// The api of the hlo parser. Given a string in the HloModule::ToString()
// format, parses the string and creates a HloModule with default config.
-StatusOr<std::unique_ptr<HloModule>> ParseHloString(
- tensorflow::StringPiece str);
+StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str);
// Parses the result of HloSharding::ToString(), e.g. "{replicated}".
-StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str);
+StatusOr<HloSharding> ParseSharding(absl::string_view str);
// Parses the result of window_util::ToString(const Window&).
-StatusOr<Window> ParseWindow(tensorflow::StringPiece str);
+StatusOr<Window> ParseWindow(absl::string_view str);
// Parses the result of ConvolutionDimensionNumbersToString(), e.g.
// "b0f_0io->b0f".
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
- tensorflow::StringPiece str);
+ absl::string_view str);
// ParseHloString sharding from str. str is supposed to contain the body of the
// sharding, i.e. just the rhs of the "sharding={...}" attribute string.
-StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str);
+StatusOr<HloSharding> ParseSharding(absl::string_view str);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index f310b36bfb..b3d3ccda74 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -16,20 +16,19 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include <string>
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
-namespace op = ::xla::testing::opcode_matchers;
-
namespace xla {
-
namespace {
-using ::tensorflow::StringPiece;
+namespace op = ::xla::testing::opcode_matchers;
+using absl::string_view;
struct TestData {
string test_name;
@@ -1128,8 +1127,8 @@ ENTRY Computation {
class HloParserTest : public ::testing::Test,
public ::testing::WithParamInterface<TestData> {
protected:
- static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
- EXPECT_TRUE(tensorflow::str_util::StrContains(s, expected))
+ static void ExpectHasSubstr(string_view s, string_view expected) {
+ EXPECT_TRUE(absl::StrContains(s, expected))
<< "'" << s << "' does not contain '" << expected << "'";
}
@@ -1393,15 +1392,14 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
)";
- ExpectHasSubstr(ParseHloString(tensorflow::strings::StrCat(
- prefix, ",dim_labels=00_01_10", suffix))
- .status()
- .error_message(),
- "expects dim labels pattern");
+ ExpectHasSubstr(
+ ParseHloString(absl::StrCat(prefix, ",dim_labels=00_01_10", suffix))
+ .status()
+ .error_message(),
+ "expects dim labels pattern");
ExpectHasSubstr(
- ParseHloString(tensorflow::strings::StrCat(
- prefix, ",dim_labels=010_1100->010", suffix))
+ ParseHloString(absl::StrCat(prefix, ",dim_labels=010_1100->010", suffix))
.status()
.error_message(),
"must have the same rank");
diff --git a/tensorflow/compiler/xla/service/hlo_pass_interface.h b/tensorflow/compiler/xla/service/hlo_pass_interface.h
index 0cddf8fb8f..f1ad0f9b01 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_interface.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_interface.h
@@ -29,7 +29,7 @@ namespace xla {
class HloPassInterface {
public:
virtual ~HloPassInterface() = default;
- virtual tensorflow::StringPiece name() const = 0;
+ virtual absl::string_view name() const = 0;
// Run the pass on the given HLO module. Return whether it modified the
// module.
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
index d8f1ab916b..df99e131d8 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -17,22 +17,22 @@ limitations under the License.
#include <functional>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
-
namespace xla {
-
namespace {
+
+using absl::StrAppend;
+using absl::StrCat;
+
void DumpModuleGraph(const HloModule& module, const string& message) {
hlo_graph_dumper::MaybeDumpHloModule(module, message);
VLOG(3) << "HLO " << message << ":";
@@ -68,7 +68,7 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
repeated_field.end());
if (!disabled_passes.empty()) {
VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
- << tensorflow::str_util::Join(disabled_passes, ", ");
+ << absl::StrJoin(disabled_passes, ", ");
}
auto run_invariant_checkers = [this,
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
index 3bb1342aa3..1d41a4dac1 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
@@ -34,7 +34,7 @@ namespace xla {
class HloPassPipeline : public HloPassInterface {
public:
explicit HloPassPipeline(const string& name) : name_(name) {}
- tensorflow::StringPiece name() const override { return name_; }
+ absl::string_view name() const override { return name_; }
// Add a pass to the pipeline. It should be called with the arguments for the
// pass constructor:
diff --git a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc
index b9cca13870..c3cacd7ce6 100644
--- a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace {
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 04e4a29359..9cc1f5a10e 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -21,6 +21,8 @@ limitations under the License.
#include <string>
#include "absl/container/inlined_vector.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
@@ -38,17 +40,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
-using ::tensorflow::strings::HumanReadableNumBytes;
-
namespace xla {
-
namespace {
+using ::tensorflow::strings::HumanReadableNumBytes;
+
// Potential optimizations:
// . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue
// of candidates.
@@ -207,11 +206,10 @@ class InstructionList {
Item* to_insert, tensorflow::gtl::ArraySlice<Item*> before_instructions) {
VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name()
<< " before {"
- << tensorflow::str_util::Join(before_instructions, ", ",
- [](string* out, Item* item) {
- tensorflow::strings::StrAppend(
- out, item->instruction->name());
- })
+ << absl::StrJoin(before_instructions, ", ",
+ [](string* out, Item* item) {
+ absl::StrAppend(out, item->instruction->name());
+ })
<< "}";
// Find the minimal position number of any instruction in
@@ -394,10 +392,9 @@ class MemoryUsageTracker {
int64 unfinished_user_count;
string ToString() const {
- return tensorflow::strings::StrCat(
- "Buffer ", id, " (defined by ",
- defining_instruction->instruction->name(), ", size ", size,
- " bytes)");
+ return absl::StrCat("Buffer ", id, " (defined by ",
+ defining_instruction->instruction->name(), ", size ",
+ size, " bytes)");
}
};
@@ -741,29 +738,27 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
}
string MemoryUsageTracker::ToString() const {
- string output = tensorflow::strings::StrCat("MemoryUsageTracker for ",
- computation_->name(), "\n");
- tensorflow::strings::StrAppend(
- &output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
- memory_usage(), " bytes)");
+ string output =
+ absl::StrCat("MemoryUsageTracker for ", computation_->name(), "\n");
+ absl::StrAppend(&output,
+ "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
+ memory_usage(), " bytes)");
for (auto* item = instruction_list_.first(); item != nullptr;
item = instruction_list_.next(item)) {
const HloInstruction* instruction = item->instruction;
string inprogress = item == in_progress_item_ ? " in-progress" : "";
string placed = item->placed ? " placed" : "";
- tensorflow::strings::StrAppend(&output, " ", instruction->name(),
- inprogress, placed, "\n Defines:\n");
+ absl::StrAppend(&output, " ", instruction->name(), inprogress, placed,
+ "\n Defines:\n");
for (BufferId buffer_id : item->buffers_defined) {
const Buffer& buffer = buffers_[buffer_id];
string live = IsCurrentlyLive(buffer_id) ? " live" : "";
- tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live,
- ", ", buffer.unfinished_user_count,
- " unfinished uses\n");
+ absl::StrAppend(&output, " ", buffer.ToString(), live, ", ",
+ buffer.unfinished_user_count, " unfinished uses\n");
}
- tensorflow::strings::StrAppend(&output, " Uses:\n");
+ absl::StrAppend(&output, " Uses:\n");
for (BufferId buffer_id : item->buffers_used) {
- tensorflow::strings::StrAppend(&output, " ",
- buffers_[buffer_id].ToString(), "\n");
+ absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n");
}
}
return output;
@@ -781,10 +776,9 @@ bool MemoryUsageTracker::Check() const {
CHECK(elements_are_unique(defined_buffers))
<< "Instruction " << instruction->name()
<< " does not have unique defined buffers: "
- << tensorflow::str_util::Join(
+ << absl::StrJoin(
defined_buffers, ", ", [this](string* out, BufferId buffer_id) {
- tensorflow::strings::StrAppend(
- out, buffers_.at(buffer_id).ToString());
+ absl::StrAppend(out, buffers_.at(buffer_id).ToString());
});
for (const Buffer& buffer : buffers_) {
@@ -804,10 +798,9 @@ bool MemoryUsageTracker::Check() const {
CHECK(elements_are_unique(used_buffers))
<< "Instruction " << instruction->name()
<< " does not have unique used buffers: "
- << tensorflow::str_util::Join(
+ << absl::StrJoin(
used_buffers, ", ", [this](string* out, BufferId buffer_id) {
- tensorflow::strings::StrAppend(
- out, buffers_.at(buffer_id).ToString());
+ absl::StrAppend(out, buffers_.at(buffer_id).ToString());
});
}
for (const Buffer& buffer : buffers_) {
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index 8f3ae9c621..7bd8a4a544 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -32,7 +32,7 @@ limitations under the License.
namespace xla {
/*static*/ StatusOr<std::unique_ptr<HloModule>>
-HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string,
+HloRunner::CreateModuleFromString(const absl::string_view hlo_string,
const DebugOptions& debug_options) {
HloModuleConfig config;
config.set_debug_options(debug_options);
diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h
index 65537f07f5..cfc519063e 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.h
+++ b/tensorflow/compiler/xla/service/hlo_runner.h
@@ -87,8 +87,7 @@ class HloRunner {
// Converts an HloModule from the given hlo textual IR string (in
// HloModule::ToString format).
static StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
- const tensorflow::StringPiece hlo_string,
- const DebugOptions& debug_options);
+ const absl::string_view hlo_string, const DebugOptions& debug_options);
// Reads the proto file in xla.HloProto format, creates and returns the
// HloModule.
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc
index 27cc5361cd..393824d920 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc
@@ -28,16 +28,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
-using ::tensorflow::strings::HumanReadableNumBytes;
-
namespace xla {
-
namespace {
+using ::tensorflow::strings::HumanReadableNumBytes;
+
// Class implementing a list scheduler of HLO instructions which produces a
// sequence which minimizes memory usage by preferring to schedule the node that
// frees bigger buffer and defines smaller outputs.
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 903fbbec1a..980dae07ce 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -15,13 +15,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrCat;
+using absl::StrCat;
+using absl::StrJoin;
HloSharding HloSharding::AssignDevice(int64 device_id) {
return HloSharding(device_id);
@@ -71,12 +72,9 @@ HloSharding HloSharding::SingleTuple(const Shape& tuple_shape,
const HloSharding& sharding) {
CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape);
CHECK(!sharding.IsTuple()) << sharding.ToString();
- int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape);
+ int64 leaf_count = RequiredLeaves(tuple_shape);
std::vector<HloSharding> flattened_list;
- flattened_list.reserve(leaf_count);
- for (int64 i = 0; i < leaf_count; ++i) {
- flattened_list.push_back(sharding);
- }
+ flattened_list.resize(leaf_count, sharding);
return HloSharding(flattened_list);
}
@@ -92,7 +90,7 @@ string HloSharding::ToString() const {
for (const HloSharding& element : tuple_elements_) {
parts.push_back(element.ToString());
}
- return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}");
+ return StrCat("{", absl::StrJoin(parts, ", "), "}");
}
if (replicated_) {
@@ -101,8 +99,8 @@ string HloSharding::ToString() const {
return StrCat(
"{maximal device=", static_cast<int64>(*tile_assignment_.begin()), "}");
} else {
- return StrCat("{devices=[", Join(tile_assignment_.dimensions(), ","), "]",
- Join(tile_assignment_, ","), "}");
+ return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","),
+ "]", StrJoin(tile_assignment_, ","), "}");
}
}
@@ -445,7 +443,7 @@ absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const {
}
for (int64 i = 1; i < tuple_elements_.size(); ++i) {
if (tuple_elements_[0] != tuple_elements_[i]) {
- return absl::optional<HloSharding>();
+ return absl::nullopt;
}
}
return tuple_elements_.front();
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index 4c64ac60c5..be51c3f55b 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -260,9 +260,9 @@ class HloSharding {
bool maximal_;
bool tuple_;
Array<int64> tile_assignment_;
- // Only non-empty when tuple_ is true, but because empty tuples are allowed
- // may also be empty even then. This is a flattened list of all the leaf
- // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order).
+ // Only non-empty when tuple_ is true. If a tuple is empty then one entry is
+ // present for the root. This is a flattened list of all the leaf shardings in
+ // a tuple shape, by pre-order walk (ShapeTree iterator order).
std::vector<HloSharding> tuple_elements_;
};
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
index 6f0353ee5f..a9b3b66934 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
@@ -118,13 +118,17 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain,
return Status::OK();
}
-std::unique_ptr<HloSharding> CloneShardingForDomain(
- const HloSharding& sharding) {
- auto single_sharding = sharding.ExtractSingleSharding();
+// For tuple shardings if every element have the same sharsing then we want to
+// treat them as single element sharsings to insert less domain separation as a
+// domain can prevent some optimizations and we want to minimize that from
+// happening.
+std::shared_ptr<const HloSharding> CloneShardingForDomain(
+ std::shared_ptr<const HloSharding> sharding) {
+ auto single_sharding = sharding->ExtractSingleSharding();
if (!single_sharding) {
- return absl::make_unique<HloSharding>(sharding);
+ return sharding;
}
- return absl::make_unique<HloSharding>(*single_sharding);
+ return std::make_shared<const HloSharding>(*single_sharding);
}
Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain,
@@ -280,66 +284,18 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain,
return Status::OK();
}
-// Creates a kDomain instruction to be placed between instruction and operand.
-// The kDomain instruction will be created only if the sharding differ between
-// the instruction and the operand.
-std::unique_ptr<HloInstruction> CreateDomain(HloInstruction* instruction,
- HloInstruction* root,
- HloInstruction* operand) {
- const HloSharding* instruction_sharding =
- instruction->has_sharding() ? &instruction->sharding() : nullptr;
- const HloSharding* root_sharding =
- root->has_sharding() ? &root->sharding() : nullptr;
- // No need for domain if they both have no sharding.
- if (instruction_sharding == nullptr && root_sharding == nullptr) {
- return nullptr;
- }
- // No need for domain if they match.
- if (instruction_sharding != nullptr && root_sharding != nullptr &&
- ShardingMatches(*instruction_sharding, *root_sharding)) {
- return nullptr;
- }
- std::unique_ptr<HloSharding> real_instruction_sharding;
- std::unique_ptr<HloSharding> real_operand_sharding;
- if (instruction_sharding != nullptr) {
- real_instruction_sharding = CloneShardingForDomain(*instruction_sharding);
- }
- if (root_sharding != nullptr) {
- real_operand_sharding = CloneShardingForDomain(*root_sharding);
- }
- VLOG(3) << "Creating domain:";
- VLOG(3) << " Instruction: " << instruction->name();
- VLOG(3) << " Operand: " << operand->name();
- VLOG(3) << " User side sharding: "
- << (real_instruction_sharding != nullptr
- ? real_instruction_sharding->ToString()
- : "None");
- VLOG(3) << " Operand side sharding: "
- << (real_operand_sharding != nullptr
- ? real_operand_sharding->ToString()
- : "None");
-
- std::unique_ptr<DomainMetadata> operand_side_metadata =
- absl::make_unique<ShardingMetadata>(std::move(real_operand_sharding));
- std::unique_ptr<DomainMetadata> user_side_metadata =
- absl::make_unique<ShardingMetadata>(std::move(real_instruction_sharding));
- return HloInstruction::CreateDomain(operand->shape(), operand,
- std::move(operand_side_metadata),
- std::move(user_side_metadata));
-}
-
-StatusOr<std::unique_ptr<HloSharding>> ExtractOriginalCommonSharding(
+StatusOr<std::shared_ptr<const HloSharding>> ExtractOriginalCommonSharding(
tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
// If we are here, all the instructions being passed had the same sharding
// (or no sharding), by the means of the ShardingMatches() API.
// As such, no kDomain was inserted, and here we are asked to extract the
// original common sharding.
// All the instructions passed to this API are part of the same computation.
- const HloSharding* sharding = nullptr;
+ std::shared_ptr<const HloSharding> sharding;
for (HloInstruction* instruction : instructions) {
if (instruction->has_sharding()) {
if (sharding == nullptr) {
- sharding = &instruction->sharding();
+ sharding = instruction->sharding_ptr();
} else {
TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding()))
<< "Sharding " << *sharding << " does not match the one in "
@@ -348,10 +304,10 @@ StatusOr<std::unique_ptr<HloSharding>> ExtractOriginalCommonSharding(
}
}
if (sharding == nullptr) {
- return std::unique_ptr<HloSharding>();
+ return std::shared_ptr<const HloSharding>();
}
VLOG(4) << "Extracted sharding is " << *sharding;
- return CloneShardingForDomain(*sharding);
+ return CloneShardingForDomain(sharding);
}
} // namespace
@@ -405,7 +361,7 @@ Status ShardingMetadata::NormalizeShardingDomain(
TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding));
}
} else {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloSharding> sharding,
+ TF_ASSIGN_OR_RETURN(std::shared_ptr<const HloSharding> sharding,
ExtractOriginalCommonSharding(domain.instructions));
if (sharding != nullptr) {
VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString();
@@ -417,10 +373,75 @@ Status ShardingMetadata::NormalizeShardingDomain(
return Status::OK();
}
-std::unique_ptr<HloInstruction> CreateShardingDomain(
- HloInstruction* instruction, HloInstruction* root,
- HloInstruction* operand) {
- return CreateDomain(instruction, root, operand);
+// Creates a kDomain instruction to be placed between instruction and operand.
+// The kDomain instruction will be created only if the sharding differ between
+// the instruction and the operand.
+HloInstruction* ShardingDomainCreator::operator()(HloInstruction* instruction,
+ HloInstruction* root,
+ HloInstruction* operand) {
+ auto instruction_sharding = instruction->sharding_ptr();
+ auto root_sharding = root->sharding_ptr();
+ // No need for domain if they both have no sharding.
+ if (instruction_sharding == nullptr && root_sharding == nullptr) {
+ return nullptr;
+ }
+ // No need for domain if they match.
+ if (instruction_sharding != nullptr && root_sharding != nullptr &&
+ ShardingMatches(*instruction_sharding, *root_sharding)) {
+ return nullptr;
+ }
+
+ if (instruction_sharding != nullptr) {
+ instruction_sharding = CloneShardingForDomain(instruction_sharding);
+ }
+ if (root_sharding != nullptr) {
+ root_sharding = CloneShardingForDomain(root_sharding);
+ }
+
+ auto it = domain_cse_map_.find({operand, instruction_sharding});
+ if (it != domain_cse_map_.end()) {
+ return it->second;
+ }
+
+ VLOG(3) << "Creating domain:";
+ VLOG(3) << " Instruction: " << instruction->name();
+ VLOG(3) << " Operand: " << operand->name();
+ VLOG(3) << " User side sharding: "
+ << (instruction_sharding != nullptr ? instruction_sharding->ToString()
+ : "None");
+ VLOG(3) << " Operand side sharding: "
+ << (root_sharding != nullptr ? root_sharding->ToString() : "None");
+
+ HloInstruction* domain =
+ operand->parent()->AddInstruction(HloInstruction::CreateDomain(
+ operand->shape(), operand,
+ absl::make_unique<ShardingMetadata>(root_sharding),
+ absl::make_unique<ShardingMetadata>(instruction_sharding)));
+ domain_cse_map_.emplace(DomainCseMapKey{operand, instruction_sharding},
+ domain);
+ return domain;
+}
+
+bool ShardingDomainCreator::DomainCseMapKey::operator==(
+ const ShardingDomainCreator::DomainCseMapKey& other) const {
+ if (instruction != other.instruction) {
+ return false;
+ }
+ if (sharding == nullptr && other.sharding == nullptr) {
+ return true;
+ }
+ if (sharding == nullptr || other.sharding == nullptr) {
+ return false;
+ }
+ return *sharding == *other.sharding;
+}
+
+size_t ShardingDomainCreator::DomainCseMapHasher::operator()(
+ const ShardingDomainCreator::DomainCseMapKey& key) const {
+ return tensorflow::Hash64Combine(
+ std::hash<const HloInstruction*>{}(key.instruction),
+ key.sharding ? key.sharding->Hash()
+ : static_cast<size_t>(0x297814aaad196e6dULL));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
index dc258e4094..7a6b0d9abc 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
@@ -27,12 +27,12 @@ namespace xla {
// A DomainMetadata implementation that internally wraps a sharding attribute.
class ShardingMetadata : public DomainMetadata {
public:
- explicit ShardingMetadata(std::unique_ptr<HloSharding> sharding)
+ explicit ShardingMetadata(std::shared_ptr<const HloSharding> sharding)
: sharding_(std::move(sharding)) {}
std::unique_ptr<DomainMetadata> Clone() const override;
- tensorflow::StringPiece Kind() const override { return KindName(); }
+ absl::string_view Kind() const override { return KindName(); }
bool Matches(const DomainMetadata& other) const override;
@@ -40,7 +40,7 @@ class ShardingMetadata : public DomainMetadata {
const HloSharding* sharding() const { return sharding_.get(); }
- static tensorflow::StringPiece KindName() { return "sharding"; }
+ static absl::string_view KindName() { return "sharding"; }
static StatusOr<const ShardingMetadata*> ToShardingMetadata(
const DomainMetadata* metadata);
@@ -55,15 +55,33 @@ class ShardingMetadata : public DomainMetadata {
const DomainMetadata* metadata);
private:
- std::unique_ptr<HloSharding> sharding_;
+ std::shared_ptr<const HloSharding> sharding_;
};
-// Given an HLO graph edge between instruction and one of its operands, creates
-// a ShardingMetadata based kDomain instruction if the sharding between
-// instruction and parent changes. Returns nullptr if there is no need for a
-// domain separation.
-std::unique_ptr<HloInstruction> CreateShardingDomain(
- HloInstruction* instruction, HloInstruction* root, HloInstruction* operand);
+// If the sharding between root and instruction changes then returns a
+// ShardingMetadata based kDomain instruction what can be used to separate
+// operand and instruction.
+// Returns nullptr if there is no need for a domain separation.
+class ShardingDomainCreator {
+ public:
+ HloInstruction* operator()(HloInstruction* instruction, HloInstruction* root,
+ HloInstruction* operand);
+
+ private:
+ // Map from instruction and user sharding to domain users to CSE identical
+ // domains.
+ struct DomainCseMapKey {
+ const HloInstruction* instruction;
+ std::shared_ptr<const HloSharding> sharding;
+
+ bool operator==(const DomainCseMapKey& other) const;
+ };
+ struct DomainCseMapHasher {
+ size_t operator()(const DomainCseMapKey& key) const;
+ };
+ std::unordered_map<DomainCseMapKey, HloInstruction*, DomainCseMapHasher>
+ domain_cse_map_;
+};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index 45fc300fca..2341f8ada0 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -115,6 +115,13 @@ TEST_F(HloShardingTest, Tile) {
}
}
+// Tests that empty tuple is supported.
+TEST_F(HloShardingTest, EmptySingleTuple) {
+ HloSharding sharding = HloSharding::SingleTuple(ShapeUtil::MakeTupleShape({}),
+ HloSharding::AssignDevice(0));
+ EXPECT_TRUE(sharding.ExtractSingleSharding());
+}
+
TEST_F(HloShardingTest, NestedTuple) {
// nested_tuple_shape = (f32[], (f32[3]), f32[4, 6])
Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({
diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
index 2ef38821af..d1cf644f82 100644
--- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
+++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
@@ -24,7 +24,7 @@ namespace xla {
// one arbitrarily to use and delete the others.
class HloSubcomputationUnification : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "subcomputation-unification";
}
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
index b78bfa0cdf..4876533449 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -21,28 +23,25 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-
-using ::tensorflow::GraphDef;
-using ::tensorflow::NodeDef;
-using ::tensorflow::TensorShapeProto;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
-using ::tensorflow::str_util::Join;
namespace xla {
namespace hlo_graph_dumper {
namespace {
+using absl::StrAppend;
+using absl::StrCat;
+using tensorflow::GraphDef;
+using tensorflow::NodeDef;
+using tensorflow::TensorShapeProto;
+
string GetOpDefName(const HloInstruction* instruction) {
string name = StrCat("hlo-", HloOpcodeString(instruction->opcode()));
- tensorflow::str_util::TitlecaseString(&name, "-");
+ tensorflow::str_util::TitlecaseString(&name, "-"); // non-absl ok
name.erase(std::remove(name.begin(), name.end(), '-'), name.end());
if (instruction->opcode() == HloOpcode::kFusion) {
string fusion_name = ToString(instruction->fusion_kind());
- StrAppend(&name, tensorflow::StringPiece(fusion_name).substr(1));
+ StrAppend(&name, absl::string_view(fusion_name).substr(1));
}
return name;
}
@@ -166,7 +165,9 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction,
layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape());
} else {
layout_string = StrCat(
- "{", Join(LayoutUtil::MinorToMajor(instruction->shape()), ","), "}");
+ "{",
+ absl::StrJoin(LayoutUtil::MinorToMajor(instruction->shape()), ","),
+ "}");
}
attrs["layout"].set_s(layout_string);
}
diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc
index 14703aaf64..e0c1326177 100644
--- a/tensorflow/compiler/xla/service/hlo_value.cc
+++ b/tensorflow/compiler/xla/service/hlo_value.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include <utility>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -30,16 +32,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
+using absl::StrAppend;
+using absl::StrCat;
const Shape& HloPosition::shape() const {
return ShapeUtil::GetSubshape(instruction->shape(), index);
@@ -216,10 +215,11 @@ void HloValueSet::SortAndUniquifyValues() {
}
string HloValueSet::ToString() const {
- return StrCat("HloValueSet: ",
- Join(values_, ", ", [](string* result, const HloValue* value) {
- result->append(value->ToShortString());
- }));
+ return StrCat(
+ "HloValueSet: ",
+ absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) {
+ result->append(value->ToShortString());
+ }));
}
bool HloValueSet::AssignUnionOf(
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 7acf58e252..f60c4eab42 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include <set>
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -122,29 +123,26 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
reduce_precision->mantissa_bits()));
}
-namespace {
-
-Status CheckIsTokenOperand(const HloInstruction* instruction,
- int64 operand_no) {
+Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction,
+ int64 operand_no) {
const HloInstruction* token = instruction->operand(operand_no);
if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) {
return InternalError(
"Expected operand %lld to be token-shaped, actual shape is "
"%s:\n%s",
- operand_no, ShapeUtil::HumanString(token->shape()).c_str(),
+ operand_no, StringifyShape(token->shape()).c_str(),
instruction->ToString().c_str());
}
return Status::OK();
}
-Status CheckOperandAndParameter(const HloInstruction* instruction,
- int64 operand_number,
- const HloComputation* computation,
- int64 parameter_number) {
+Status ShapeVerifier::CheckOperandAndParameter(
+ const HloInstruction* instruction, int64 operand_number,
+ const HloComputation* computation, int64 parameter_number) {
const HloInstruction* operand = instruction->operand(operand_number);
const HloInstruction* parameter =
computation->parameter_instruction(parameter_number);
- if (!ShapeUtil::Compatible(operand->shape(), parameter->shape())) {
+ if (!ShapesSame(operand->shape(), parameter->shape())) {
return InternalError("Operand %s shape does not match parameter's %s in %s",
operand->ToString().c_str(),
parameter->ToString().c_str(),
@@ -153,8 +151,6 @@ Status CheckOperandAndParameter(const HloInstruction* instruction,
return Status::OK();
}
-} // namespace
-
Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0));
@@ -171,13 +167,12 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) {
// Outfeed has a separate shape field for the value which is outfed to the
// host. The shape of the instruction itself is always a token.
- if (!ShapeUtil::Compatible(outfeed->outfeed_shape(),
- outfeed->operand(0)->shape())) {
+ if (!ShapesSame(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) {
return InternalError(
- "Expected outfeed shape to be compatible with operand's shape %s, "
+ "Expected outfeed shape to be equal to operand's shape %s, "
"actual shape is %s:\n%s",
- ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(),
- ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(),
+ StringifyShape(outfeed->operand(0)->shape()).c_str(),
+ StringifyShape(outfeed->outfeed_shape()).c_str(),
outfeed->ToString().c_str());
}
return CheckShape(outfeed, ShapeUtil::MakeTokenShape());
@@ -258,8 +253,8 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) {
return InternalError(
"Expected sort to have to have the same dimensions for the keys and "
"the values. Keys shape is: %s\n, Values shape is: %s",
- ShapeUtil::HumanString(sort->operand(0)->shape()).c_str(),
- ShapeUtil::HumanString(sort->operand(1)->shape()).c_str());
+ StringifyShape(sort->operand(0)->shape()).c_str(),
+ StringifyShape(sort->operand(1)->shape()).c_str());
}
return CheckVariadicShape(sort);
}
@@ -333,7 +328,18 @@ Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
return Status::OK();
}
-Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); }
+Status ShapeVerifier::HandleFusion(HloInstruction* fusion) {
+ for (HloInstruction* fused_param : fusion->fused_parameters()) {
+ int64 param_no = fused_param->parameter_number();
+ if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) {
+ return InternalError(
+ "Shape mismatch between parameter number %lld and its operand in "
+ "%s.",
+ param_no, fusion->ToString().c_str());
+ }
+ }
+ return Status::OK();
+}
Status ShapeVerifier::HandleCall(HloInstruction* call) {
for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) {
@@ -415,12 +421,11 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0));
const Shape& conditional_shape =
xla_while->while_condition()->root_instruction()->shape();
- if (!ShapeUtil::Compatible(conditional_shape,
- ShapeUtil::MakeShape(PRED, {}))) {
+ if (!ShapesSame(conditional_shape, ShapeUtil::MakeShape(PRED, {}))) {
return InternalError(
"Conditional computation shape does not lead to a scalar predicate "
"shape: %s",
- ShapeUtil::HumanString(conditional_shape).c_str());
+ StringifyShape(conditional_shape).c_str());
}
// The shape of kWhile should match the shape of the body computation it
// calls.
@@ -598,52 +603,51 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
}
// Check if the output shape matches the expected shape.
- bool compatible;
+ //
// We treat BF16 and F32 as compatible types if mixed precision is allowed,
// but only when the instruction defines the BF16/F32 buffer.
- switch (instruction->opcode()) {
- case HloOpcode::kTupleSelect:
- // TupleSelect only defines the top-level buffer, which in this case is
- // the tuple, so we cannot allow mixed precision.
- compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape);
- break;
- case HloOpcode::kGetTupleElement:
- case HloOpcode::kTuple:
- // Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed
- // precision is disallowed.
- case HloOpcode::kConstant:
- case HloOpcode::kBitcast:
- case HloOpcode::kBitcastConvert:
- case HloOpcode::kCall:
- case HloOpcode::kConditional:
- case HloOpcode::kConvert:
- case HloOpcode::kCustomCall:
- case HloOpcode::kInfeed:
- case HloOpcode::kOutfeed:
- case HloOpcode::kParameter:
- case HloOpcode::kRecv:
- case HloOpcode::kRecvDone:
- case HloOpcode::kSend:
- case HloOpcode::kSendDone:
- case HloOpcode::kWhile:
- // The above opcodes should match the expected shapes exactly.
- compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape);
- break;
- default:
- if (allow_mixed_precision_) {
- compatible = ShapeUtil::CompatibleIgnoringFpPrecision(
- instruction->shape(), inferred_shape);
- } else {
- compatible =
- ShapeUtil::Compatible(instruction->shape(), inferred_shape);
- }
- }
- if (!compatible) {
+ bool equal = [&] {
+ switch (instruction->opcode()) {
+ // The opcodes below can't have implicit layout conversions, nor can they
+ // implicitly transform f32 -> bf16. Fundamentally these are either
+ // reinterpreting existing data (e.g. kBitcast) or shuffling data around
+ // without modifying it (e.g. kGetTupleElement, kTupleSelect).
+ case HloOpcode::kBitcast:
+ case HloOpcode::kCall:
+ case HloOpcode::kConditional:
+ case HloOpcode::kConstant:
+ case HloOpcode::kCustomCall:
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kInfeed:
+ case HloOpcode::kOutfeed:
+ case HloOpcode::kParameter:
+ case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
+ case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
+ case HloOpcode::kTuple:
+ case HloOpcode::kTupleSelect:
+ case HloOpcode::kWhile:
+ return ShapesSame(instruction->shape(), inferred_shape);
+
+ // We allow arbitrary layout and f32->bf16 transformations on all other
+ // instructions, although this may be made more strict pending discussion
+ // in b/112709536.
+ default:
+ if (allow_mixed_precision_) {
+ return ShapeUtil::CompatibleIgnoringFpPrecision(instruction->shape(),
+ inferred_shape);
+ } else {
+ return ShapeUtil::Compatible(instruction->shape(), inferred_shape);
+ }
+ }
+ }();
+ if (!equal) {
return InternalError(
- "Expected instruction to have shape compatible with %s, actual "
+ "Expected instruction to have shape equal to %s, actual "
"shape is %s:\n%s",
- ShapeUtil::HumanString(inferred_shape).c_str(),
- ShapeUtil::HumanString(instruction->shape()).c_str(),
+ StringifyShape(inferred_shape).c_str(),
+ StringifyShape(instruction->shape()).c_str(),
instruction->ToString().c_str());
}
return Status::OK();
@@ -688,10 +692,10 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) {
string ComputationsToString(
tensorflow::gtl::ArraySlice<HloComputation*> computations) {
- return tensorflow::str_util::Join(
- computations, ",", [](string* s, const HloComputation* computation) {
- s->append(computation->name());
- });
+ return absl::StrJoin(computations, ",",
+ [](string* s, const HloComputation* computation) {
+ s->append(computation->name());
+ });
}
// Verifies various invariants about the structure of the HLO:
@@ -827,7 +831,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
}
// Fused parameter instructions must be numbered contiguously and match up
- // (shapes compatible) with their respective operand.
+ // (shapes equal) with their respective operand.
CHECK_EQ(fusion->operands().size(), fused_parameters.size());
std::vector<bool> parameter_numbers(fused_parameters.size(), false);
for (auto fused_param : fused_parameters) {
@@ -848,13 +852,6 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
param_no, fusion->ToString().c_str());
}
parameter_numbers[param_no] = true;
- if (!ShapeUtil::Compatible(fused_param->shape(),
- fusion->operand(param_no)->shape())) {
- return InternalError(
- "Shape mismatch between parameter number %lld and its operand in "
- "%s.",
- param_no, fusion->ToString().c_str());
- }
}
// Make sure all the parameter_numbers entries were seen.
for (int i = 0; i < parameter_numbers.size(); i++) {
@@ -916,7 +913,7 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) {
if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) {
return FailedPrecondition(
"Implicit broadcast is not allowed in HLO."
- "Found non-compatible shapes for instruction %s.\n"
+ "Found different shapes for instruction %s.\n"
"output: %s\noperand: %s\n",
HloOpcodeString(instruction->opcode()).c_str(),
ShapeUtil::HumanString(out_shape).c_str(),
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 523bf4d70c..b6093d667c 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -28,9 +28,9 @@ namespace xla {
// TODO(b/26024837): Check output shape for all instruction types.
class ShapeVerifier : public DfsHloVisitor {
public:
- explicit ShapeVerifier() : allow_mixed_precision_(false) {}
- explicit ShapeVerifier(bool allow_mixed_precision)
- : allow_mixed_precision_(allow_mixed_precision) {}
+ explicit ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision)
+ : layout_sensitive_(layout_sensitive),
+ allow_mixed_precision_(allow_mixed_precision) {}
Status HandleElementwiseUnary(HloInstruction* hlo) override;
Status HandleElementwiseBinary(HloInstruction* hlo) override;
@@ -106,13 +106,42 @@ class ShapeVerifier : public DfsHloVisitor {
Status CheckVariadicShape(const HloInstruction* instruction);
private:
- // Return true if the shapes of the two operands have the same element type,
- // and the result shape either has the same element type as the operand
- // shapes or mixed precision is allowed and the result shape and the operand
- // shapes have floating point element types.
+ // Helpers that switch on layout_sensitive_.
+ bool ShapesSame(const Shape& a, const Shape& b) {
+ return layout_sensitive_ ? ShapeUtil::Equal(a, b)
+ : ShapeUtil::Compatible(a, b);
+ }
+ bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b) {
+ return layout_sensitive_ ? ShapeUtil::EqualIgnoringFpPrecision(a, b)
+ : ShapeUtil::CompatibleIgnoringFpPrecision(a, b);
+ }
+ string StringifyShape(const Shape& s) {
+ return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s)
+ : ShapeUtil::HumanString(s);
+ }
+
+ // Checks that the given operand of the given instruction is of type TOKEN.
+ Status CheckIsTokenOperand(const HloInstruction* instruction,
+ int64 operand_no);
+
+ // Checks that the shape of the given operand of the given instruction matches
+ // the given parameter of the given computation.
+ Status CheckOperandAndParameter(const HloInstruction* instruction,
+ int64 operand_number,
+ const HloComputation* computation,
+ int64 parameter_number);
+
+ // Returns true if the shapes of the two operands have the same element type,
+ // and the result shape either has the same element type as the operand shapes
+ // or mixed precision is allowed and the result shape and the operand shapes
+ // have floating point element types.
bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1,
const Shape& result_shape);
+ // If the verifier is layout-sensitive, shapes must be equal to what's
+ // expected. Otherwise, the shapes must simply be compatible.
+ bool layout_sensitive_;
+
// Whether the inputs and output of an instruction can contain both F32s and
// BF16s. Tuples that include both F32s and BF16s are allowed regardless of
// this flag.
@@ -125,14 +154,10 @@ class HloVerifier : public HloPassInterface {
public:
using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
- // Uses standard shape inference.
- explicit HloVerifier()
- : shape_verifier_factory_(
- [] { return absl::make_unique<ShapeVerifier>(false); }) {}
-
- explicit HloVerifier(bool allow_mixed_precision)
- : shape_verifier_factory_([allow_mixed_precision] {
- return absl::make_unique<ShapeVerifier>(allow_mixed_precision);
+ explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision)
+ : shape_verifier_factory_([layout_sensitive, allow_mixed_precision] {
+ return absl::make_unique<ShapeVerifier>(layout_sensitive,
+ allow_mixed_precision);
}) {}
// Uses custom shape verification.
@@ -140,10 +165,9 @@ class HloVerifier : public HloPassInterface {
: shape_verifier_factory_(std::move(shape_verifier_factory)) {}
~HloVerifier() override = default;
- tensorflow::StringPiece name() const override { return "verifier"; }
+ absl::string_view name() const override { return "verifier"; }
- // Note: always returns false (no instructions are ever modified by this
- // pass).
+ // Never returns true; no instructions are ever modified by this pass.
StatusOr<bool> Run(HloModule* module) override;
private:
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index d764964f3c..70b741353d 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -37,13 +37,15 @@ using ::testing::HasSubstr;
class HloVerifierTest : public HloTestBase {
public:
HloVerifierTest()
- : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/false) {}
+ : HloTestBase(/*verifier_layout_sensitive=*/false,
+ /*allow_mixed_precision_in_hlo_verifier=*/false) {}
};
class HloVerifierTestAllowMixedPrecision : public HloTestBase {
public:
HloVerifierTestAllowMixedPrecision()
- : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/true) {}
+ : HloTestBase(/*verifier_layout_sensitive=*/false,
+ /*allow_mixed_precision_in_hlo_verifier=*/true) {}
};
TEST_F(HloVerifierTest, NullInstructionParent) {
diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc
index bb5b40a8a8..581b3ce1e0 100644
--- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc
+++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc
@@ -14,20 +14,20 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/human_readable_profile_builder.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/metric_table_report.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
+using absl::StrAppend;
+using absl::StrCat;
using tensorflow::strings::Appendf;
using tensorflow::strings::HumanReadableElapsedTime;
using tensorflow::strings::HumanReadableNumBytes;
using tensorflow::strings::Printf;
-using tensorflow::strings::StrAppend;
-using tensorflow::strings::StrCat;
string HumanReadableProfileBuilder::ToString() const {
string s;
diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h
index 6f56c3aa82..b99624460e 100644
--- a/tensorflow/compiler/xla/service/human_readable_profile_builder.h
+++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -29,7 +29,7 @@ namespace xla {
// computation, suitable for consumption by humans.
class HumanReadableProfileBuilder {
public:
- explicit HumanReadableProfileBuilder(tensorflow::StringPiece computation_name,
+ explicit HumanReadableProfileBuilder(absl::string_view computation_name,
int64 total_cycles,
double clock_rate_ghz)
: computation_name_(std::string(computation_name)),
@@ -43,9 +43,8 @@ class HumanReadableProfileBuilder {
// Adds an operation to the profile. If you don't know the number of
// floating-point ops or bytes touched by the op, or if you don't know how
// fast it would run optimally, pass -1 for that param.
- void AddOp(tensorflow::StringPiece op_name,
- tensorflow::StringPiece short_name,
- tensorflow::StringPiece category, int64 cycles, int64 flop_count,
+ void AddOp(absl::string_view op_name, absl::string_view short_name,
+ absl::string_view category, int64 cycles, int64 flop_count,
int64 transcendental_count, int64 bytes_accessed,
float optimal_seconds) {
op_infos_.push_back({std::string(op_name), std::string(short_name),
diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
index aa325dc8a3..85bb4a8b24 100644
--- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
+++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
@@ -30,7 +30,7 @@ class ImplicitBroadcastRemover : public HloPassInterface {
ImplicitBroadcastRemover() {}
~ImplicitBroadcastRemover() override {}
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "implicit-broadcast-remover";
}
diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
index f85d31d522..df88587492 100644
--- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
+++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
@@ -26,6 +26,11 @@ namespace xla {
namespace {
class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase {
+ public:
+ ImplicitBroadcastRemoverTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
ImplicitBroadcastRemover remover_;
};
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 256c8e5573..43ef30d1eb 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -17,12 +17,13 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace gtl = ::tensorflow::gtl;
@@ -33,32 +34,30 @@ using UnknownArray = Analysis::UnknownArray;
using ConstantArray = Analysis::ConstantArray;
using ReshapedArray = Analysis::ReshapedArray;
using ScalarIndexedArray = Analysis::ScalarIndexedArray;
+using absl::StrJoin;
using tensorflow::gtl::ArraySlice;
-using tensorflow::str_util::Join;
} // namespace
string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
switch (root->kind()) {
case Array::kUnknown: {
auto* unknown_tensor = root->as<UnknownArray>();
- return tensorflow::strings::StrCat("%",
- unknown_tensor->instruction().name());
+ return absl::StrCat("%", unknown_tensor->instruction().name());
}
case Array::kConstant: {
if (print_constants) {
string contents = root->as<ConstantArray>()->literal()->ToString();
- return tensorflow::strings::StrCat(
- "(constant ", ShapeUtil::HumanString(root->shape()), " ", contents,
- ")");
+ return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()),
+ " ", contents, ")");
}
- return tensorflow::strings::StrCat(
- "(constant ", ShapeUtil::HumanString(root->shape()), ")");
+ return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()),
+ ")");
}
case Array::kReshaped: {
ReshapedArray* reshaped_array = root->as<ReshapedArray>();
- return tensorflow::strings::StrCat(
+ return absl::StrCat(
"(reshape ", ToString(reshaped_array->operand(), print_constants),
" to ", ShapeUtil::HumanString(reshaped_array->shape()), ")");
}
@@ -69,11 +68,11 @@ string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
string name = root->kind() == Array::kScalarIndexedConstant
? "scalar-indexed-const"
: "scalar-indexed";
- return tensorflow::strings::StrCat(
+ return absl::StrCat(
"(", name, " ", ToString(indexed_array->source(), print_constants),
" ", ToString(indexed_array->indices(), print_constants), " ",
indexed_array->source_dim(), "->[",
- Join(indexed_array->output_dims(), ","), "])");
+ StrJoin(indexed_array->output_dims(), ","), "])");
}
}
}
@@ -379,8 +378,8 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
CHECK_NE(candidate_operand_dim, 0)
<< "result_dim = " << result_dim
<< ", result_subarray_size = " << result_subarray_size
- << ", result_shape = [" << Join(result_shape, ",") << "]"
- << ", operand_shape = [" << Join(operand_shape, ",") << "]";
+ << ", result_shape = [" << StrJoin(result_shape, ",") << "]"
+ << ", operand_shape = [" << StrJoin(operand_shape, ",") << "]";
if (candidate_operand_dim != -1 &&
result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) {
@@ -396,12 +395,13 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
std::vector<string> result_strings;
absl::c_transform(result, std::back_inserter(result_strings),
[](ReshapePassthroughDimPair value) {
- return tensorflow::strings::StrCat(
- value.result_dim, "->", value.operand_dim);
+ return absl::StrCat(value.result_dim, "->",
+ value.operand_dim);
});
- VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to ["
- << Join(result_shape, ",") << "] passthrough indices are ["
- << Join(result_strings, ",") << "] (legend: `result`->`operand`)";
+ VLOG(3) << "For a reshape from [" << StrJoin(operand_shape, ",") << "] to ["
+ << StrJoin(result_shape, ",") << "] passthrough indices are ["
+ << StrJoin(result_strings, ",")
+ << "] (legend: `result`->`operand`)";
}
DCHECK(absl::c_is_sorted(
@@ -443,7 +443,7 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice<int64> operand_shape,
ArraySlice<int64> result_shape,
int64 source_passthrough_dim) {
VLOG(3) << "FindSourcePositionForPassthroughResultDim(["
- << Join(operand_shape, ",") << "], [" << Join(result_shape, ",")
+ << StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",")
<< "], " << source_passthrough_dim << ")";
int64 indexed_source_subarray_size =
@@ -755,9 +755,9 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
if (source_dim_for_new_scalar_indexed_node == -1) {
VLOG(3) << "Could not compute the source dim for the new scalar indexed "
"node: scalar_indexed_source_shape = ["
- << Join(scalar_indexed_source_shape.dimensions(), ",")
+ << StrJoin(scalar_indexed_source_shape.dimensions(), ",")
<< "] and new_scalar_indexed_source_shape = ["
- << Join(new_scalar_indexed_source_shape, ",") << "]";
+ << StrJoin(new_scalar_indexed_source_shape, ",") << "]";
return nullptr;
}
@@ -997,8 +997,7 @@ absl::optional<int64> GetOnlyNonContractingNonBatchDim(
// `contracting_dims` and `batch_dims` are the contracting and batch dimensions
// of whatever operand `indexed_array` is to the dot (LHS or RHS).
bool CanFoldDotIntoIndexedArray(
- tensorflow::StringPiece tag,
- Analysis::ScalarIndexedConstantArray* indexed_array,
+ absl::string_view tag, Analysis::ScalarIndexedConstantArray* indexed_array,
ArraySlice<int64> contracting_dims, ArraySlice<int64> batch_dims) {
absl::optional<int64> non_contracting_non_batch_dim =
GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()),
@@ -1135,7 +1134,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
return nullptr;
}
-tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const {
+absl::string_view IndexedArrayAnalysisPrinterPass::name() const {
return "indexed-array-analysis-printer-pass";
}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index 675eb31d26..3fa7d749e1 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -371,7 +371,7 @@ class IndexedArrayAnalysis {
// unconditionally add to the regular HLO pass pipeline.
class IndexedArrayAnalysisPrinterPass : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override;
+ absl::string_view name() const override;
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
index 97052edf7d..c34c32f7d3 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
@@ -22,6 +22,11 @@ limitations under the License.
namespace xla {
namespace {
class IndexedArrayAnalysisTest : public HloVerifiedTestBase {
+ public:
+ IndexedArrayAnalysisTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
void AssertArrayForRootExpressionIs(const string& hlo_text,
const string& root_expression) {
@@ -634,9 +639,9 @@ ENTRY main {
AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
(scalar-indexed-const (constant f32[3,4] f32[3,4] {
- { 0.761594176, 0.964027584, 0.995054781, 0.999329329 },
- { 0.761594176, 0.995054781, 0.964027584, 0.999329329 },
- { 0.999329329, 0.995054781, 0.964027584, 0.761594176 }
+ { 0.761594, 0.964028, 0.995055, 0.999329 },
+ { 0.761594, 0.995055, 0.964028, 0.999329 },
+ { 0.999329, 0.995055, 0.964028, 0.761594 }
}) %indices 0->[0]))");
}
diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/inliner.h
index a523811f6c..efa8ed3abc 100644
--- a/tensorflow/compiler/xla/service/inliner.h
+++ b/tensorflow/compiler/xla/service/inliner.h
@@ -27,7 +27,7 @@ namespace xla {
class Inliner : public HloPassInterface {
public:
~Inliner() override = default;
- tensorflow::StringPiece name() const override { return "inline"; }
+ absl::string_view name() const override { return "inline"; }
// Run inlining on the given computation. Returns whether the computation was
// changed.
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index f73ca9adf7..8489c3d9ad 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -36,7 +36,7 @@ class InstructionFusion : public HloPassInterface {
bool may_duplicate = true)
: is_expensive_(is_expensive), may_duplicate_(may_duplicate) {}
~InstructionFusion() override = default;
- tensorflow::StringPiece name() const override { return "fusion"; }
+ absl::string_view name() const override { return "fusion"; }
// Run instruction fusion on the given computation. Returns whether the
// computation was changed (instructions were fused).
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index c75bffc63d..5741864282 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -27,6 +27,8 @@ limitations under the License.
#include <tuple>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
@@ -49,20 +51,12 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
namespace xla {
-// For now moving only one API here, but we should have a single top level
-// anonymous namespace, instead of three or four spread all over this file.
-namespace {
-
-} // namespace
-
std::ostream& operator<<(std::ostream& out,
const LayoutConstraint& constraint) {
out << constraint.ToString();
@@ -368,31 +362,27 @@ const ShapeLayout* LayoutConstraints::ResultLayout() const {
string LayoutConstraints::ToString() const {
string output;
- tensorflow::strings::StrAppend(&output, "LayoutConstraints for computation ",
- computation_->name(), ":\n");
+ absl::StrAppend(&output, "LayoutConstraints for computation ",
+ computation_->name(), ":\n");
for (auto* instruction : computation_->MakeInstructionPostOrder()) {
- tensorflow::strings::StrAppend(&output, " ", instruction->ToShortString(),
- "\n");
+ absl::StrAppend(&output, " ", instruction->ToShortString(), "\n");
for (int64 i = 0; i < instruction->operand_count(); ++i) {
if (OperandLayout(instruction, i) != nullptr) {
- tensorflow::strings::StrAppend(
- &output, " operand (", i,
- "): ", OperandLayout(instruction, i)->ToString(), "\n");
+ absl::StrAppend(&output, " operand (", i,
+ "): ", OperandLayout(instruction, i)->ToString(), "\n");
}
}
for (const LogicalBuffer* buffer :
points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
if (BufferLayout(*buffer) != nullptr) {
- tensorflow::strings::StrAppend(
- &output, " ", buffer->ToString(), " : ",
- LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n");
+ absl::StrAppend(&output, " ", buffer->ToString(), " : ",
+ LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n");
}
}
}
if (ResultLayout() != nullptr) {
- tensorflow::strings::StrAppend(&output, " => ", ResultLayout()->ToString(),
- "\n");
+ absl::StrAppend(&output, " => ", ResultLayout()->ToString(), "\n");
}
return output;
}
@@ -909,7 +899,7 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
"Layout of instruction %s at index {%s} does not match "
"source LogicalBuffer %s: %s vs %s",
instruction->name().c_str(),
- tensorflow::str_util::Join(index, ",").c_str(),
+ absl::StrJoin(index, ",").c_str(),
buffer->ToString().c_str(),
ShapeUtil::HumanStringWithLayout(instruction_subshape)
.c_str(),
@@ -1400,8 +1390,8 @@ StatusOr<Layout> InferArrayLayout(
return FailedPrecondition(
"Array at index {%s} in instruction %s aliases buffers %s "
"and %s which have different layouts",
- tensorflow::str_util::Join(index, ",").c_str(),
- instruction->name().c_str(), source_buffers[0]->ToString().c_str(),
+ absl::StrJoin(index, ",").c_str(), instruction->name().c_str(),
+ source_buffers[0]->ToString().c_str(),
source_buffer->ToString().c_str());
}
}
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index f9e8dbea2f..3e000ec2df 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -297,7 +297,7 @@ class LayoutAssignment : public HloPassInterface {
ComputationLayout* entry_computation_layout,
ChannelLayoutConstraints* channel_constraints = nullptr);
~LayoutAssignment() override {}
- tensorflow::StringPiece name() const override { return "layout-assignment"; }
+ absl::string_view name() const override { return "layout-assignment"; }
// Assign layouts to the given module. Returns whether the module was changed
// (any layouts were changed).
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index a16fa75e30..6d05fa5fe2 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -59,7 +59,7 @@ class LayoutAssignmentTest : public HloTestBase {
EXPECT_IS_OK(layout_assignment.Run(module).status());
}
- std::vector<int64> LayoutOf(HloModule* module, tensorflow::StringPiece name) {
+ std::vector<int64> LayoutOf(HloModule* module, absl::string_view name) {
auto minor_to_major =
FindInstruction(module, name)->shape().layout().minor_to_major();
return std::vector<int64>(minor_to_major.begin(), minor_to_major.end());
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index 539a9522c1..fc3289f30d 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -38,6 +38,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:logical_buffer",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
],
)
@@ -69,6 +70,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
"@llvm//:support",
"@llvm//:target",
@@ -89,6 +91,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
],
)
@@ -104,6 +107,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
],
)
@@ -192,6 +196,7 @@ cc_library(
"//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter",
"//tensorflow/compiler/xla/service/gpu:partition_assignment",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@llvm//:core",
],
@@ -219,7 +224,7 @@ cc_library(
deps = [
":llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
],
)
@@ -230,6 +235,7 @@ cc_library(
hdrs = ["buffer_assignment_util.h"],
deps = [
"//tensorflow/compiler/xla/service:buffer_assignment",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
index fe9eab93aa..8d9fa99d82 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_
+#include "absl/strings/str_cat.h"
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -23,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace llvm_ir {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc
index 4eb5d9fb47..bdce4a171b 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
+#include "absl/strings/str_cat.h"
namespace xla {
namespace llvm_ir {
@@ -48,7 +49,7 @@ string ConstantBufferAllocationToGlobalName(
c = '_';
}
}
- return tensorflow::strings::StrCat("buffer_for_", instr_name);
+ return absl::StrCat("buffer_for_", instr_name);
}
const Literal& LiteralForConstantAllocation(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
index 27fbb11e2e..ad350613dd 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
@@ -40,7 +40,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl(
const Shape& update_shape, const ElementGenerator& start_indices_generator,
bool is_signed, ElementGenerator update_array_generator,
const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions,
- tensorflow::StringPiece name, llvm::IRBuilder<>* b) {
+ absl::string_view name, llvm::IRBuilder<>* b) {
const Shape& output_shape = output_array.GetShape();
// Read start indices from start_indices_generator.
@@ -101,8 +101,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl(
Status EmitDynamicUpdateSliceInPlace(
tensorflow::gtl::ArraySlice<IrArray> operand_arrays,
- const IrArray& output_array, tensorflow::StringPiece name,
- llvm::IRBuilder<>* b) {
+ const IrArray& output_array, absl::string_view name, llvm::IRBuilder<>* b) {
VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name;
// No need to use operand_arrays[0], the input array of the
diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h
index 3502577d23..e1631a62ae 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h
@@ -65,8 +65,7 @@ inline bool CanEmitFusedDynamicUpdateSliceInPlace(
// modify the input/output buffer without touching any of the other elements.
Status EmitDynamicUpdateSliceInPlace(
tensorflow::gtl::ArraySlice<IrArray> operand_arrays,
- const IrArray& output_array, tensorflow::StringPiece name,
- llvm::IRBuilder<>* b);
+ const IrArray& output_array, absl::string_view name, llvm::IRBuilder<>* b);
// Given a loop-fusion node whose root is a dynamic-update-slice op whose
// array-to-be-updated and output share the same buffer slice, emits
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
index 2b6caee6aa..6971220022 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
@@ -342,9 +342,9 @@ llvm::Value* IrArray::Index::Linearize(
return logical_linear_index;
}
-llvm::Value* IrArray::EmitArrayElementAddress(
- const IrArray::Index& index, llvm::IRBuilder<>* b,
- tensorflow::StringPiece name) const {
+llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
+ llvm::IRBuilder<>* b,
+ absl::string_view name) const {
if (ShapeUtil::IsScalar(*shape_)) {
// Special handling of scalars: a scalar pretends to have the same value for
// every index, thus effectively implementing broadcasting of its value
@@ -402,7 +402,7 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata(
llvm::Value* IrArray::EmitReadArrayElement(const Index& index,
llvm::IRBuilder<>* b,
- tensorflow::StringPiece name) const {
+ absl::string_view name) const {
llvm::Value* element_address = EmitArrayElementAddress(index, b, name);
llvm::LoadInst* load = b->CreateLoad(element_address);
AnnotateLoadStoreInstructionWithMetadata(load);
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
index cbfd2e7012..e913c109b3 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
@@ -20,12 +20,12 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h"
+#include "absl/strings/string_view.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -241,7 +241,7 @@ class IrArray {
// The optional name is useful for debugging when looking at
// the emitted LLVM IR.
llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b,
- tensorflow::StringPiece name = "") const;
+ absl::string_view name = "") const;
// Attach metadata this IrArray instance knows about to "instruction".
void AnnotateLoadStoreInstructionWithMetadata(
@@ -255,7 +255,7 @@ class IrArray {
// The optional name is useful for debugging when looking at
// the emitted LLVM IR.
llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b,
- tensorflow::StringPiece name = "") const;
+ absl::string_view name = "") const;
// Emit IR to write the given value to the array element at the given index.
void EmitWriteArrayElement(const Index& index, llvm::Value* value,
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
index b79567369a..bd0139f85b 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
@@ -19,7 +19,7 @@ limitations under the License.
namespace xla {
Status KernelSupportLibrary::For(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<Status(llvm::Value*, bool)>& for_body_generator) {
return If(b_->CreateICmpSLT(start, end), [&]() -> Status {
@@ -30,7 +30,7 @@ Status KernelSupportLibrary::For(
}
Status KernelSupportLibrary::For(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step, bool peel_first_iteration,
const std::function<Status(llvm::Value*, llvm::Value*)>&
for_body_generator) {
@@ -56,7 +56,7 @@ Status KernelSupportLibrary::For(
}
Status KernelSupportLibrary::If(
- tensorflow::StringPiece name, llvm::Value* condition,
+ absl::string_view name, llvm::Value* condition,
const std::function<Status()>& true_block_generator,
const std::function<Status()>& false_block_generator) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, name, b_);
@@ -70,7 +70,7 @@ Status KernelSupportLibrary::If(
void KernelSupportLibrary::EmitAndCallOutlinedKernel(
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
- tensorflow::StringPiece kernel_name,
+ absl::string_view kernel_name,
KernelSupportLibrary::ArgumentVector arguments,
const std::function<void(KernelSupportLibrary::ArgumentVector)>&
kernel_body_generator) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
index c5354a8c42..b152cf9275 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
@@ -18,12 +18,12 @@ limitations under the License.
#include <string>
+#include "absl/strings/string_view.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
// A thin wrapper around llvm_loop.h to make code generating structured control
@@ -49,13 +49,13 @@ class KernelSupportLibrary {
// `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`;
// }
Status For(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<Status(llvm::Value* ind_var,
bool is_first_iteration)>& for_body_generator);
void ForReturnVoid(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
for_body_generator) {
@@ -67,7 +67,7 @@ class KernelSupportLibrary {
}));
}
- Status For(tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+ Status For(absl::string_view name, int64 start, int64 end, int64 step,
const std::function<Status(llvm::Value* ind_var,
bool is_first_iteration)>&
for_body_generator) {
@@ -77,7 +77,7 @@ class KernelSupportLibrary {
}
void ForReturnVoid(
- tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+ absl::string_view name, int64 start, int64 end, int64 step,
const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
for_body_generator) {
ForReturnVoid(name, /*start=*/b_->getInt64(start),
@@ -99,13 +99,13 @@ class KernelSupportLibrary {
// for (i64 i = `start`; i s< `end`; i += `step`)
// `for_body_generator(/*ind_var=*/,i,
// /*is_first_iteration=*/,(i != `start`))`;
- Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ Status For(absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step, bool peel_first_iteration,
const std::function<Status(llvm::Value* ind_var,
llvm::Value* is_first_iteration)>&
for_body_generator);
- void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start,
+ void ForReturnVoid(absl::string_view name, llvm::Value* start,
llvm::Value* end, llvm::Value* step,
bool peel_first_iteration,
const std::function<void(llvm::Value* ind_var,
@@ -119,7 +119,7 @@ class KernelSupportLibrary {
}));
}
- Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ Status For(absl::string_view name, llvm::Value* start, llvm::Value* end,
int64 step, bool peel_first_iteration,
const std::function<Status(llvm::Value* ind_var,
llvm::Value* is_first_iteration)>&
@@ -129,7 +129,7 @@ class KernelSupportLibrary {
peel_first_iteration, for_body_generator);
}
- void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start,
+ void ForReturnVoid(absl::string_view name, llvm::Value* start,
llvm::Value* end, int64 step, bool peel_first_iteration,
const std::function<void(llvm::Value* ind_var,
llvm::Value* is_first_iteration)>&
@@ -140,7 +140,7 @@ class KernelSupportLibrary {
}
Status For(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
return For(name, start, end, step,
@@ -151,7 +151,7 @@ class KernelSupportLibrary {
}
void ForReturnVoid(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
ForReturnVoid(name, start, end, step,
@@ -162,8 +162,7 @@ class KernelSupportLibrary {
}
Status For(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
- int64 step,
+ absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step,
const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
return For(name, start, end, llvm::ConstantInt::get(start->getType(), step),
/*peel_first_iteration=*/false,
@@ -173,8 +172,7 @@ class KernelSupportLibrary {
}
void ForReturnVoid(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
- int64 step,
+ absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step,
const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
ForReturnVoid(name, start, end,
llvm::ConstantInt::get(start->getType(), step),
@@ -182,7 +180,7 @@ class KernelSupportLibrary {
}
Status For(
- tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+ absl::string_view name, int64 start, int64 end, int64 step,
const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
return For(name, /*start=*/b_->getInt64(start),
/*end=*/b_->getInt64(end),
@@ -190,7 +188,7 @@ class KernelSupportLibrary {
}
void ForReturnVoid(
- tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+ absl::string_view name, int64 start, int64 end, int64 step,
const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
ForReturnVoid(name, /*start=*/b_->getInt64(start),
/*end=*/b_->getInt64(end),
@@ -203,7 +201,7 @@ class KernelSupportLibrary {
// `true_block_generator()`;
// else
// `false_block_generator()`;
- Status If(tensorflow::StringPiece name, llvm::Value* condition,
+ Status If(absl::string_view name, llvm::Value* condition,
const std::function<Status()>& true_block_generator,
const std::function<Status()>& false_block_generator =
[]() -> Status { return Status::OK(); });
@@ -222,7 +220,7 @@ class KernelSupportLibrary {
IfReturnVoid("", condition, true_block_generator, false_block_generator);
}
- void IfReturnVoid(tensorflow::StringPiece name, llvm::Value* condition,
+ void IfReturnVoid(absl::string_view name, llvm::Value* condition,
const std::function<void()>& true_block_generator,
const std::function<void()>& false_block_generator = []() {
}) {
@@ -259,13 +257,13 @@ class KernelSupportLibrary {
// Currently we only support at most one nullptr value in `arguments`.
static void EmitAndCallOutlinedKernel(
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
- tensorflow::StringPiece kernel_name, ArgumentVector arguments,
+ absl::string_view kernel_name, ArgumentVector arguments,
const std::function<void(ArgumentVector)>& kernel_body_generator);
// Thin wrappers around the more general EmitAndCallOutlinedKernel above.
static void EmitAndCallOutlinedKernel(
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
- tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1,
+ absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1,
llvm::Value* arg2,
const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*)>&
kernel_body_generator) {
@@ -278,7 +276,7 @@ class KernelSupportLibrary {
static void EmitAndCallOutlinedKernel(
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
- tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1,
+ absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1,
llvm::Value* arg2, llvm::Value* arg3,
const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*,
llvm::Value*)>& kernel_body_generator) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
index ba7f94834c..978fa5b453 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
@@ -25,14 +26,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace llvm_ir {
-ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
+ForLoop::ForLoop(absl::string_view prefix, absl::string_view suffix,
llvm::Value* start_index, llvm::Value* end_index,
llvm::Value* step, UnrollMode unroll_mode,
bool prevent_vectorization)
@@ -46,9 +46,9 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
prevent_vectorization_(prevent_vectorization) {}
/* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop(
- tensorflow::StringPiece prefix, llvm::Value* start_index,
- llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b,
- UnrollMode unroll_mode, bool prevent_vectorization) {
+ absl::string_view prefix, llvm::Value* start_index, llvm::Value* end_index,
+ llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode,
+ bool prevent_vectorization) {
std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index,
end_index, step, unroll_mode,
prevent_vectorization));
@@ -168,16 +168,16 @@ std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(llvm::IRBuilder<>* b) {
return result;
}
-string ForLoop::GetQualifiedName(tensorflow::StringPiece name) {
+string ForLoop::GetQualifiedName(absl::string_view name) {
return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_));
}
-llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name,
+llvm::BasicBlock* ForLoop::CreateLoopBB(absl::string_view name,
llvm::IRBuilder<>* b) {
return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), b);
}
-std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
+std::unique_ptr<ForLoop> ForLoopNest::AddLoop(absl::string_view suffix,
llvm::Value* start_index,
llvm::Value* end_index,
UnrollMode unroll_mode,
@@ -186,12 +186,9 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
unroll_mode, prevent_vectorization);
}
-std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
- llvm::Value* start_index,
- llvm::Value* end_index,
- llvm::Value* stride,
- UnrollMode unroll_mode,
- bool prevent_vectorization) {
+std::unique_ptr<ForLoop> ForLoopNest::AddLoop(
+ absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index,
+ llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization) {
if (inner_loop_body_bb_ != nullptr) {
// Create this loop inside the previous one.
b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt());
@@ -216,7 +213,7 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
int64 end_index,
- tensorflow::StringPiece suffix,
+ absl::string_view suffix,
UnrollMode unroll_mode,
bool prevent_vectorization) {
CHECK_LE(start_index, end_index);
@@ -227,7 +224,7 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
int64 end_index, int64 stride,
- tensorflow::StringPiece suffix,
+ absl::string_view suffix,
UnrollMode unroll_mode,
bool prevent_vectorization) {
CHECK_LE(start_index, end_index);
@@ -238,7 +235,7 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
}
IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
- tensorflow::StringPiece suffix) {
+ absl::string_view suffix) {
std::vector<int64> dimensions(ShapeUtil::Rank(shape));
std::iota(dimensions.begin(), dimensions.end(), 0);
return AddLoopsForShapeOnDimensions(shape, dimensions, suffix);
@@ -246,14 +243,14 @@ IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions(
const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::StringPiece suffix) {
+ absl::string_view suffix) {
llvm_ir::IrArray::Index index(index_type_, shape.dimensions_size());
for (int64 dimension : dimensions) {
std::unique_ptr<llvm_ir::ForLoop> loop = AddLoop(
/*start_index=*/0,
/*end_index=*/shape.dimensions(dimension),
/*suffix=*/
- llvm_ir::IrName(suffix, tensorflow::strings::StrCat(dimension)));
+ llvm_ir::IrName(suffix, absl::StrCat(dimension)));
index[dimension] = loop->GetIndVarValue();
}
return index;
@@ -261,7 +258,7 @@ IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions(
IrArray::Index ForLoopNest::EmitOperandArrayLoopNest(
const llvm_ir::IrArray& operand_array, int64 dimension_to_skip,
- tensorflow::StringPiece name_suffix) {
+ absl::string_view name_suffix) {
// Prepares the dimension list we will use to emit the loop nest. Outermost
// loops are added first. Add loops in major-to-minor order, and skip the
// 'dimension_to_skip' dimension.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
index a4fed5c8dc..62aa15fe2d 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
@@ -19,15 +19,15 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -78,7 +78,7 @@ class ForLoop {
// `unroll_mode` specifies the desired LLVM unrolling behavior for generated
// loop.
static std::unique_ptr<ForLoop> EmitForLoop(
- tensorflow::StringPiece prefix, llvm::Value* start_index,
+ absl::string_view prefix, llvm::Value* start_index,
llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b,
UnrollMode unroll_mode = llvm_ir::UnrollMode::kDefaultUnroll,
bool prevent_vectorization = false);
@@ -133,19 +133,18 @@ class ForLoop {
// Allow ForLoopNest to call this private constructor.
friend class ForLoopNest;
- ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
+ ForLoop(absl::string_view prefix, absl::string_view suffix,
llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step,
UnrollMode unroll_mode, bool prevent_vectorization);
// Emit the loop at the insert point of the builder.
void Emit(llvm::IRBuilder<>* b);
- llvm::BasicBlock* CreateLoopBB(tensorflow::StringPiece name,
- llvm::IRBuilder<>* b);
+ llvm::BasicBlock* CreateLoopBB(absl::string_view name, llvm::IRBuilder<>* b);
// Creates a name for an LLVM construct, appending prefix_ and suffix_, if
// they are set.
- string GetQualifiedName(tensorflow::StringPiece name);
+ string GetQualifiedName(absl::string_view name);
// Return a list of metadata nodes that should be associated with the
// llvm::Loop for this `ForLoop`.
@@ -182,7 +181,7 @@ class ForLoopNest {
SetIndexType(index_ty);
}
- ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* b,
+ ForLoopNest(absl::string_view name, llvm::IRBuilder<>* b,
llvm::Type* index_ty = nullptr)
: name_(std::string(name)),
outer_loop_preheader_bb_(nullptr),
@@ -197,14 +196,14 @@ class ForLoopNest {
// been added then emit loop inside the body of the last added loop.
// unroll_mode is used to emit metadata that controls LLVM unrolling.
std::unique_ptr<ForLoop> AddLoop(
- tensorflow::StringPiece suffix, llvm::Value* start_index,
+ absl::string_view suffix, llvm::Value* start_index,
llvm::Value* end_index, llvm::Value* stride,
UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
bool prevent_vectorization = false);
// Like the above, except that it defaults to a stride of one.
std::unique_ptr<ForLoop> AddLoop(
- tensorflow::StringPiece suffix, llvm::Value* start_index,
+ absl::string_view suffix, llvm::Value* start_index,
llvm::Value* end_index,
UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
bool prevent_vectorization = false);
@@ -213,13 +212,13 @@ class ForLoopNest {
// end index are constant.
std::unique_ptr<ForLoop> AddLoop(
int64 start_index, int64 end_index, int64 stride,
- tensorflow::StringPiece suffix,
+ absl::string_view suffix,
UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
bool prevent_vectorization = false);
// Like the above, except that it defaults to a stride of one.
std::unique_ptr<ForLoop> AddLoop(
- int64 start_index, int64 end_index, tensorflow::StringPiece suffix,
+ int64 start_index, int64 end_index, absl::string_view suffix,
UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
bool prevent_vectorization = false);
@@ -234,8 +233,7 @@ class ForLoopNest {
// within the shape. One possible order for that sequence would be:
//
// (0,0), (0,1), (0,2), (1,0), (1,1), (1,2)
- IrArray::Index AddLoopsForShape(const Shape& shape,
- tensorflow::StringPiece suffix);
+ IrArray::Index AddLoopsForShape(const Shape& shape, absl::string_view suffix);
// Add a loop for each dimension in "dimensions". "suffix" is the
// name suffix of the indvar and basic blocks in this new loop nest.
@@ -245,7 +243,7 @@ class ForLoopNest {
// dimension that is not in "dimensions".
IrArray::Index AddLoopsForShapeOnDimensions(
const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::StringPiece suffix);
+ absl::string_view suffix);
// Emits a series of nested loops for iterating over an operand array. Loops
// are constructed in major to minor dimension layout order. No loop is
@@ -256,7 +254,7 @@ class ForLoopNest {
// basic blocks) constructed by this method.
IrArray::Index EmitOperandArrayLoopNest(const llvm_ir::IrArray& operand_array,
int64 dimension_to_skip,
- tensorflow::StringPiece name_suffix);
+ absl::string_view name_suffix);
// Convenience methods which return particular basic blocks of the outermost
// or innermost loops. These methods return nullptr if no loops have been
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index e6126881af..f0db2a3761 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/MDBuilder.h"
@@ -34,8 +36,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -61,7 +61,7 @@ string AsString(const std::string& str) {
return string(str.data(), str.length());
}
-llvm::StringRef AsStringRef(tensorflow::StringPiece str) {
+llvm::StringRef AsStringRef(absl::string_view str) {
return llvm::StringRef(str.data(), str.size());
}
@@ -262,15 +262,17 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
}
llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
- tensorflow::StringPiece name,
+ absl::string_view name,
llvm::IRBuilder<>* b,
int alignment) {
return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment);
}
-llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(
- llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name,
- llvm::IRBuilder<>* b, int alignment) {
+llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
+ llvm::Value* element_count,
+ absl::string_view name,
+ llvm::IRBuilder<>* b,
+ int alignment) {
llvm::IRBuilder<>::InsertPoint insert_point = b->saveIP();
llvm::Function* function = b->GetInsertBlock()->getParent();
b->SetInsertPoint(&function->getEntryBlock(),
@@ -285,7 +287,7 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(
}
llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
- tensorflow::StringPiece name,
+ absl::string_view name,
llvm::IRBuilder<>* b) {
return llvm::BasicBlock::Create(
/*Context=*/b->getContext(),
@@ -294,27 +296,25 @@ llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
/*InsertBefore*/ insert_before);
}
-LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name,
+LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
llvm::IRBuilder<>* b, bool emit_else) {
llvm_ir::LlvmIfData if_data;
if_data.if_block = b->GetInsertBlock();
if_data.true_block =
- CreateBasicBlock(nullptr, tensorflow::strings::StrCat(name, "-true"), b);
+ CreateBasicBlock(nullptr, absl::StrCat(name, "-true"), b);
if_data.false_block =
- emit_else ? CreateBasicBlock(
- nullptr, tensorflow::strings::StrCat(name, "-false"), b)
+ emit_else ? CreateBasicBlock(nullptr, absl::StrCat(name, "-false"), b)
: nullptr;
// Add a terminator to the if block, if necessary.
if (if_data.if_block->getTerminator() == nullptr) {
b->SetInsertPoint(if_data.if_block);
- if_data.after_block = CreateBasicBlock(
- nullptr, tensorflow::strings::StrCat(name, "-after"), b);
+ if_data.after_block =
+ CreateBasicBlock(nullptr, absl::StrCat(name, "-after"), b);
b->CreateBr(if_data.after_block);
} else {
if_data.after_block = if_data.if_block->splitBasicBlock(
- b->GetInsertPoint(),
- AsStringRef(tensorflow::strings::StrCat(name, "-after")));
+ b->GetInsertPoint(), AsStringRef(absl::StrCat(name, "-after")));
}
// Our basic block should now end with an unconditional branch. Remove it;
@@ -413,14 +413,14 @@ string IrName(string a) {
return a;
}
-string IrName(tensorflow::StringPiece a, tensorflow::StringPiece b) {
+string IrName(absl::string_view a, absl::string_view b) {
if (!a.empty() && !b.empty()) {
- return IrName(tensorflow::strings::StrCat(a, ".", b));
+ return IrName(absl::StrCat(a, ".", b));
}
- return IrName(tensorflow::strings::StrCat(a, b));
+ return IrName(absl::StrCat(a, b));
}
-string IrName(const HloInstruction* a, tensorflow::StringPiece b) {
+string IrName(const HloInstruction* a, absl::string_view b) {
return IrName(a->name(), b);
}
@@ -556,7 +556,7 @@ std::map<int, llvm::MDNode*> MergeMetadata(
return result;
}
-static string GetProcessUniqueIrFileName(tensorflow::StringPiece prefix) {
+static string GetProcessUniqueIrFileName(absl::string_view prefix) {
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
static NameUniquer* uniquer = new NameUniquer(/*separator=*/"-");
@@ -584,18 +584,16 @@ Status DumpIRToDirectory(const string& directory_name,
// XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously
// dumped from the same process in such cases.
string unique_and_safe_file_name = GetProcessUniqueIrFileName(
- tensorflow::strings::StrCat("ir-", SanitizeFileName(hlo_module_name), "-",
- optimized ? "with" : "no", "-opt"));
+ absl::StrCat("ir-", SanitizeFileName(hlo_module_name), "-",
+ optimized ? "with" : "no", "-opt"));
string ir_file_name = tensorflow::io::JoinPath(
- directory_name,
- tensorflow::strings::StrCat(unique_and_safe_file_name, ".ll"));
+ directory_name, absl::StrCat(unique_and_safe_file_name, ".ll"));
// For some models the embedded constants can be huge, so also dump the module
// with the constants stripped to get IR that is easier to manipulate.
string ir_no_constant_initializers_file_name = tensorflow::io::JoinPath(
- directory_name,
- tensorflow::strings::StrCat(unique_and_safe_file_name, "-noconst.ll"));
+ directory_name, absl::StrCat(unique_and_safe_file_name, "-noconst.ll"));
TF_RETURN_IF_ERROR(CreateAndWriteStringToFile(
directory_name, ir_file_name, DumpModuleToString(llvm_module)));
@@ -607,8 +605,7 @@ Status DumpIRToDirectory(const string& directory_name,
llvm::Function* CreateFunction(llvm::FunctionType* function_type,
llvm::GlobalValue::LinkageTypes linkage,
bool enable_fast_math, bool optimize_for_size,
- tensorflow::StringPiece name,
- llvm::Module* module) {
+ absl::string_view name, llvm::Module* module) {
llvm::Function* function =
llvm::Function::Create(function_type, linkage, AsStringRef(name), module);
function->setCallingConv(llvm::CallingConv::C);
@@ -638,7 +635,7 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) {
fake_argv_storage.push_back("");
for (const auto& it : options) {
// Skip options the XLA backend itself consumes.
- if (!tensorflow::str_util::StartsWith(it.first, "xla_")) {
+ if (!absl::StartsWith(it.first, "xla_")) {
if (it.second.empty()) {
fake_argv_storage.push_back(it.first);
} else {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
index 0958398534..dde50e19d1 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/strings/string_view.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/IRBuilder.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
@@ -47,11 +47,11 @@ namespace llvm_ir {
// Convert a std::string (used by LLVM's interfaces) to string.
string AsString(const std::string& str);
-// Convert a tensorflow::StringPiece to a llvm::StringRef. Note: both
-// tensorflow::StringPiece and llvm::StringRef are non-owning pointers into a
+// Convert a absl::string_view to a llvm::StringRef. Note: both
+// absl::string_view and llvm::StringRef are non-owning pointers into a
// string in memory. This method is used to feed strings to LLVM
// & Clang APIs that expect llvm::StringRef.
-llvm::StringRef AsStringRef(tensorflow::StringPiece str);
+llvm::StringRef AsStringRef(absl::string_view str);
template <typename T>
llvm::ArrayRef<T> AsArrayRef(const std::vector<T>& vec) {
@@ -88,8 +88,8 @@ string DumpModuleToString(const llvm::Module& module);
// - removing all '%'s.
//
string IrName(string a);
-string IrName(tensorflow::StringPiece a, tensorflow::StringPiece b);
-string IrName(const HloInstruction* a, tensorflow::StringPiece b = "");
+string IrName(absl::string_view a, absl::string_view b);
+string IrName(const HloInstruction* a, absl::string_view b = "");
// Removes special characters from a function name.
//
@@ -164,21 +164,23 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
// This can be useful to avoid e.g. executing an alloca every time
// through a loop.
llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
- tensorflow::StringPiece name,
+ absl::string_view name,
llvm::IRBuilder<>* b,
int alignment = 0);
// As EmitAllocaAtFunctionEntry, but allocates element_count entries
// instead of a single element.
-llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(
- llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name,
- llvm::IRBuilder<>* b, int alignment = 0);
+llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
+ llvm::Value* element_count,
+ absl::string_view name,
+ llvm::IRBuilder<>* b,
+ int alignment = 0);
// Creates a basic block with the same context and function as for the
// builder. Inserts at the end of the function if insert_before is
// null.
llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
- tensorflow::StringPiece name,
+ absl::string_view name,
llvm::IRBuilder<>* b);
// Struct with data on a conditional branch in a diamond shape created
@@ -210,7 +212,7 @@ struct LlvmIfData {
// Currently the insertion point of the builder must be a well-formed
// block with a terminator. If you need to use this for a
// non-terminated block, just make the function able to do that too.
-LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name,
+LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
llvm::IRBuilder<>* b, bool emit_else = true);
// Emits a compare operation between "lhs" and "rhs" with the given predicate,
@@ -285,8 +287,7 @@ Status DumpIRToDirectory(const string& directory_name,
llvm::Function* CreateFunction(llvm::FunctionType* function_type,
llvm::GlobalValue::LinkageTypes linkage,
bool enable_fast_math, bool optimize_for_size,
- tensorflow::StringPiece name,
- llvm::Module* module);
+ absl::string_view name, llvm::Module* module);
// Extracts the xla_backend_extra_options from `config` and passes those that
// don't start with xla_ to LLVM.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
index 36f5fa1952..cf7445804c 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
@@ -86,7 +86,7 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
}
std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name, llvm::Type* index_type) {
+ absl::string_view loop_name, llvm::Type* index_type) {
CHECK_NE(index_type, nullptr);
if (ShapeUtil::IsScalar(shape_)) {
// No loop needed, so set exit_bb_ to nullptr.
@@ -122,7 +122,7 @@ std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
return {array_index};
}
-Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name,
+Status LoopEmitter::EmitLoop(absl::string_view loop_name,
llvm::Type* index_type) {
if (index_type == nullptr) {
index_type = b_->getInt64Ty();
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
index c4f5c82086..57d9d8bbc6 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
@@ -69,10 +69,10 @@ class LoopEmitter {
}
virtual std::vector<IrArray::Index> EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name, llvm::Type* index_type);
+ absl::string_view loop_name, llvm::Type* index_type);
// Emits a complete loop nest for every element in the given shape.
- Status EmitLoop(tensorflow::StringPiece loop_name = "",
+ Status EmitLoop(absl::string_view loop_name = "",
llvm::Type* index_type = nullptr);
protected:
diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
index c333311a7e..00dd3f1638 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
@@ -30,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -88,7 +88,7 @@ void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index,
Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
const absl::optional<IrArray>& values_array,
- tensorflow::StringPiece name, llvm::Value* xor_mask,
+ absl::string_view name, llvm::Value* xor_mask,
llvm::IRBuilder<>* b,
const gpu::LaunchDimensions* launch_dimensions) {
const Shape& keys_shape = keys_array.GetShape();
diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
index 39fffea931..527ed10374 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
@@ -16,12 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_
+#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -32,7 +32,7 @@ namespace llvm_ir {
// the inner compare loop will not be parallelized.
Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
const absl::optional<IrArray>& values_array,
- tensorflow::StringPiece name, llvm::Value* xor_mask,
+ absl::string_view name, llvm::Value* xor_mask,
llvm::IRBuilder<>* b,
const gpu::LaunchDimensions* launch_dimensions);
} // namespace llvm_ir
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index b7cb782a7e..ea59adadea 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
@@ -37,7 +38,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
diff --git a/tensorflow/compiler/xla/service/logical_buffer.cc b/tensorflow/compiler/xla/service/logical_buffer.cc
index c742d35a7b..e1f56727bd 100644
--- a/tensorflow/compiler/xla/service/logical_buffer.cc
+++ b/tensorflow/compiler/xla/service/logical_buffer.cc
@@ -15,11 +15,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
@@ -34,11 +34,10 @@ LogicalBuffer::~LogicalBuffer() {}
string LogicalBuffer::ToString() const {
string color_string;
if (has_color()) {
- color_string = tensorflow::strings::StrCat(" @", color().value());
+ color_string = absl::StrCat(" @", color().value());
}
- return tensorflow::strings::StrCat(instruction_->name(), "[",
- tensorflow::str_util::Join(index_, ","),
- "](#", id(), color_string, ")");
+ return absl::StrCat(instruction_->name(), "[", absl::StrJoin(index_, ","),
+ "](#", id(), color_string, ")");
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h
index 6aa639a954..4c8cb7d379 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.h
@@ -19,10 +19,10 @@ limitations under the License.
#include <queue>
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
@@ -48,9 +48,7 @@ class MultiOutputFusion : public HloPassInterface {
public:
MultiOutputFusion(int64 fuel) : fuel_(fuel) {}
- tensorflow::StringPiece name() const override {
- return "multi_output_fusion";
- }
+ absl::string_view name() const override { return "multi_output_fusion"; }
// Run multi-output fusion on the given module. Returns whether the module
// was changed.
diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc
index f6e7578a89..70cd0a339a 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.cc
+++ b/tensorflow/compiler/xla/service/name_uniquer.cc
@@ -15,8 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/name_uniquer.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -52,7 +53,7 @@ NameUniquer::NameUniquer(const string& separator) {
return result;
}
-string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) {
+string NameUniquer::GetUniqueName(absl::string_view prefix) {
string root = GetSanitizedName(prefix.empty() ? "name" : std::string(prefix));
// Strip away numeric suffix (if any). Only recognize separator if it is in
@@ -63,20 +64,22 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) {
if (separator_index != string::npos && (separator_index > 0) &&
(separator_index < root.size() - 1)) {
string after_suffix = root.substr(separator_index + 1);
- if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) {
+ if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) {
has_numeric_suffix = true;
// Remove numeric suffix from root.
root = root.substr(0, separator_index);
+ } else {
+ // absl::SimpleAtoi may modify numeric_suffix even if it returns false.
+ numeric_suffix = 0;
}
}
SequentialIdGenerator& id_generator = generated_names_[root];
numeric_suffix = id_generator.RegisterId(numeric_suffix);
if (numeric_suffix == 0) {
- return has_numeric_suffix ? tensorflow::strings::StrCat(root, separator_, 0)
- : root;
+ return has_numeric_suffix ? absl::StrCat(root, separator_, 0) : root;
}
- tensorflow::strings::StrAppend(&root, separator_, numeric_suffix);
+ absl::StrAppend(&root, separator_, numeric_suffix);
return root;
}
diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h
index 4423d61069..6dd89c240f 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.h
+++ b/tensorflow/compiler/xla/service/name_uniquer.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <string>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
@@ -38,7 +38,7 @@ class NameUniquer {
// Get a sanitized unique name in a string, with an optional prefix for
// convenience.
- string GetUniqueName(tensorflow::StringPiece prefix = "");
+ string GetUniqueName(absl::string_view prefix = "");
// Sanitizes and returns the name. Unallowed characters will be replaced with
// '_'. The result will match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*".
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
index ac6ea4c72f..ccc06ce613 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher.h
+++ b/tensorflow/compiler/xla/service/pattern_matcher.h
@@ -16,11 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
@@ -622,7 +622,7 @@ template <typename Previous>
class HloInstructionPatternNameImpl {
public:
explicit HloInstructionPatternNameImpl(const Previous& previous,
- tensorflow::StringPiece name)
+ absl::string_view name)
: previous_(previous), name_(name) {}
bool Match(const ::xla::HloInstruction* inst) const {
@@ -631,7 +631,7 @@ class HloInstructionPatternNameImpl {
private:
Previous previous_;
- tensorflow::StringPiece name_;
+ absl::string_view name_;
};
// An HloInstructionPattern implementation that matches only if the instruction
@@ -784,7 +784,7 @@ class HloInstructionPattern {
// Modifies the pattern to match only if the instruction has the given name.
HloInstructionPattern<HloInstructionType, HloInstructionPatternNameImpl<Impl>>
- WithName(tensorflow::StringPiece name) const {
+ WithName(absl::string_view name) const {
return HloInstructionPattern<HloInstructionType,
HloInstructionPatternNameImpl<Impl>>(
HloInstructionPatternNameImpl<Impl>(impl_, name), matched_inst_);
diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc
index 39fe3c7835..150af0cd93 100644
--- a/tensorflow/compiler/xla/service/platform_util.cc
+++ b/tensorflow/compiler/xla/service/platform_util.cc
@@ -19,20 +19,19 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
-using tensorflow::str_util::Lowercase;
-
// Minimum supported CUDA compute capability is 3.5.
constexpr int kMinCudaComputeCapabilityMajor = 3;
constexpr int kMinCudaComputeCapabilityMinor = 5;
@@ -43,7 +42,7 @@ constexpr char kInterpreter[] = "interpreter";
namespace {
string CanonicalPlatformName(const string& name) {
- string platform_str = Lowercase(name);
+ string platform_str = absl::AsciiStrToLower(name);
// "cpu" and "host" mean the same thing.
if (platform_str == "cpu") {
platform_str = "host";
@@ -94,7 +93,7 @@ PlatformUtil::GetSupportedPlatforms() {
}
// Multiple platforms present and we can't pick a reasonable default.
- string platforms_string = tensorflow::str_util::Join(
+ string platforms_string = absl::StrJoin(
platforms, ", ",
[](string* out, const se::Platform* p) { out->append(p->Name()); });
return InvalidArgument(
@@ -110,15 +109,15 @@ PlatformUtil::GetSupportedPlatforms() {
return platforms[0];
} else if (platforms.size() == 2) {
for (int i = 0; i < 2; i++) {
- if (Lowercase(platforms[i]->Name()) == kInterpreter &&
- Lowercase(platforms[1 - i]->Name()) != kInterpreter) {
+ if (absl::AsciiStrToLower(platforms[i]->Name()) == kInterpreter &&
+ absl::AsciiStrToLower(platforms[1 - i]->Name()) != kInterpreter) {
return platforms[1 - i];
}
}
}
// Multiple platforms present and we can't pick a reasonable default.
- string platforms_string = tensorflow::str_util::Join(
+ string platforms_string = absl::StrJoin(
platforms, ", ",
[](string* out, const se::Platform* p) { out->append(p->Name()); });
return InvalidArgument(
@@ -132,7 +131,7 @@ PlatformUtil::GetSupportedPlatforms() {
string platform_str = CanonicalPlatformName(platform_name);
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
for (se::Platform* platform : platforms) {
- if (Lowercase(platform->Name()) == platform_str) {
+ if (absl::AsciiStrToLower(platform->Name()) == platform_str) {
return platform;
}
}
@@ -146,7 +145,7 @@ PlatformUtil::GetSupportedPlatforms() {
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
std::vector<se::Platform*> matched;
for (se::Platform* platform : platforms) {
- if (Lowercase(platform->Name()) != platform_name) {
+ if (absl::AsciiStrToLower(platform->Name()) != platform_name) {
matched.push_back(platform);
}
}
@@ -157,7 +156,7 @@ PlatformUtil::GetSupportedPlatforms() {
if (matched.size() == 1) {
return matched[0];
}
- string matched_string = tensorflow::str_util::Join(
+ string matched_string = absl::StrJoin(
matched, ", ",
[](string* out, const se::Platform* p) { out->append(p->Name()); });
return InvalidArgument(
diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
index afde3cf95c..256b231e3a 100644
--- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h
+++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
@@ -59,7 +59,7 @@ class ReducePrecisionInsertion : public HloPassInterface {
~ReducePrecisionInsertion() override{};
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "reduce-precision-insertion";
}
diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h
index 1f59e3b314..1e86a0823a 100644
--- a/tensorflow/compiler/xla/service/reshape_mover.h
+++ b/tensorflow/compiler/xla/service/reshape_mover.h
@@ -26,7 +26,7 @@ namespace xla {
// them inputward also.
class ReshapeMover : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "reshape-mover"; }
+ absl::string_view name() const override { return "reshape-mover"; }
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc
index 7534a3f7e3..a395dd5333 100644
--- a/tensorflow/compiler/xla/service/reshape_mover_test.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc
@@ -28,13 +28,18 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-
-namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-using ReshapeMoverTest = HloVerifiedTestBase;
+
+namespace op = xla::testing::opcode_matchers;
+
+class ReshapeMoverTest : public HloVerifiedTestBase {
+ public:
+ ReshapeMoverTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) {
HloComputation::Builder builder(TestName());
diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h
index 8f735e877d..14f062c89c 100644
--- a/tensorflow/compiler/xla/service/scatter_expander.h
+++ b/tensorflow/compiler/xla/service/scatter_expander.h
@@ -22,7 +22,7 @@ namespace xla {
class ScatterExpander : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "scatter_expander"; }
+ absl::string_view name() const override { return "scatter_expander"; }
StatusOr<bool> Run(HloModule* module) override;
private:
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 18d1b7732b..d39a5191b8 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
@@ -46,7 +47,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -55,8 +55,8 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/ptr_util.h"
+using absl::StrCat;
using ::tensorflow::strings::Printf;
-using ::tensorflow::strings::StrCat;
namespace xla {
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index ec6aa6df55..6a22f8bef4 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -22,6 +22,9 @@ limitations under the License.
#include <string>
#include "absl/algorithm/container.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -29,28 +32,24 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/math/math_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
-using tensorflow::str_util::Join;
-using tensorflow::strings::Printf;
-
namespace xla {
-
namespace {
+using absl::StrJoin;
+using tensorflow::strings::Printf;
+
// Returns true if no element is present in slice more than once.
bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) {
return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
}
-Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) {
+Status ExpectArray(const Shape& shape, absl::string_view op_type) {
if (!ShapeUtil::IsArray(shape)) {
return InvalidArgument("Expected array argument for %s, but got %s.",
std::string(op_type).c_str(),
@@ -234,10 +233,12 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
switch (opcode) {
case HloOpcode::kFloor:
case HloOpcode::kCeil:
+ case HloOpcode::kRoundNearestAfz:
if (!ShapeUtil::ElementIsFloating(shape)) {
return InvalidArgument(
- "Expected element type in shape to be floating for floor/ceil "
- "operation; got %s.",
+ "Expected element type in shape to be floating for %s operation; "
+ "got %s.",
+ HloOpcodeString(opcode).c_str(),
PrimitiveType_Name(shape.element_type()).c_str());
}
return shape;
@@ -251,8 +252,9 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
if (!ShapeUtil::ElementIsFloating(shape) &&
!ShapeUtil::ElementIsComplex(shape)) {
return InvalidArgument(
- "Expected element type in shape to be floating or complex for "
- "sin/cos/exp/log/tanh operation; got %s.",
+ "Expected element type in shape to be floating or complex for %s "
+ "operation; got %s.",
+ HloOpcodeString(opcode).c_str(),
PrimitiveType_Name(shape.element_type()).c_str());
}
return shape;
@@ -265,19 +267,51 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
} else {
return InvalidArgument(
"Expected element type in shape to be floating or complex for "
- "real/imag operation; got %s.",
+ "%s operation; got %s.",
+ HloOpcodeString(opcode).c_str(),
PrimitiveType_Name(shape.element_type()).c_str());
}
case HloOpcode::kAbs:
if (ShapeUtil::ElementIsComplex(shape)) {
return ShapeUtil::ChangeElementType(
shape, primitive_util::ComplexComponentType(shape.element_type()));
+ } else if (ShapeUtil::ElementIsSigned(shape)) {
+ return shape;
+ } else {
+ return InvalidArgument(
+ "Expected element type in shape to be floating or complex for "
+ "%s operation; got %s.",
+ HloOpcodeString(opcode).c_str(),
+ PrimitiveType_Name(shape.element_type()).c_str());
}
- return shape;
case HloOpcode::kClz:
+ if (!ShapeUtil::ElementIsIntegral(shape)) {
+ return InvalidArgument(
+ "Expected an integral element type in argument to Clz "
+ "operation; got %s.",
+ PrimitiveType_Name(shape.element_type()).c_str());
+ }
+ return shape;
case HloOpcode::kNegate:
- case HloOpcode::kRoundNearestAfz:
+ if (!ShapeUtil::ElementIsIntegral(shape) &&
+ !ShapeUtil::ElementIsFloating(shape) &&
+ !ShapeUtil::ElementIsComplex(shape)) {
+ return InvalidArgument(
+ "Expected element type in shape to be integral, floating or "
+ "complex for %s operation; got %s.",
+ HloOpcodeString(opcode).c_str(),
+ PrimitiveType_Name(shape.element_type()).c_str());
+ }
+ return shape;
case HloOpcode::kSign:
+ if (!ShapeUtil::ElementIsSigned(shape) &&
+ !ShapeUtil::ElementIsComplex(shape)) {
+ return InvalidArgument(
+ "Expected element type in shape to be signed or complex for "
+ "%s operation; got %s.",
+ HloOpcodeString(opcode).c_str(),
+ PrimitiveType_Name(shape.element_type()).c_str());
+ }
return shape;
case HloOpcode::kNot:
@@ -879,16 +913,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
HloOpcodeString(opcode).c_str(), ShapeUtil::HumanString(lhs).c_str(),
ShapeUtil::HumanString(rhs).c_str(),
- Join(broadcast_dimensions, ", ").c_str());
+ StrJoin(broadcast_dimensions, ", ").c_str());
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
- TF_RETURN_IF_ERROR(
- ExpectArray(lhs, tensorflow::strings::StrCat("lhs of binary operation ",
- HloOpcodeString(opcode))));
- TF_RETURN_IF_ERROR(
- ExpectArray(rhs, tensorflow::strings::StrCat("rhs of binary operation ",
- HloOpcodeString(opcode))));
+ TF_RETURN_IF_ERROR(ExpectArray(
+ lhs, absl::StrCat("lhs of binary operation ", HloOpcodeString(opcode))));
+ TF_RETURN_IF_ERROR(ExpectArray(
+ rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode))));
switch (opcode) {
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
@@ -1059,7 +1091,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Map operation requires all operands to have the same shape; got: "
"%s.",
- Join(pieces, ", ").c_str());
+ StrJoin(pieces, ", ").c_str());
}
// Check that dimensions.size == arg_shape.dimensions_size() (we currently
@@ -1076,7 +1108,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (dimensions[i] != i) {
return InvalidArgument(
"Map requires monotonically increasing dimension numbers; got: %s.",
- Join(dimensions, ", ").c_str());
+ StrJoin(dimensions, ", ").c_str());
}
}
@@ -1977,14 +2009,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"%s in slice operation; argument shape: %s; starts: {%s}; limits: "
"{%s}; strides: {%s}.",
message.c_str(), ShapeUtil::HumanString(arg).c_str(),
- Join(starts, ",").c_str(), Join(limits, ",").c_str(),
- Join(strides, ",").c_str());
+ StrJoin(starts, ",").c_str(), StrJoin(limits, ",").c_str(),
+ StrJoin(strides, ",").c_str());
};
TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice"));
VLOG(2) << tensorflow::strings::Printf(
"slicing shape %s starts={%s} limits={%s}",
- ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(),
- Join(limits, ", ").c_str());
+ ShapeUtil::HumanString(arg).c_str(), StrJoin(starts, ", ").c_str(),
+ StrJoin(limits, ", ").c_str());
if (starts.size() != limits.size()) {
return error(Printf("slice start and limit sizes differ: %zu vs %zu",
@@ -2047,7 +2079,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
ShapeUtil::HumanString(operand_shape).c_str(),
ShapeUtil::HumanString(start_indices_shape).c_str(),
- Join(slice_sizes, ", ").c_str());
+ StrJoin(slice_sizes, ", ").c_str());
if (ShapeUtil::Rank(start_indices_shape) != 1) {
return InvalidArgument(
@@ -2344,7 +2376,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Reshape dimensions [%s] are not a permutation of the operand "
"dimensions (operand shape is %s).",
- Join(dimensions, ",").c_str(), ShapeUtil::HumanString(operand).c_str());
+ StrJoin(dimensions, ",").c_str(),
+ ShapeUtil::HumanString(operand).c_str());
}
return inferred_shape;
@@ -2464,8 +2497,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (arg_shapes.size() != to_apply.parameters_size()) {
string computation_signature = ShapeUtil::HumanString(to_apply);
string argument_shapes =
- Join(arg_shapes, ", ", [](string* out, const Shape* shape) {
- tensorflow::strings::StrAppend(out, ShapeUtil::HumanString(*shape));
+ StrJoin(arg_shapes, ", ", [](string* out, const Shape* shape) {
+ absl::StrAppend(out, ShapeUtil::HumanString(*shape));
});
return InvalidArgument(
"Call applied function arity must match number of arguments; got: "
@@ -2498,14 +2531,14 @@ static Status ValidateGatherDimensionNumbers(
if (!absl::c_is_sorted(dim_numbers.offset_dims())) {
return InvalidArgument(
"Output window dimensions in gather op must be ascending; got: %s.",
- Join(dim_numbers.offset_dims(), ", ").c_str());
+ StrJoin(dim_numbers.offset_dims(), ", ").c_str());
}
if (absl::c_adjacent_find(dim_numbers.offset_dims()) !=
dim_numbers.offset_dims().end()) {
return InvalidArgument(
"Output window dimensions in gather op must not repeat; got: %s.",
- Join(dim_numbers.offset_dims(), ", ").c_str());
+ StrJoin(dim_numbers.offset_dims(), ", ").c_str());
}
const int64 output_offset_dim_count = dim_numbers.offset_dims_size();
@@ -2554,7 +2587,7 @@ static Status ValidateGatherDimensionNumbers(
return InvalidArgument(
"Repeated dimensions are not allowed in start_index_map; "
"got: %s.",
- Join(dim_numbers.start_index_map(), ", ").c_str());
+ StrJoin(dim_numbers.start_index_map(), ", ").c_str());
}
for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) {
@@ -2569,7 +2602,7 @@ static Status ValidateGatherDimensionNumbers(
if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) {
return InvalidArgument(
"collapsed_slice_dims in gather op must be sorted; got: %s",
- Join(dim_numbers.collapsed_slice_dims(), ", ").c_str());
+ StrJoin(dim_numbers.collapsed_slice_dims(), ", ").c_str());
}
if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) !=
@@ -2577,7 +2610,7 @@ static Status ValidateGatherDimensionNumbers(
return InvalidArgument(
"Repeated dimensions not allowed in collapsed_slice_dims in gather op; "
"got: %s.",
- Join(dim_numbers.collapsed_slice_dims(), ", ").c_str());
+ StrJoin(dim_numbers.collapsed_slice_dims(), ", ").c_str());
}
return Status::OK();
@@ -2639,8 +2672,9 @@ static Status ValidateGatherDimensionNumbers(
"All components of the offset index in a gather op must either be a "
"offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, "
"output_slice_sizes=%s, collapsed_slice_dims=%s.",
- slice_sizes.size(), Join(gather_dim_numbers.offset_dims(), ",").c_str(),
- Join(gather_dim_numbers.collapsed_slice_dims(), ",").c_str());
+ slice_sizes.size(),
+ StrJoin(gather_dim_numbers.offset_dims(), ",").c_str(),
+ StrJoin(gather_dim_numbers.collapsed_slice_dims(), ",").c_str());
}
for (int i = 0; i < slice_sizes.size(); i++) {
@@ -2703,13 +2737,13 @@ Status ValidateScatterDimensionNumbers(
if (!absl::c_is_sorted(dim_numbers.update_window_dims())) {
return InvalidArgument(
"update_window_dims in scatter op must be sorted; got: %s.",
- Join(dim_numbers.update_window_dims(), ", ").c_str());
+ StrJoin(dim_numbers.update_window_dims(), ", ").c_str());
}
if (absl::c_adjacent_find(dim_numbers.update_window_dims()) !=
dim_numbers.update_window_dims().end()) {
return InvalidArgument(
"update_window_dims in scatter op must not repeat; got: %s.",
- Join(dim_numbers.update_window_dims(), ", ").c_str());
+ StrJoin(dim_numbers.update_window_dims(), ", ").c_str());
}
const int64 updates_rank = ShapeUtil::Rank(updates_shape);
for (int64 window_dim : dim_numbers.update_window_dims()) {
@@ -2725,13 +2759,13 @@ Status ValidateScatterDimensionNumbers(
if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) {
return InvalidArgument(
"inserted_window_dims in scatter op must be sorted; got: %s.",
- Join(dim_numbers.inserted_window_dims(), ", ").c_str());
+ StrJoin(dim_numbers.inserted_window_dims(), ", ").c_str());
}
if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) !=
dim_numbers.inserted_window_dims().end()) {
return InvalidArgument(
"inserted_window_dims in scatter op must not repeat; got: %s.",
- Join(dim_numbers.inserted_window_dims(), ", ").c_str());
+ StrJoin(dim_numbers.inserted_window_dims(), ", ").c_str());
}
for (int64 inserted_dim : dim_numbers.inserted_window_dims()) {
if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) {
@@ -2773,7 +2807,7 @@ Status ValidateScatterDimensionNumbers(
return InvalidArgument(
"Repeated dimensions not allowed in scatter_dims_to_operand_dims; "
"got: %s.",
- Join(dim_numbers.scatter_dims_to_operand_dims(), ", ").c_str());
+ StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", ").c_str());
}
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc
index 70714ffff0..5c12dc37b7 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.cc
+++ b/tensorflow/compiler/xla/service/shaped_buffer.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -76,7 +77,7 @@ void ShapedBuffer::clear() {
}
string ShapedBuffer::ToString() const {
- string s = tensorflow::strings::StrCat(
+ string s = absl::StrCat(
"ShapedBuffer(", platform_->Name(), ":", device_ordinal(),
"), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()),
", on-device shape=" +
diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc
index e0f995fd0d..0c577ec67a 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/transfer_manager.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -28,7 +29,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/notification.h"
-using ::tensorflow::strings::StrCat;
+using absl::StrCat;
namespace xla {
/* static */ tensorflow::mutex
diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h
index 71e8446452..3e5aa2db60 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.h
+++ b/tensorflow/compiler/xla/service/transpose_folding.h
@@ -49,7 +49,7 @@ class TransposeFolding : public HloPassInterface {
explicit TransposeFolding(
TransposableGemmOperandsFn transposable_gemm_operands,
TransposableConvOperandsFn transposable_conv_operands);
- tensorflow::StringPiece name() const override { return "transpose-folding"; }
+ absl::string_view name() const override { return "transpose-folding"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index 0c2f2112af..cb07b8d4d3 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -27,17 +29,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
string BufferAlias::ToString() const {
- return tensorflow::strings::StrCat("BufferAlias(", instruction_->name(), "[",
- tensorflow::str_util::Join(index_, ","),
- "])");
+ return absl::StrCat("BufferAlias(", instruction_->name(), "[",
+ absl::StrJoin(index_, ","), "])");
}
std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) {
@@ -496,8 +495,7 @@ StatusOr<const LogicalBuffer*> TuplePointsToAnalysis::GetBufferDefinedAt(
if (buffers.size() != 1 || buffers[0]->instruction() != instruction) {
return FailedPrecondition(
"instruction %s does not define buffer at index {%s}",
- instruction->name().c_str(),
- tensorflow::str_util::Join(index, ",").c_str());
+ instruction->name().c_str(), absl::StrJoin(index, ",").c_str());
}
return buffers[0];
}
@@ -563,8 +561,7 @@ string TuplePointsToAnalysis::ToString() const {
for (const auto* computation : module_->MakeNonfusionComputations()) {
const char* entry =
computation == module_->entry_computation() ? "entry " : "";
- tensorflow::strings::StrAppend(&output, entry, "computation ",
- computation->name(), ":\n");
+ absl::StrAppend(&output, entry, "computation ", computation->name(), ":\n");
for (const HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
InstructionToString(instruction, &output);
@@ -576,12 +573,11 @@ string TuplePointsToAnalysis::ToString() const {
}
}
- tensorflow::strings::StrAppend(&output, "LogicalBuffers:\n");
+ absl::StrAppend(&output, "LogicalBuffers:\n");
for (const auto& b : logical_buffer_analysis_->logical_buffers()) {
- tensorflow::strings::StrAppend(&output, " buffer ", b->ToString(), ":\n");
+ absl::StrAppend(&output, " buffer ", b->ToString(), ":\n");
for (const BufferAlias& alias : logical_buffer_aliases_.at(b->id())) {
- tensorflow::strings::StrAppend(&output, " alias ", alias.ToString(),
- "\n");
+ absl::StrAppend(&output, " alias ", alias.ToString(), "\n");
}
}
return output;
@@ -590,20 +586,18 @@ string TuplePointsToAnalysis::ToString() const {
void TuplePointsToAnalysis::InstructionToString(
const HloInstruction* instruction, string* output) const {
const string prefix = instruction->IsFused() ? " " : "";
- tensorflow::strings::StrAppend(output, prefix, " instruction ",
- instruction->ToShortString(), ":\n");
+ absl::StrAppend(output, prefix, " instruction ",
+ instruction->ToShortString(), ":\n");
const PointsToSet& points_to_set = GetPointsToSet(instruction);
points_to_set.ForEachElement([&prefix, &output](
const ShapeIndex& index,
const PointsToSet::BufferList& points_to) {
- tensorflow::strings::StrAppend(
- output, prefix, " {", tensorflow::str_util::Join(index, ","), "}: ",
- tensorflow::str_util::Join(
- points_to, ", ",
- [](string* out, const LogicalBuffer* source) {
- out->append(source->ToString());
- }),
- "\n");
+ absl::StrAppend(output, prefix, " {", absl::StrJoin(index, ","), "}: ",
+ absl::StrJoin(points_to, ", ",
+ [](string* out, const LogicalBuffer* source) {
+ out->append(source->ToString());
+ }),
+ "\n");
});
}
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h
index 7509501883..8c91d6e69d 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier.h
+++ b/tensorflow/compiler/xla/service/tuple_simplifier.h
@@ -30,7 +30,7 @@ class TupleSimplifier : public HloPassInterface {
TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {}
explicit TupleSimplifier(bool exclude_entry_computation);
~TupleSimplifier() override {}
- tensorflow::StringPiece name() const override { return "tuple-simplifier"; }
+ absl::string_view name() const override { return "tuple-simplifier"; }
// Run tuple simplification on the given computation. Returns whether the
// computation was changed.
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
index 21fb8568a8..2dba7d7f75 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
@@ -54,7 +54,7 @@ class WhileLoopConstantSinking : public HloPassInterface {
public:
~WhileLoopConstantSinking() override = default;
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "while-loop-invariant-code-motion";
}
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
index 8e6cc87875..2cdf20ce80 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
@@ -38,7 +38,7 @@ class WhileLoopInvariantCodeMotion : public HloPassInterface {
: hoist_constants_(hoist_constants) {}
~WhileLoopInvariantCodeMotion() override = default;
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "while-loop-invariant-code-motion";
}
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
index 32e69c335b..e14014b961 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
@@ -28,6 +28,10 @@ namespace op = xla::testing::opcode_matchers;
class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase {
public:
+ WhileLoopInvariantCodeMotionTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
// Makes a computation which has one parameter, of the given shape, and always
// returns PRED[]{true}. This is useful as a dummy loop condition.
HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape,
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index a24e2b0116..6a7bfe3f12 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -14,12 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/while_loop_analysis.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
@@ -236,12 +236,11 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
<< "Instruction " << user->ToString(print_no_metadata)
<< " should be unused (except by root of while body), but has "
"users: {"
- << tensorflow::str_util::Join(
- user->users(), ", ",
- [&](string* out, const HloInstruction* instr) {
- tensorflow::strings::StrAppend(
- out, instr->ToString(print_no_metadata));
- })
+ << absl::StrJoin(user->users(), ", ",
+ [&](string* out, const HloInstruction* instr) {
+ absl::StrAppend(
+ out, instr->ToString(print_no_metadata));
+ })
<< "}";
replacements.emplace(user, nullptr);
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h
index 3d3e1d60f2..78024f14dc 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.h
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h
@@ -33,9 +33,7 @@ namespace xla {
class WhileLoopSimplifier : public HloPassInterface {
public:
~WhileLoopSimplifier() override {}
- tensorflow::StringPiece name() const override {
- return "simplify-while-loops";
- }
+ absl::string_view name() const override { return "simplify-while-loops"; }
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index 2e1571943e..cfe4104f6d 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -15,11 +15,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_replace.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace {
@@ -27,6 +28,11 @@ namespace {
namespace op = xla::testing::opcode_matchers;
class WhileLoopSimplifierTest : public HloVerifiedTestBase {
+ public:
+ WhileLoopSimplifierTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
// Makes an HloModule that contains a loop with `num_iters` iteration.
void MakeModuleWithSimpleLoop(int num_iters);
@@ -64,10 +70,8 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) {
}
)";
- string hlo_string = tensorflow::str_util::StringReplace(
- hlo_string_template, "{{LOOP_BOUND}}",
- tensorflow::strings::StrCat(42 + num_iters),
- /*replace_all=*/true);
+ string hlo_string = absl::StrReplaceAll(
+ hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}});
ParseAndVerifyModule(hlo_string);
}
@@ -103,10 +107,8 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound(
}
)";
- string hlo_string = tensorflow::str_util::StringReplace(
- hlo_string_template, "{{LOOP_BOUND}}",
- tensorflow::strings::StrCat(42 + num_iters),
- /*replace_all=*/true);
+ string hlo_string = absl::StrReplaceAll(
+ hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}});
ParseAndVerifyModule(hlo_string);
}
diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc
index 52d9c3e5ae..e8f76ff745 100644
--- a/tensorflow/compiler/xla/service/while_util.cc
+++ b/tensorflow/compiler/xla/service/while_util.cc
@@ -15,15 +15,15 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_util.h"
#include "absl/algorithm/container.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/tuple_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
-using tensorflow::strings::StrCat;
+using absl::StrCat;
static StatusOr<HloComputation*> WidenWhileCondition(
HloComputation* narrow_condition, const Shape& wide_shape) {
diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
index 8763e588c4..a7f0e207eb 100644
--- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
+++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
@@ -24,7 +24,7 @@ namespace xla {
class ZeroSizedHloElimination : public HloPassInterface {
public:
StatusOr<bool> Run(HloModule* module) override;
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "zero_sized_hlo_elimination";
}
};
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 7244be80d9..31ddd57eef 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -22,6 +22,13 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/ascii.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/strip.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
@@ -31,25 +38,22 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/regexp.h"
namespace xla {
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
+using absl::StrAppend;
+using absl::StrCat;
string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); }
string ShapeIndexView::ToString() const {
- return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}");
+ return StrCat("{", absl::StrJoin(indices_, ","), "}");
}
bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
@@ -449,14 +453,14 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
namespace {
// Class to memoize the computation of
-// tensorflow::str_util::Lowercase(PrimitiveType_Name(p))
+// absl::AsciiStrToLower(PrimitiveType_Name(p))
// for all PrimitiveType values "p"
class PrimitiveTypeNameGenerator {
public:
PrimitiveTypeNameGenerator() {
for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
if (PrimitiveType_IsValid(i)) {
- lowercase_name_[i] = tensorflow::str_util::Lowercase(
+ lowercase_name_[i] = absl::AsciiStrToLower(
PrimitiveType_Name(static_cast<PrimitiveType>(i)));
}
}
@@ -507,7 +511,7 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
return text;
}
return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[",
- tensorflow::str_util::Join(shape.dimensions(), ","), "]");
+ absl::StrJoin(shape.dimensions(), ","), "]");
}
/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
@@ -543,30 +547,30 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
: "(unknown)",
": ", HumanString(shape)));
}
- return StrCat("(", tensorflow::str_util::Join(parameters, ", "), ") -> ",
+ return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ",
HumanString(program_shape.result()));
}
namespace {
// Parses shapes with simple recursive descent structure -- consumes from the
// front of s and passes that view recursively as required.
-StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
- tensorflow::str_util::RemoveLeadingWhitespace(s);
+StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) {
+ *s = StripLeadingAsciiWhitespace(*s);
- if (tensorflow::str_util::ConsumePrefix(s, "(")) { // Tuple.
+ if (absl::ConsumePrefix(s, "(")) { // Tuple.
std::vector<Shape> shapes;
bool must_end = false;
while (true) {
- if (tensorflow::str_util::ConsumePrefix(s, ")")) {
+ if (absl::ConsumePrefix(s, ")")) {
break;
} else if (must_end) {
return InvalidArgument("Expected end of tuple; got: \"%s\"",
- std::string(*s).c_str());
+ string(*s).c_str());
}
shapes.emplace_back();
TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s));
- tensorflow::str_util::RemoveLeadingWhitespace(s);
- must_end = !tensorflow::str_util::ConsumePrefix(s, ",");
+ *s = StripLeadingAsciiWhitespace(*s);
+ must_end = !absl::ConsumePrefix(s, ",");
}
return ShapeUtil::MakeTupleShape(shapes);
}
@@ -575,9 +579,9 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
string dimensions_string;
string format_string;
string layout_string;
- // tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so
+ // absl::string_view is not compatible with internal RE2 StringPiece, so
// we convert in to the RE2-consumable type and then consume the corresponding
- // amount from our StringPiece type.
+ // amount from our string_view type.
static LazyRE2 shape_pattern = {
"^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?"};
tensorflow::RegexpStringPiece s_consumable(s->data(), s->size());
@@ -585,12 +589,12 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
&dimensions_string, &format_string, &layout_string)) {
size_t consumed = s->size() - s_consumable.size();
s->remove_prefix(consumed);
- auto string_to_int64 = [&s](const string& input) -> StatusOr<int64> {
+ auto string_to_int64 = [&s](absl::string_view input) -> StatusOr<int64> {
int64 element;
- if (!tensorflow::strings::safe_strto64(input.c_str(), &element)) {
+ if (!absl::SimpleAtoi(input, &element)) {
return InvalidArgument(
"Invalid s64 value in parsed shape string: \"%s\" in \"%s\"",
- input.c_str(), std::string(*s).c_str());
+ string(input).c_str(), string(*s).c_str());
}
return element;
};
@@ -598,7 +602,7 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
auto comma_list_to_int64s =
[string_to_int64](const string& input) -> StatusOr<std::vector<int64>> {
std::vector<int64> results;
- for (const string& piece : tensorflow::str_util::Split(input, ',')) {
+ for (const auto& piece : absl::StrSplit(input, ',', absl::SkipEmpty())) {
TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece));
results.push_back(element);
}
@@ -645,16 +649,15 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
}
return InvalidArgument("Invalid shape string to parse: \"%s\"",
- std::string(*s).c_str());
+ string(*s).c_str());
}
} // namespace
-/* static */ StatusOr<Shape> ShapeUtil::ParseShapeString(
- tensorflow::StringPiece s) {
+/* static */ StatusOr<Shape> ShapeUtil::ParseShapeString(absl::string_view s) {
TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s));
if (!s.empty()) {
return InvalidArgument("Invalid shape string to parse: \"%s\"",
- std::string(s).c_str());
+ string(s).c_str());
}
return shape;
}
@@ -1172,8 +1175,7 @@ Status ForEachMutableSubshapeHelper(
CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation)))
<< "shape=" << HumanStringWithLayout(shape)
<< ", new_shape=" << HumanStringWithLayout(new_shape)
- << ", permutation={" << tensorflow::str_util::Join(permutation, ",")
- << "}";
+ << ", permutation={" << absl::StrJoin(permutation, ",") << "}";
}
return new_shape;
}
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index cb72fbbb0e..84f36e48a0 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -228,7 +228,7 @@ class ShapeUtil {
// Parses a ShapeUtil::HumanString-format shape string back into a shape
// object.
- static StatusOr<Shape> ParseShapeString(tensorflow::StringPiece s);
+ static StatusOr<Shape> ParseShapeString(absl::string_view s);
// Returns whether the LHS and RHS shapes have the same dimensions; note: does
// not check element type.
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index e5dd62ae9a..7549ba9c78 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include <numeric>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
@@ -23,8 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace {
@@ -849,13 +849,13 @@ TEST(ShapeUtilTest, PermuteDimensionsLayout) {
std::iota(layout.begin(), layout.end(), 0);
do {
Shape s = ShapeUtil::MakeShapeWithLayout(F32, {10, 100, 1000}, layout);
- SCOPED_TRACE(tensorflow::strings::StrCat("s=", ShapeUtil::HumanString(s)));
+ SCOPED_TRACE(absl::StrCat("s=", ShapeUtil::HumanString(s)));
std::vector<int64> permutation(3);
std::iota(permutation.begin(), permutation.end(), 0);
do {
- SCOPED_TRACE(tensorflow::strings::StrCat(
- "permutation=", tensorflow::str_util::Join(permutation, ",")));
+ SCOPED_TRACE(
+ absl::StrCat("permutation=", absl::StrJoin(permutation, ",")));
// TransposeIsBitcast takes the inverse of the permutation that
// PermuteDimensions takes.
diff --git a/tensorflow/compiler/xla/status_macros.cc b/tensorflow/compiler/xla/status_macros.cc
index a6b1f9004f..b88fe367d7 100644
--- a/tensorflow/compiler/xla/status_macros.cc
+++ b/tensorflow/compiler/xla/status_macros.cc
@@ -17,9 +17,8 @@ limitations under the License.
#include <algorithm>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stacktrace.h"
@@ -37,8 +36,7 @@ static void LogError(const Status& status, const char* filename, int line,
if (TF_PREDICT_TRUE(log_severity != tensorflow::NUM_SEVERITIES)) {
string stack_trace;
if (should_log_stack_trace) {
- stack_trace =
- tensorflow::strings::StrCat("\n", tensorflow::CurrentStackTrace());
+ stack_trace = absl::StrCat("\n", tensorflow::CurrentStackTrace());
}
switch (log_severity) {
case tensorflow::INFO:
@@ -142,17 +140,15 @@ Status MakeErrorStream::Impl::GetStatus() {
is_done_ = true;
const string& stream_str = stream_.str();
- const string str =
- prior_message_handling_ == kAppendToPriorMessage
- ? tensorflow::strings::StrCat(prior_message_, stream_str)
- : tensorflow::strings::StrCat(stream_str, prior_message_);
+ const string str = prior_message_handling_ == kAppendToPriorMessage
+ ? absl::StrCat(prior_message_, stream_str)
+ : absl::StrCat(stream_str, prior_message_);
if (TF_PREDICT_FALSE(str.empty())) {
- return MakeError(file_, line_, code_,
- tensorflow::strings::StrCat(
- str, "Error without message at ", file_, ":", line_),
- true /* should_log */,
- tensorflow::ERROR /* log_severity */,
- should_log_stack_trace_);
+ return MakeError(
+ file_, line_, code_,
+ absl::StrCat(str, "Error without message at ", file_, ":", line_),
+ true /* should_log */, tensorflow::ERROR /* log_severity */,
+ should_log_stack_trace_);
} else {
return MakeError(file_, line_, code_, str, should_log_, log_severity_,
should_log_stack_trace_);
diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h
index 8918350135..3ede5e6e38 100644
--- a/tensorflow/compiler/xla/test_helpers.h
+++ b/tensorflow/compiler/xla/test_helpers.h
@@ -19,9 +19,9 @@ limitations under the License.
#include <list>
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 6baf95d631..6b29d833da 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -43,6 +43,7 @@ cc_library(
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
alwayslink = True,
)
@@ -205,6 +206,7 @@ cc_library(
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -391,6 +393,7 @@ xla_test(
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
],
)
@@ -557,6 +560,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -671,6 +675,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -689,7 +694,6 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
@@ -697,6 +701,7 @@ xla_test(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -746,7 +751,6 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
@@ -754,6 +758,7 @@ xla_test(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -829,7 +834,10 @@ xla_test(
timeout = "long",
srcs = ["convolution_test.cc"],
shard_count = 25,
- deps = CONVOLUTION_TEST_DEPS + ["@com_google_absl//absl/memory"],
+ deps = CONVOLUTION_TEST_DEPS + [
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ],
)
xla_test(
@@ -839,7 +847,10 @@ xla_test(
backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]},
backends = ["gpu"],
shard_count = 25,
- deps = CONVOLUTION_TEST_DEPS + ["@com_google_absl//absl/memory"],
+ deps = CONVOLUTION_TEST_DEPS + [
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ],
)
xla_test(
@@ -924,6 +935,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -1001,6 +1013,7 @@ xla_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
],
)
@@ -1128,6 +1141,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -1157,6 +1171,7 @@ xla_test_library(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1231,12 +1246,12 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
- ":client_library_test_base",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -1247,12 +1262,12 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
- ":client_library_test_base",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -1425,6 +1440,7 @@ xla_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1494,6 +1510,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -1648,6 +1665,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -1660,7 +1678,6 @@ xla_test(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
- "//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:global_data",
@@ -1671,6 +1688,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -1851,13 +1869,9 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_runner",
- "//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -1866,6 +1880,7 @@ xla_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -2026,6 +2041,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 84c5b6e549..577fd1ab3b 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -296,6 +296,22 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
ComputeAndCompareR1<int64>(&b, expected, {lhs_data.get(), rhs_data.get()});
}
+XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) {
+ XlaBuilder b(TestName());
+
+ std::vector<uint64> lhs{static_cast<uint64>(0x8000000000000000ULL)};
+ std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
+ auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
+
+ std::vector<uint64> rhs{static_cast<uint64>(0x7FFFFFFFFFFFFFFFULL)};
+ std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
+ auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
+
+ Lt(lhs_param, rhs_param);
+
+ ComputeAndCompare(&b, {std::move(*lhs_literal), std::move(*rhs_literal)});
+}
+
TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
const int count = GetParam();
XlaBuilder builder(TestName());
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index 24b17b7100..ac90a3adb6 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
@@ -41,7 +42,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/math/math_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -382,7 +382,7 @@ struct BatchNormTestParam {
friend ::std::ostream& operator<<(::std::ostream& os,
const BatchNormTestParam& p) {
- os << "bounds={" << tensorflow::str_util::Join(p.bounds, ", ") << "}, ";
+ os << "bounds={" << absl::StrJoin(p.bounds, ", ") << "}, ";
os << "feature_index=" << p.feature_index << ", ";
os << "random_value_mean=" << p.random_value_mean << ", ";
os << "random_value_var=" << p.random_value_var;
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 2cab3264a7..9cd974fd9b 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -196,8 +196,8 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
AsInt64Slice(expected.shape().dimensions()), minor_to_major);
TF_ASSIGN_OR_RETURN(auto actual,
ExecuteAndTransfer(computation, arguments, &layout));
- verify_output(*actual, tensorflow::strings::StrCat(
- "Test with output layout: ",
+ verify_output(*actual,
+ absl::StrCat("Test with output layout: ",
ShapeUtil::HumanStringWithLayout(layout)));
} while (std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
return Status::OK();
@@ -258,7 +258,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
output_with_layout));
string error_message = "Test with input layouts: ";
for (const auto& str : layout_strings) {
- tensorflow::strings::StrAppend(&error_message, str, " ");
+ absl::StrAppend(&error_message, str, " ");
}
verify_output(*actual, error_message);
return Status::OK();
@@ -391,7 +391,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
void ClientLibraryTestBase::ComputeAndCompareR1U8(
- XlaBuilder* builder, tensorflow::StringPiece expected,
+ XlaBuilder* builder, absl::string_view expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
auto actual_status = ExecuteAndTransfer(builder, arguments);
EXPECT_IS_OK(actual_status.status());
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 24d0325929..ac96d3e325 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -36,7 +37,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/bitmap.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/test.h"
@@ -202,7 +202,7 @@ class ClientLibraryTestBase : public ::testing::Test {
// Compare the result of the computation to a strings. In XLA strings are
// represented using rank-1 U8 shapes.
void ComputeAndCompareR1U8(
- XlaBuilder* builder, tensorflow::StringPiece expected,
+ XlaBuilder* builder, absl::string_view expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments);
// Convenience method for running a built computation, transferring the
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
index 5a06d061f0..8226b6de3f 100644
--- a/tensorflow/compiler/xla/tests/compute_constant_test.cc
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/match.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -145,8 +145,8 @@ TEST_F(ComputeConstantTest, DirectParamMissing) {
EXPECT_FALSE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<float>(client, computation, &b);
- EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(),
- "depends on a parameter"))
+ EXPECT_TRUE(
+ absl::StrContains(value.status().ToString(), "depends on a parameter"))
<< value.status();
}
}
@@ -161,8 +161,8 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) {
EXPECT_FALSE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<float>(client, computation, &b);
- EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(),
- "depends on a parameter"))
+ EXPECT_TRUE(
+ absl::StrContains(value.status().ToString(), "depends on a parameter"))
<< value.status();
}
}
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index 40658c3b77..d2c6478b02 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
@@ -35,8 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 0e9e92ed99..5873516442 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -261,16 +262,14 @@ string PrintDotTestParam(
const ::testing::TestParamInfo<DotTestParam>& test_param) {
const DotTestParam& param = test_param.param;
if (param.has_addend) {
- return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n,
- "_MajorToMinor",
- param.dot_lhs_row_major ? "T" : "F",
- param.dot_rhs_row_major ? "T" : "F",
- param.addend_row_major ? "T" : "F");
+ return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor",
+ param.dot_lhs_row_major ? "T" : "F",
+ param.dot_rhs_row_major ? "T" : "F",
+ param.addend_row_major ? "T" : "F");
} else {
- return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n,
- "_MajorToMinor",
- param.dot_lhs_row_major ? "T" : "F",
- param.dot_rhs_row_major ? "T" : "F");
+ return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor",
+ param.dot_lhs_row_major ? "T" : "F",
+ param.dot_rhs_row_major ? "T" : "F");
}
}
diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc
index 39cc6c5927..4a835a8e21 100644
--- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc
+++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc
@@ -16,13 +16,13 @@ limitations under the License.
#include <limits>
#include <string>
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -39,8 +39,7 @@ class FloorCeilTest : public ClientLibraryTestBase {
// Runs a computation and comparison on expected vs f(input)
void TestR1F32(tensorflow::gtl::ArraySlice<float> input,
tensorflow::gtl::ArraySlice<float> expected, Function f) {
- LOG(INFO) << "input: {" << tensorflow::str_util::Join(expected, ", ")
- << "}";
+ LOG(INFO) << "input: {" << absl::StrJoin(expected, ", ") << "}";
XlaBuilder builder(TestName());
auto c = ConstantR1<float>(&builder, input);
if (f == kCeil) {
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index 5635c3fe86..93ea144438 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -43,7 +43,7 @@ namespace xla {
namespace {
using absl::optional;
-using tensorflow::StringPiece;
+using absl::string_view;
using tensorflow::gtl::ArraySlice;
constexpr char kInterpreter[] = "interpreter";
@@ -86,16 +86,20 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
} // namespace
-HloTestBase::HloTestBase(bool allow_mixed_precision_in_hlo_verifier)
+HloTestBase::HloTestBase(bool verifier_layout_sensitive,
+ bool allow_mixed_precision_in_hlo_verifier)
: HloTestBase(GetTestPlatform(), GetReferencePlatform(),
+ verifier_layout_sensitive,
allow_mixed_precision_in_hlo_verifier) {}
HloTestBase::HloTestBase(se::Platform* test_platform,
se::Platform* reference_platform,
+ bool verifier_layout_sensitive,
bool allow_mixed_precision_in_hlo_verifier)
: test_runner_(test_platform), reference_runner_(reference_platform) {
- hlo_verifier_ =
- absl::make_unique<HloVerifier>(allow_mixed_precision_in_hlo_verifier);
+ hlo_verifier_ = absl::make_unique<HloVerifier>(
+ /*layout_sensitive=*/verifier_layout_sensitive,
+ /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier);
}
std::unique_ptr<HloModule> HloTestBase::CreateNewModule(const string& name) {
@@ -239,7 +243,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
}
::testing::AssertionResult HloTestBase::RunAndCompare(
- const StringPiece hlo_string, const absl::optional<ErrorSpec>& error,
+ string_view hlo_string, const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
auto module_or_status =
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
@@ -252,7 +256,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
reference_preprocessor);
}
-::testing::AssertionResult HloTestBase::Run(const StringPiece hlo_string) {
+::testing::AssertionResult HloTestBase::Run(string_view hlo_string) {
auto module_or_status =
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
if (!module_or_status.ok()) {
@@ -289,7 +293,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
}
::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
- const StringPiece hlo_string, const absl::optional<ErrorSpec>& error,
+ string_view hlo_string, const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
auto module_or_status =
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
@@ -316,7 +320,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
}
HloComputation* HloTestBase::FindComputation(HloModule* module,
- tensorflow::StringPiece name) {
+ absl::string_view name) {
auto computations = module->computations();
auto it = absl::c_find_if(
computations, [&](HloComputation* c) { return c->name() == name; });
@@ -327,7 +331,7 @@ HloComputation* HloTestBase::FindComputation(HloModule* module,
}
HloInstruction* HloTestBase::FindInstruction(HloModule* module,
- tensorflow::StringPiece name) {
+ absl::string_view name) {
for (const HloComputation* c : module->computations()) {
auto instructions = c->instructions();
auto it = absl::c_find_if(
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index d88abf561a..06bcc39741 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -85,12 +85,14 @@ class HloTestBase : public ::testing::Test {
// automatically finds another supported backend as the test backend. If the
// interpreter is the only supported backend, it will be both the test backend
// and the reference backend.
- HloTestBase(bool allow_mixed_precision_in_hlo_verifier = true);
+ HloTestBase(bool verifier_layout_sensitive = false,
+ bool allow_mixed_precision_in_hlo_verifier = true);
// If your test doesn't use interpreter as the reference backend, you can use
// this constructor. Note that your test target is responsible for linking in
// both needed backends.
HloTestBase(se::Platform* test_platform, se::Platform* reference_platform,
+ bool verifier_layout_sensitive = false,
bool allow_mixed_precision_in_hlo_verifier = true);
~HloTestBase() override {}
@@ -169,18 +171,18 @@ class HloTestBase : public ::testing::Test {
// input. Module can be passed in directly, or parsed from an hlo_string,
// or loaded from a file.
::testing::AssertionResult RunAndCompare(
- const tensorflow::StringPiece hlo_string,
+ const absl::string_view hlo_string,
const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
- ::testing::AssertionResult Run(const tensorflow::StringPiece hlo_string)
+ ::testing::AssertionResult Run(const absl::string_view hlo_string)
TF_MUST_USE_RESULT;
::testing::AssertionResult RunAndCompareFromFile(
const string& filename, const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
::testing::AssertionResult RunAndCompareNoHloPasses(
- const tensorflow::StringPiece hlo_string,
+ const absl::string_view hlo_string,
const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
@@ -228,10 +230,8 @@ class HloTestBase : public ::testing::Test {
//
// This is useful for tests which create HLOs from a string and then want to
// inspect a particular computation or instruction.
- HloComputation* FindComputation(HloModule* module,
- tensorflow::StringPiece name);
- HloInstruction* FindInstruction(HloModule* module,
- tensorflow::StringPiece name);
+ HloComputation* FindComputation(HloModule* module, absl::string_view name);
+ HloInstruction* FindInstruction(HloModule* module, absl::string_view name);
// Return an HLO verifier constructed for the test backend.
HloVerifier& verifier() const { return *hlo_verifier_; }
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
index a509ee3207..8f86c528d0 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
@@ -25,8 +25,11 @@ limitations under the License.
namespace xla {
-HloVerifiedTestBase::HloVerifiedTestBase()
- : shape_verifier_(absl::make_unique<ShapeVerifier>()) {}
+HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive,
+ bool allow_mixed_precision)
+ : HloTestBase(
+ /*verifier_layout_sensitive=*/layout_sensitive,
+ /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision) {}
HloVerifiedTestBase::~HloVerifiedTestBase() {
// We can't call the ASSERT or EXPECT test macros in destructors, so we
@@ -51,8 +54,7 @@ void HloVerifiedTestBase::TearDown() {
}
void HloVerifiedTestBase::VerifyModule(HloModule* module) {
- HloVerifier verifier(/*allow_mixed_precision=*/true);
- xla::StatusOr<bool> mutated = verifier.Run(module);
+ xla::StatusOr<bool> mutated = verifier().Run(module);
if (!mutated.ok()) {
ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
} else {
@@ -73,7 +75,7 @@ HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) {
return modules_.back().get();
}
-void HloVerifiedTestBase::ParseAndVerifyModule(tensorflow::StringPiece hlo_text,
+void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text,
const HloModuleConfig& config) {
CHECK(!module_) << "Called ParseModule when test already has a module.";
TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config));
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
index 5b28c01c36..cc6967feed 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
@@ -29,7 +29,8 @@ namespace xla {
// performs verification on that module on tear-down.
class HloVerifiedTestBase : public HloTestBase {
protected:
- HloVerifiedTestBase();
+ explicit HloVerifiedTestBase(bool layout_sensitive,
+ bool allow_mixed_precision);
~HloVerifiedTestBase() override;
// Constructs a default shape verifier.
@@ -44,32 +45,28 @@ class HloVerifiedTestBase : public HloTestBase {
// Returns the default HloModule, lazily creating it if necessary via
// HloTestBase::CreateNewModule().
HloModule& module();
- void ParseAndVerifyModule(tensorflow::StringPiece hlo_text,
+ void ParseAndVerifyModule(absl::string_view hlo_text,
const HloModuleConfig& config = HloModuleConfig());
- // Sets the shape-size function used during hlo verification. If this isn't
- // called, a default ShapeVerifier is used instead.
- void SetShapeVerifier(std::unique_ptr<ShapeVerifier> shape_verifier) {
- shape_verifier_ = std::move(shape_verifier);
- }
-
// Creates a new module for a test, and stores it in modules_ so it can be
// verified. Intentionally hides HloTestBase::CreateNewModule, to prevent
// creation of unverified modules.
HloModule* CreateNewModule(const string& name = TestName());
+ private:
+ void VerifyModule(HloModule* module);
+
// It is confusing to store modules created by module() and CreateNewModule()
// in different fields, but it allows us to migrate tests to
// HloVerifiedTestBase more easily, so it's a win because we can verify more
// modules. See b/80488902.
- private:
+ //
// Lazily populated. Access via module().
std::unique_ptr<HloModule> module_;
// Populated by calls to CreateNewModule.
std::vector<std::unique_ptr<HloModule>> modules_;
- std::unique_ptr<ShapeVerifier> shape_verifier_;
+
bool tear_down_called_ = false;
- static void VerifyModule(HloModule* module);
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
index f297b2b847..4151bfae03 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
@@ -20,9 +20,9 @@ limitations under the License.
#include <vector>
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -80,7 +80,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
std::vector<string> results;
TF_CHECK_OK(env->GetMatchingPaths(pattern, &results));
- LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]";
+ LOG(INFO) << "results: [" << absl::StrJoin(results, ", ") << "]";
EXPECT_EQ(3, results.size());
for (const string& result : results) {
LiteralProto literal_proto;
@@ -105,8 +105,10 @@ TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
auto actual = LiteralUtil::CreateR1<int32>({4, 5, 6});
::testing::AssertionResult result =
LiteralTestUtil::Equal(*expected, *actual);
- EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}"));
- EXPECT_THAT(result.message(), ::testing::HasSubstr("actual: {4, 5, 6}"));
+ EXPECT_THAT(result.message(),
+ ::testing::HasSubstr("Expected literal:\n{1, 2, 3}"));
+ EXPECT_THAT(result.message(),
+ ::testing::HasSubstr("Actual literal:\n{4, 5, 6}"));
}
TEST(LiteralTestUtilTest, NearComparatorR1) {
diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
index b6035a21a6..7956a034f8 100644
--- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -158,7 +159,7 @@ class TestLinspaceMaxParametric
string PrintTestLinspaceMaxParam(
const ::testing::TestParamInfo<TestLinspaceMaxParam>& test_param) {
const TestLinspaceMaxParam& param = test_param.param;
- return tensorflow::strings::StrCat(param.rows, "r", param.cols, "c");
+ return absl::StrCat(param.rows, "r", param.cols, "c");
}
#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index cadf1c5523..16b77e965d 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
@@ -52,12 +53,22 @@ class MultiOutputFusionTest : public HloTestBase {
protected:
MultiOutputFusionTest() { error_spec_ = ErrorSpec{0.0001, 1e-2}; }
+ // Layout assignment assumes that there are no fusions in the input graph.
+ // Since the purpose of this test is to send pre-fused graphs to XLA, we have
+ // to do layout assignment ourselves.
+ DebugOptions GetDebugOptionsForTest() override {
+ auto opts = HloTestBase::GetDebugOptionsForTest();
+ opts.add_xla_disable_hlo_passes("layout-assignment");
+ return opts;
+ }
+
void RunTest2D(bool manual_fusion, int64 size) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
- const Shape elem_shape0 = ShapeUtil::MakeShape(F32, {});
- const Shape elem_shape2 = ShapeUtil::MakeShape(F32, {size, size});
+ const Shape elem_shape0 = ShapeUtil::MakeShapeWithLayout(F32, {}, {});
+ const Shape elem_shape2 =
+ ShapeUtil::MakeShapeWithLayout(F32, {size, size}, {1, 0});
auto const0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(8.0f)));
@@ -100,10 +111,10 @@ class MultiOutputFusionTest : public HloTestBase {
nullptr);
}
- Literal arg1(ShapeUtil::MakeShape(F32, {size, size}));
+ Literal arg1(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size}));
arg1.PopulateWithValue<float>(2.5f);
- Literal expect(ShapeUtil::MakeShape(F32, {size, size}));
+ Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size}));
expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
auto actual =
ExecuteAndTransfer(std::move(hlo_module),
@@ -115,8 +126,10 @@ class MultiOutputFusionTest : public HloTestBase {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
- const Shape elem_shape_F32 = ShapeUtil::MakeShape(F32, {size});
- const Shape elem_shape_U8 = ShapeUtil::MakeShape(F64, {size});
+ const Shape elem_shape_F32 =
+ ShapeUtil::MakeShapeWithDescendingLayout(F32, {size});
+ const Shape elem_shape_U8 =
+ ShapeUtil::MakeShapeWithDescendingLayout(F64, {size});
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, elem_shape_F32, "0"));
auto param1 = builder.AddInstruction(
@@ -136,12 +149,13 @@ class MultiOutputFusionTest : public HloTestBase {
HloInstruction* reshape =
builder.AddInstruction(HloInstruction::CreateReshape(
- ShapeUtil::MakeShape(F32, {size, 1}), add));
+ ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, 1}), add));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(0);
HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
- ShapeUtil::MakeShape(F32, {1}), sub, reshape, dot_dnums));
+ ShapeUtil::MakeShapeWithDescendingLayout(F32, {1}), sub, reshape,
+ dot_dnums));
auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
if (manual_fusion) {
@@ -161,9 +175,9 @@ class MultiOutputFusionTest : public HloTestBase {
nullptr);
}
- Literal input0(ShapeUtil::MakeShape(F32, {size}));
+ Literal input0(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size}));
input0.PopulateWithValue(2.5f);
- Literal input1(ShapeUtil::MakeShape(F64, {size}));
+ Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size}));
input1.PopulateWithValue(1.);
Literal expect =
@@ -291,7 +305,7 @@ const char* const kScalarOps = R"(
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) {
- const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
c0 = f32[] constant(0)
@@ -323,7 +337,7 @@ XLA_TEST_F(MultiOutputFusionTest,
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) {
- const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
c0 = f32[] constant(0)
@@ -355,7 +369,7 @@ XLA_TEST_F(MultiOutputFusionTest,
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) {
- const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
c0 = f32[] constant(0)
@@ -388,7 +402,7 @@ XLA_TEST_F(MultiOutputFusionTest,
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) {
- const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
c0 = f32[] constant(0)
@@ -422,7 +436,7 @@ XLA_TEST_F(MultiOutputFusionTest,
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) {
- const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
c0 = f32[] constant(0)
@@ -457,7 +471,7 @@ XLA_TEST_F(MultiOutputFusionTest,
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) {
- const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
c0 = f32[] constant(0)
@@ -494,7 +508,7 @@ XLA_TEST_F(MultiOutputFusionTest,
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) {
- const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
init1 = f32[] parameter(1)
@@ -529,7 +543,7 @@ XLA_TEST_F(MultiOutputFusionTest,
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionDifferentElementTypes)) {
- const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce (p0: f16[2,2,2]) -> (f32[2,2], f32[2,2], f16[2,2,2]) {
p0 = f16[2,2,2]{2,1,0} parameter(0)
convert = f32[2,2,2]{2,1,0} convert(p0)
diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
index a080dd1732..9af9ea4a22 100644
--- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
@@ -15,11 +15,11 @@ limitations under the License.
#include <array>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -29,16 +29,13 @@ limitations under the License.
namespace xla {
namespace {
-namespace str_util = tensorflow::str_util;
-namespace strings = tensorflow::strings;
-
struct ReduceLayout {
std::array<int64, 4> input_minor_to_major;
std::array<int64, 3> output_minor_to_major;
string ToString() const {
- return strings::StrCat(str_util::Join(input_minor_to_major, "x"), "_",
- str_util::Join(output_minor_to_major, "x"));
+ return absl::StrCat(absl::StrJoin(input_minor_to_major, "x"), "_",
+ absl::StrJoin(output_minor_to_major, "x"));
}
};
diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
index 531648fe3e..0916a07f4f 100644
--- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -57,8 +58,8 @@ static const int mantissa_sizes[] = {23, 10, 23, 10};
string TestDataToString(const ::testing::TestParamInfo<int> data) {
int i = data.param;
- return tensorflow::strings::StrCat(exponent_sizes[i], "_exponent_bits_",
- mantissa_sizes[i], "_mantissa_bits");
+ return absl::StrCat(exponent_sizes[i], "_exponent_bits_", mantissa_sizes[i],
+ "_mantissa_bits");
}
// The FPVAL macro allows us to write out the binary representation of the
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index 2065271a7f..b93d838349 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
@@ -559,9 +560,9 @@ void PrintTo(const BoundsLayout& spec, std::ostream* os) {
*os << tensorflow::strings::Printf(
"R%luToR%lu%s_%s_Reduce%s", spec.bounds.size(),
spec.bounds.size() - spec.reduce_dims.size(),
- tensorflow::str_util::Join(spec.bounds, "x").c_str(),
- tensorflow::str_util::Join(spec.layout, "").c_str(),
- tensorflow::str_util::Join(spec.reduce_dims, "").c_str());
+ absl::StrJoin(spec.bounds, "x").c_str(),
+ absl::StrJoin(spec.layout, "").c_str(),
+ absl::StrJoin(spec.reduce_dims, "").c_str());
}
// Add-reduces a broadcasted scalar matrix among dimension 1 and 0.
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index ebf7fa30be..60167619a4 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include <memory>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -579,21 +581,20 @@ string R4ReduceWindowTestDataToString(
const ::testing::TestParamInfo<
::testing::tuple<R4ReduceWindowTestData, bool>>& data) {
const auto& param = ::testing::get<0>(data.param);
- string str = tensorflow::strings::StrCat(
- "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), //
- "__window_bounds_",
- tensorflow::str_util::Join(param.window_bounds, "x"), //
- "__strides_", tensorflow::str_util::Join(param.strides, "x"), //
- "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), //
- "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), //
- "__layout_", tensorflow::str_util::Join(param.layout, "_"), //
+ string str = absl::StrCat(
+ "base_bounds_", absl::StrJoin(param.base_bounds, "x"), //
+ "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), //
+ "__strides_", absl::StrJoin(param.strides, "x"), //
+ "__pad_low_", absl::StrJoin(param.pad_low, "x"), //
+ "__pad_high_", absl::StrJoin(param.pad_high, "x"), //
+ "__layout_", absl::StrJoin(param.layout, "_"), //
(param.reducer == kAdd) ? "_add" : "_max");
CHECK(param.reducer == kAdd || param.reducer == kMax);
// Test names are not allowed to contain the '-' character.
std::replace(str.begin(), str.end(), '-', 'n');
if (::testing::get<1>(data.param)) {
- str = tensorflow::strings::StrCat(str, "_bfloat16");
+ str = absl::StrCat(str, "_bfloat16");
}
return str;
}
@@ -935,15 +936,15 @@ string R3ReduceWindowTestDataToString(
const ::testing::TestParamInfo<
::testing::tuple<R3ReduceWindowTestData, bool>>& data) {
const auto& param = ::testing::get<0>(data.param);
- string str = tensorflow::strings::StrCat(
- "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"),
- "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"),
- "__strides_", tensorflow::str_util::Join(param.strides, "x"),
- "__padding_", param.padding == Padding::kSame ? "same" : "valid",
- "__layout_", param.layout[0], "_", param.layout[1], "_", param.layout[2],
- "__reducer_", param.reducer == kAdd ? "add" : "max");
+ string str = absl::StrCat(
+ "base_bounds_", absl::StrJoin(param.base_bounds, "x"), "__window_bounds_",
+ absl::StrJoin(param.window_bounds, "x"), "__strides_",
+ absl::StrJoin(param.strides, "x"), "__padding_",
+ param.padding == Padding::kSame ? "same" : "valid", "__layout_",
+ param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_",
+ param.reducer == kAdd ? "add" : "max");
if (::testing::get<1>(data.param)) {
- str = tensorflow::strings::StrCat(str, "_bfloat16");
+ str = absl::StrCat(str, "_bfloat16");
}
return str;
}
@@ -1069,17 +1070,16 @@ string R2ReduceWindowTestDataToString(
const ::testing::TestParamInfo<
::testing::tuple<R2ReduceWindowTestData, bool>>& data) {
const auto& param = ::testing::get<0>(data.param);
- string str = tensorflow::strings::StrCat(
- "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), //
- "__window_bounds_",
- tensorflow::str_util::Join(param.window_bounds, "x"), //
- "__strides_", tensorflow::str_util::Join(param.strides, "x"), //
- "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"),
- "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"),
- "__layout_", param.layout[0], "_", param.layout[1], //
+ string str = absl::StrCat(
+ "base_bounds_", absl::StrJoin(param.base_bounds, "x"), //
+ "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), //
+ "__strides_", absl::StrJoin(param.strides, "x"), //
+ "__pad_low_", absl::StrJoin(param.pad_low, "x"), "__pad_high_",
+ absl::StrJoin(param.pad_high, "x"), "__layout_", param.layout[0], "_",
+ param.layout[1], //
"__reducer_", param.reducer == kAdd ? "add" : "max");
if (::testing::get<1>(data.param)) {
- str = tensorflow::strings::StrCat(str, "_bfloat16");
+ str = absl::StrCat(str, "_bfloat16");
}
return str;
}
@@ -1274,15 +1274,15 @@ string R1ReduceWindowTestDataToString(
const ::testing::TestParamInfo<
::testing::tuple<R1ReduceWindowTestData, bool>>& data) {
const auto& param = ::testing::get<0>(data.param);
- string str = tensorflow::strings::StrCat(
- "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"),
- "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"),
- "__strides_", tensorflow::str_util::Join(param.strides, "x"),
- "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"),
- "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"),
- "__reducer_", param.reducer == kAdd ? "add" : "max");
+ string str =
+ absl::StrCat("base_bounds_", absl::StrJoin(param.base_bounds, "x"),
+ "__window_bounds_", absl::StrJoin(param.window_bounds, "x"),
+ "__strides_", absl::StrJoin(param.strides, "x"),
+ "__pad_low_", absl::StrJoin(param.pad_low, "x"),
+ "__pad_high_", absl::StrJoin(param.pad_high, "x"),
+ "__reducer_", param.reducer == kAdd ? "add" : "max");
if (::testing::get<1>(data.param)) {
- str = tensorflow::strings::StrCat(str, "_bfloat16");
+ str = absl::StrCat(str, "_bfloat16");
}
return str;
}
diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc
index 41e49b4003..60084f143d 100644
--- a/tensorflow/compiler/xla/tests/reverse_test.cc
+++ b/tensorflow/compiler/xla/tests/reverse_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include <memory>
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -43,10 +44,8 @@ struct ReverseSpec {
string ToTestCaseName() const {
return tensorflow::strings::Printf(
- "reverse_%s_in_dims_%s_%s",
- tensorflow::str_util::Join(input_dims, "x").c_str(),
- tensorflow::str_util::Join(reversal, "x").c_str(),
- use_bfloat16 ? "bf16" : "f32");
+ "reverse_%s_in_dims_%s_%s", absl::StrJoin(input_dims, "x").c_str(),
+ absl::StrJoin(reversal, "x").c_str(), use_bfloat16 ? "bf16" : "f32");
}
};
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index e42c71eb28..cf2d453f43 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <limits>
#include <memory>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -31,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index d865c414fd..c57bbbd1e4 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include <vector>
#include "absl/container/inlined_vector.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -34,8 +36,6 @@ limitations under the License.
namespace xla {
namespace {
-using ::tensorflow::str_util::Join;
-
class SliceTest : public ClientLibraryTestBase {};
TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) {
@@ -449,13 +449,11 @@ struct R4Spec {
string R4SpecToString(const ::testing::TestParamInfo<R4Spec>& data) {
const R4Spec& spec = data.param;
- return tensorflow::strings::StrCat( //
- "input_", Join(spec.input_dims, "x"), //
- "__layout_", Join(spec.input_layout, ""), //
- "__starts_", Join(spec.slice_starts, "x"), //
- "__limits_", Join(spec.slice_limits, "x"), //
- "__strides_", Join(spec.slice_strides, "x") //
- );
+ return absl::StrCat("input_", absl::StrJoin(spec.input_dims, "x"),
+ "__layout_", absl::StrJoin(spec.input_layout, ""),
+ "__starts_", absl::StrJoin(spec.slice_starts, "x"),
+ "__limits_", absl::StrJoin(spec.slice_limits, "x"),
+ "__strides_", absl::StrJoin(spec.slice_strides, "x"));
}
class SliceR4Test : public ClientLibraryTestBase,
diff --git a/tensorflow/compiler/xla/tests/test_macros.cc b/tensorflow/compiler/xla/tests/test_macros.cc
index be35ec6c6e..a9874a9186 100644
--- a/tensorflow/compiler/xla/tests/test_macros.cc
+++ b/tensorflow/compiler/xla/tests/test_macros.cc
@@ -20,7 +20,9 @@ limitations under the License.
#include <string>
#include <unordered_map>
-#include "tensorflow/core/lib/strings/str_util.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
@@ -44,7 +46,7 @@ ManifestT ReadManifest() {
string contents((std::istreambuf_iterator<char>(file_stream)),
std::istreambuf_iterator<char>());
- std::vector<string> lines = tensorflow::str_util::Split(contents, '\n');
+ std::vector<string> lines = absl::StrSplit(contents, '\n');
for (string& line : lines) {
auto comment = line.find("//");
if (comment != string::npos) {
@@ -53,8 +55,8 @@ ManifestT ReadManifest() {
if (line.empty()) {
continue;
}
- tensorflow::str_util::StripTrailingWhitespace(&line);
- std::vector<string> pieces = tensorflow::str_util::Split(line, ' ');
+ absl::StripTrailingAsciiWhitespace(&line);
+ std::vector<string> pieces = absl::StrSplit(line, ' ');
CHECK_GE(pieces.size(), 1);
auto& platforms = manifest[pieces[0]];
for (int64 i = 1; i < pieces.size(); ++i) {
@@ -73,8 +75,7 @@ string PrependDisabledIfIndicated(const string& test_case_name,
// First try full match: test_case_name.test_name
// If that fails, try to find just the test_case_name; this would disable all
// tests in the test case.
- auto it = manifest.find(
- tensorflow::strings::StrCat(test_case_name, ".", test_name));
+ auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name));
if (it == manifest.end()) {
it = manifest.find(test_case_name);
if (it == manifest.end()) {
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 2f1d97b25d..21c58e075e 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -408,8 +408,12 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
return std::move(arguments);
}
-Status VerifyHloModule(HloModule* const module, bool allow_mixed_precision) {
- return HloVerifier(allow_mixed_precision).Run(module).status();
+Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
+ bool allow_mixed_precision) {
+ return HloVerifier(/*layout_sensitive=*/layout_sensitive,
+ /*allow_mixed_precision=*/allow_mixed_precision)
+ .Run(module)
+ .status();
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
index 1aca1d8ef7..277d53d423 100644
--- a/tensorflow/compiler/xla/tests/test_utils.h
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -95,8 +95,8 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
// Check that a given module satisfies various constraints before trying to
// execute it.
-Status VerifyHloModule(HloModule* const module,
- bool allow_mixed_precision = false);
+Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
+ bool allow_mixed_precision);
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc
index 2bdbd08309..c7eb9e2dbe 100644
--- a/tensorflow/compiler/xla/tests/token_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc
@@ -15,11 +15,10 @@ limitations under the License.
#include <array>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -67,7 +66,10 @@ XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) {
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42)));
module->AddEntryComputation(builder.Build());
- Status status = HloVerifier().Run(module.get()).status();
+ Status status =
+ HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
+ .Run(module.get())
+ .status();
ASSERT_IS_NOT_OK(status);
EXPECT_THAT(
status.error_message(),
@@ -84,7 +86,10 @@ XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) {
"param"));
module->AddEntryComputation(builder.Build());
- Status status = HloVerifier().Run(module.get()).status();
+ Status status =
+ HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
+ .Run(module.get())
+ .status();
ASSERT_IS_NOT_OK(status);
EXPECT_THAT(
status.error_message(),
@@ -101,7 +106,10 @@ XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) {
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(123)));
module->AddEntryComputation(builder.Build());
- Status status = HloVerifier().Run(module.get()).status();
+ Status status =
+ HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
+ .Run(module.get())
+ .status();
ASSERT_IS_NOT_OK(status);
EXPECT_THAT(status.error_message(),
::testing::HasSubstr(
diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc
index 20ae68ab74..8f80a9f3e4 100644
--- a/tensorflow/compiler/xla/tests/unary_op_test.cc
+++ b/tensorflow/compiler/xla/tests/unary_op_test.cc
@@ -190,25 +190,6 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR1) {
SignAbsTestHelper<complex64>();
}
-XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) {
- XlaBuilder builder(TestName());
- auto arg = ConstantR1<unsigned int>(
- &builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
- Abs(arg);
-
- ComputeAndCompareR1<unsigned int>(
- &builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}, {});
-}
-
-XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) {
- XlaBuilder builder(TestName());
- auto arg = ConstantR1<unsigned int>(
- &builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
- Sign(arg);
-
- ComputeAndCompareR1<unsigned int>(&builder, {1, 1, 0, 1, 1}, {});
-}
-
XLA_TEST_F(UnaryOpTest, SignAbsTestR2) {
XlaBuilder builder(TestName());
auto arg = ConstantR2<float>(&builder, {{1.0, -2.0}, {-3.0, 4.0}});
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index e12e095ecd..6a7ddd9b55 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -17,6 +17,9 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h"
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -30,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -82,8 +84,7 @@ struct ParsedProfileOutputLine {
Status ParseOneProfileOutputLine(
const string& line, bool expect_hlo,
gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results,
- tensorflow::gtl::ArraySlice<tensorflow::StringPiece> opcodes_to_ignore =
- {}) {
+ tensorflow::gtl::ArraySlice<absl::string_view> opcodes_to_ignore = {}) {
string separator = "[^:]*:: +";
string match_percentage = R"(\d+\.\d*% +\d+Σ)";
string match_cycles = R"((\d+) cycles +\( *()" + match_percentage + R"()\))";
@@ -100,7 +101,7 @@ Status ParseOneProfileOutputLine(
string match_opcode =
expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*" : "(\\[total\\])";
- string regexp_pattern = tensorflow::strings::StrCat(
+ string regexp_pattern = absl::StrCat(
" +", match_cycles, separator, match_usecs, separator, match_flops,
separator, match_trops, separator, match_bytes_per_sec, separator,
match_bytes_per_cycle, separator, match_opcode);
@@ -205,7 +206,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) {
rhs_shape);
std::vector<string> profile_output_lines =
- tensorflow::str_util::Split(profile_output, '\n');
+ absl::StrSplit(profile_output, '\n');
gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
@@ -292,22 +293,20 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) {
matrix_shape);
std::vector<string> profile_output_lines =
- tensorflow::str_util::Split(profile_output, '\n');
+ absl::StrSplit(profile_output, '\n');
auto while_body_profile_start =
- absl::c_find_if(profile_output_lines, [](tensorflow::StringPiece s) {
- return tensorflow::str_util::StartsWith(s,
- "Execution profile for body");
+ absl::c_find_if(profile_output_lines, [](absl::string_view s) {
+ return absl::StartsWith(s, "Execution profile for body");
});
ASSERT_NE(while_body_profile_start, profile_output_lines.cend());
- auto while_body_profile_end =
- std::find_if(while_body_profile_start, profile_output_lines.end(),
- [](tensorflow::StringPiece s) {
- return tensorflow::str_util::StartsWith(
- s, "********** microseconds report **********");
- });
+ auto while_body_profile_end = std::find_if(
+ while_body_profile_start, profile_output_lines.end(),
+ [](absl::string_view s) {
+ return absl::StartsWith(s, "********** microseconds report **********");
+ });
// We emit a blank line before the "********** microseconds report **********"
// line.
diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc
index a075195618..15603619b6 100644
--- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc
+++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "absl/strings/match.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -32,16 +32,14 @@ GTEST_API_ int main(int argc, char** argv) {
// If the --benchmarks flag is passed in then only run the benchmarks, not the
// tests.
for (int i = 1; i < argc; i++) {
- tensorflow::StringPiece arg(argv[i]);
- if (arg == "--benchmarks" ||
- tensorflow::str_util::StartsWith(arg, "--benchmarks=")) {
+ absl::string_view arg(argv[i]);
+ if (arg == "--benchmarks" || absl::StartsWith(arg, "--benchmarks=")) {
const char* pattern = nullptr;
- if (tensorflow::str_util::StartsWith(arg, "--benchmarks=")) {
+ if (absl::StartsWith(arg, "--benchmarks=")) {
pattern = argv[i] + strlen("--benchmarks=");
} else {
// Handle flag of the form '--benchmarks foo' (no '=').
- if (i + 1 >= argc ||
- tensorflow::str_util::StartsWith(argv[i + 1], "--")) {
+ if (i + 1 >= argc || absl::StartsWith(argv[i + 1], "--")) {
LOG(ERROR) << "--benchmarks flag requires an argument.";
return 2;
}
diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc
index 7de2c39b38..9835e3d803 100644
--- a/tensorflow/compiler/xla/text_literal_reader.cc
+++ b/tensorflow/compiler/xla/text_literal_reader.cc
@@ -21,24 +21,27 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/match.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/strip.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadPath(
- tensorflow::StringPiece path) {
- CHECK(!tensorflow::str_util::EndsWith(path, ".gz"))
+ absl::string_view path) {
+ CHECK(!absl::EndsWith(path, ".gz"))
<< "TextLiteralReader no longer supports reading .gz files";
std::unique_ptr<tensorflow::RandomAccessFile> file;
Status s =
@@ -54,33 +57,6 @@ StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadPath(
TextLiteralReader::TextLiteralReader(tensorflow::RandomAccessFile* file)
: file_(file) {}
-namespace {
-// This is an optimized version of tensorflow::str_util::Split which uses
-// StringPiece for the delimited strings and uses an out parameter for the
-// result to avoid vector creation/destruction.
-void SplitByDelimToStringPieces(tensorflow::StringPiece text, char delim,
- std::vector<tensorflow::StringPiece>* result) {
- result->clear();
-
- if (text.empty()) {
- return;
- }
-
- // The following loop is a little strange: its bound is text.size() + 1
- // instead of the more typical text.size().
- // The final iteration of the loop (when i is equal to text.size()) handles
- // the trailing token.
- size_t token_start = 0;
- for (size_t i = 0; i < text.size() + 1; i++) {
- if (i == text.size() || text[i] == delim) {
- tensorflow::StringPiece token(text.data() + token_start, i - token_start);
- result->push_back(token);
- token_start = i + 1;
- }
- }
-}
-} // namespace
-
StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
tensorflow::io::RandomAccessInputStream stream(file_.get());
tensorflow::io::BufferedInputStream buf(&stream, 65536);
@@ -90,11 +66,7 @@ StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
return s;
}
- tensorflow::StringPiece sp(shape_string);
- if (tensorflow::str_util::RemoveWhitespaceContext(&sp) > 0) {
- string tmp = std::string(sp);
- shape_string = tmp;
- }
+ absl::StripAsciiWhitespace(&shape_string);
TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string));
if (shape.element_type() != F32) {
return Unimplemented(
@@ -105,35 +77,33 @@ StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
auto result = absl::make_unique<Literal>(shape);
const float fill = std::numeric_limits<float>::quiet_NaN();
result->PopulateWithValue<float>(fill);
- std::vector<tensorflow::StringPiece> pieces;
- std::vector<tensorflow::StringPiece> coordinates;
+ std::vector<absl::string_view> pieces;
+ std::vector<absl::string_view> coordinates;
std::vector<int64> coordinate_values;
string line;
while (buf.ReadLine(&line).ok()) {
- SplitByDelimToStringPieces(line, ':', &pieces);
- tensorflow::StringPiece coordinates_string = pieces[0];
- tensorflow::StringPiece value_string = pieces[1];
- tensorflow::str_util::RemoveWhitespaceContext(&coordinates_string);
- tensorflow::str_util::RemoveWhitespaceContext(&value_string);
- if (!tensorflow::str_util::ConsumePrefix(&coordinates_string, "(")) {
+ pieces = absl::StrSplit(line, ':');
+ absl::string_view coordinates_string =
+ absl::StripAsciiWhitespace(pieces[0]);
+ absl::string_view value_string = absl::StripAsciiWhitespace(pieces[1]);
+ if (!absl::ConsumePrefix(&coordinates_string, "(")) {
return InvalidArgument(
"expected '(' at the beginning of coordinates: \"%s\"", line.c_str());
}
- if (!tensorflow::str_util::ConsumeSuffix(&coordinates_string, ")")) {
+ if (!absl::ConsumeSuffix(&coordinates_string, ")")) {
return InvalidArgument("expected ')' at the end of coordinates: \"%s\"",
line.c_str());
}
float value;
- if (!tensorflow::strings::safe_strtof(std::string(value_string).c_str(),
- &value)) {
+ if (!absl::SimpleAtof(absl::string_view(value_string), &value)) {
return InvalidArgument("could not parse value as float: \"%s\"",
- std::string(value_string).c_str());
+ string(value_string).c_str());
}
- SplitByDelimToStringPieces(coordinates_string, ',', &coordinates);
+ coordinates = absl::StrSplit(coordinates_string, ',');
coordinate_values.clear();
- for (tensorflow::StringPiece piece : coordinates) {
+ for (absl::string_view piece : coordinates) {
int64 coordinate_value;
- if (!tensorflow::strings::safe_strto64(piece, &coordinate_value)) {
+ if (!absl::SimpleAtoi(piece, &coordinate_value)) {
return InvalidArgument(
"could not parse coordinate member as int64: \"%s\"",
std::string(piece).c_str());
diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h
index 708e8c80d8..b265640802 100644
--- a/tensorflow/compiler/xla/text_literal_reader.h
+++ b/tensorflow/compiler/xla/text_literal_reader.h
@@ -18,11 +18,11 @@ limitations under the License.
#include <memory>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
@@ -41,8 +41,7 @@ class TextLiteralReader {
public:
// See class comment -- reads a file in its entirety (there must be only one
// literal in the text file path provided).
- static StatusOr<std::unique_ptr<Literal>> ReadPath(
- tensorflow::StringPiece path);
+ static StatusOr<std::unique_ptr<Literal>> ReadPath(absl::string_view path);
private:
// Ownership of file is transferred.
diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc
index 24e0784741..00147015a6 100644
--- a/tensorflow/compiler/xla/text_literal_writer.cc
+++ b/tensorflow/compiler/xla/text_literal_writer.cc
@@ -17,23 +17,23 @@ limitations under the License.
#include <string>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
-/* static */ Status TextLiteralWriter::WriteToPath(
- const Literal& literal, tensorflow::StringPiece path) {
+/* static */ Status TextLiteralWriter::WriteToPath(const Literal& literal,
+ absl::string_view path) {
std::unique_ptr<tensorflow::WritableFile> f;
- auto s = tensorflow::Env::Default()->NewWritableFile(std::string(path), &f);
+ auto s = tensorflow::Env::Default()->NewWritableFile(string(path), &f);
if (!s.ok()) {
return s;
}
@@ -51,11 +51,10 @@ namespace xla {
if (!status.ok()) {
return;
}
- string coordinates = tensorflow::strings::StrCat(
- "(", tensorflow::str_util::Join(indices, ", "), ")");
+ string coordinates =
+ absl::StrCat("(", absl::StrJoin(indices, ", "), ")");
- status = f_ptr->Append(
- tensorflow::strings::StrCat(coordinates, ": ", value, "\n"));
+ status = f_ptr->Append(absl::StrCat(coordinates, ": ", value, "\n"));
});
auto ignored = f->Close();
return status;
diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h
index 159ac1b7e1..34de8572d6 100644
--- a/tensorflow/compiler/xla/text_literal_writer.h
+++ b/tensorflow/compiler/xla/text_literal_writer.h
@@ -16,11 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
#define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@@ -37,8 +37,7 @@ namespace xla {
// This should be readable by xla::TextLiteralReader.
class TextLiteralWriter {
public:
- static Status WriteToPath(const Literal& literal,
- tensorflow::StringPiece path);
+ static Status WriteToPath(const Literal& literal, absl::string_view path);
private:
TF_DISALLOW_COPY_AND_ASSIGN(TextLiteralWriter);
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index 40d28a57bf..1e45588148 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -24,6 +24,7 @@ tf_cc_binary(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/strings",
],
)
@@ -191,6 +192,7 @@ tf_cc_binary(
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:interpreter_plugin",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc
index f0af0580c1..7aedd1da98 100644
--- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -30,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
@@ -44,10 +44,9 @@ class OperationDumper : public DfsHloVisitorWithDefault {
explicit OperationDumper(const string& path) : path_(path) {}
Status DefaultAction(HloInstruction* hlo) override {
- string params = tensorflow::str_util::Join(
+ string params = absl::StrJoin(
hlo->operands(), ", ", [](string* out, const HloInstruction* operand) {
- tensorflow::strings::StrAppend(
- out, ShapeUtil::HumanString(operand->shape()));
+ absl::StrAppend(out, ShapeUtil::HumanString(operand->shape()));
});
// Spit `op_name(params...) -> result_type :: path` to stdout.
std::cout << tensorflow::strings::Printf(
diff --git a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
index eb7bff053b..75b63c3b84 100644
--- a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
+++ b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
@@ -17,10 +17,10 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
#include "tensorflow/core/platform/env.h"
@@ -67,7 +67,7 @@ int main(int argc, char** argv) {
floats.push_back(value);
}
- tensorflow::StringPiece content(
+ tensorflow::StringPiece content( // non-absl ok
tensorflow::bit_cast<const char*>(floats.data()),
floats.size() * sizeof(float));
TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc
index e43498e381..85f05b7b8d 100644
--- a/tensorflow/compiler/xla/util.cc
+++ b/tensorflow/compiler/xla/util.cc
@@ -18,11 +18,13 @@ limitations under the License.
#include <stdarg.h>
#include <numeric>
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
@@ -54,16 +56,16 @@ ScopedLoggingTimer::~ScopedLoggingTimer() {
}
}
-Status AddStatus(Status prior, tensorflow::StringPiece context) {
+Status AddStatus(Status prior, absl::string_view context) {
CHECK(!prior.ok());
- return Status{prior.code(), tensorflow::strings::StrCat(
- context, ": ", prior.error_message())};
+ return Status{prior.code(),
+ absl::StrCat(context, ": ", prior.error_message())};
}
-Status AppendStatus(Status prior, tensorflow::StringPiece context) {
+Status AppendStatus(Status prior, absl::string_view context) {
CHECK(!prior.ok());
- return Status{prior.code(), tensorflow::strings::StrCat(prior.error_message(),
- ": ", context)};
+ return Status{prior.code(),
+ absl::StrCat(prior.error_message(), ": ", context)};
}
// Implementation note: we can't common these out (without using macros) because
@@ -146,16 +148,13 @@ Status Unavailable(const char* format, ...) {
return WithLogBacktrace(tensorflow::errors::Unavailable(message));
}
-string Reindent(tensorflow::StringPiece original,
- const tensorflow::StringPiece indentation) {
- std::vector<string> pieces = tensorflow::str_util::Split(
- tensorflow::StringPiece(original.data(), original.size()), '\n');
- return tensorflow::str_util::Join(
- pieces, "\n", [indentation](string* out, string s) {
- tensorflow::StringPiece piece(s);
- tensorflow::str_util::RemoveWhitespaceContext(&piece);
- tensorflow::strings::StrAppend(out, indentation, piece);
- });
+string Reindent(absl::string_view original,
+ const absl::string_view indentation) {
+ std::vector<string> pieces =
+ absl::StrSplit(absl::string_view(original.data(), original.size()), '\n');
+ return absl::StrJoin(pieces, "\n", [indentation](string* out, string s) {
+ absl::StrAppend(out, indentation, absl::StripAsciiWhitespace(s));
+ });
}
bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank) {
@@ -234,20 +233,20 @@ bool HasInteriorPadding(const PaddingConfig& config) {
namespace {
string HumanReadableNumOps(double flops, double nanoseconds,
- tensorflow::StringPiece op_prefix) {
+ absl::string_view op_prefix) {
if (nanoseconds == 0) {
- return tensorflow::strings::StrCat("NaN ", op_prefix, "OP/s");
+ return absl::StrCat("NaN ", op_prefix, "OP/s");
}
double nano_flops = flops / nanoseconds;
string throughput = tensorflow::strings::HumanReadableNum(
static_cast<int64>(nano_flops * 1e9));
- tensorflow::StringPiece sp(throughput);
+ absl::string_view sp(throughput);
// Use the more common "G(FLOPS)", rather than "B(FLOPS)"
- if (tensorflow::str_util::EndsWith(sp, "B") || // Ends in 'B', ignoring case
- tensorflow::str_util::EndsWith(sp, "b")) {
+ if (absl::EndsWith(sp, "B") || // Ends in 'B', ignoring case
+ absl::EndsWith(sp, "b")) {
*throughput.rbegin() = 'G';
}
- throughput += tensorflow::strings::StrCat(op_prefix, "OP/s");
+ throughput += absl::StrCat(op_prefix, "OP/s");
return throughput;
}
} // namespace
@@ -260,8 +259,7 @@ string HumanReadableNumTranscendentalOps(double trops, double nanoseconds) {
return HumanReadableNumOps(trops, nanoseconds, "TR");
}
-void LogLines(int sev, tensorflow::StringPiece text, const char* fname,
- int lineno) {
+void LogLines(int sev, absl::string_view text, const char* fname, int lineno) {
const int orig_sev = sev;
if (sev == tensorflow::FATAL) {
sev = tensorflow::ERROR;
@@ -275,7 +273,7 @@ void LogLines(int sev, tensorflow::StringPiece text, const char* fname,
size_t cur = 0;
while (cur < text.size()) {
size_t eol = text.find('\n', cur);
- if (eol == tensorflow::StringPiece::npos) {
+ if (eol == absl::string_view::npos) {
eol = text.size();
}
auto msg = text.substr(cur, eol - cur);
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index efeafbc53a..671ef17f36 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -26,16 +26,16 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -202,8 +202,8 @@ void StridedCopy(tensorflow::gtl::MutableArraySlice<D> dest, int64 dest_base,
// Adds some context information to the error message in a
// Status. This is useful as Statuses are
// propagated upwards.
-Status AddStatus(Status prior, tensorflow::StringPiece context);
-Status AppendStatus(Status prior, tensorflow::StringPiece context);
+Status AddStatus(Status prior, absl::string_view context);
+Status AppendStatus(Status prior, absl::string_view context);
// Status error shorthands -- printfs the arguments to be
// used as an error message and returns a status in the canonical
@@ -222,26 +222,26 @@ Status InvalidArgumentV(const char* format, va_list args);
template <typename... Args>
Status InvalidArgumentStrCat(Args&&... concat) {
- return InvalidArgument(
- "%s", tensorflow::strings::StrCat(std::forward<Args>(concat)...).c_str());
+ return InvalidArgument("%s",
+ absl::StrCat(std::forward<Args>(concat)...).c_str());
}
template <typename... Args>
Status UnimplementedStrCat(Args&&... concat) {
- return Unimplemented(
- "%s", tensorflow::strings::StrCat(std::forward<Args>(concat)...).c_str());
+ return Unimplemented("%s",
+ absl::StrCat(std::forward<Args>(concat)...).c_str());
}
template <typename... Args>
Status InternalErrorStrCat(Args&&... concat) {
- return InternalError(
- "%s", tensorflow::strings::StrCat(std::forward<Args>(concat)...).c_str());
+ return InternalError("%s",
+ absl::StrCat(std::forward<Args>(concat)...).c_str());
}
template <typename... Args>
Status ResourceExhaustedStrCat(Args&&... concat) {
- return ResourceExhausted(
- "%s", tensorflow::strings::StrCat(std::forward<Args>(concat)...).c_str());
+ return ResourceExhausted("%s",
+ absl::StrCat(std::forward<Args>(concat)...).c_str());
}
// Splits the lines of the original, replaces leading whitespace with the prefix
@@ -250,8 +250,7 @@ Status ResourceExhaustedStrCat(Args&&... concat) {
//
// Note: even different amounts of leading whitespace on different lines will be
// uniformly replaced with "indentation".
-string Reindent(tensorflow::StringPiece original,
- tensorflow::StringPiece indentation);
+string Reindent(absl::string_view original, absl::string_view indentation);
// Checks whether permutation is a permutation of the [0, rank) integer range.
bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank);
@@ -313,7 +312,7 @@ string CommaSeparatedString(const Container& c, const char* prefix = "",
string comma_separated = prefix;
const char* separator = "";
for (const auto& entry : c) {
- tensorflow::strings::StrAppend(&comma_separated, separator, entry);
+ absl::StrAppend(&comma_separated, separator, entry);
separator = ", ";
}
comma_separated += suffix;
@@ -395,8 +394,7 @@ string HumanReadableNumTranscendentalOps(double trops, double nanoseconds);
// Split the text into multiple lines and log each line with the given
// severity, filename, and line number.
-void LogLines(int sev, tensorflow::StringPiece text, const char* fname,
- int lineno);
+void LogLines(int sev, absl::string_view text, const char* fname, int lineno);
template <typename T>
inline bool IsPowerOfTwo(T x) {
diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc
index f11123ca24..44fb1bdc38 100644
--- a/tensorflow/compiler/xla/window_util.cc
+++ b/tensorflow/compiler/xla/window_util.cc
@@ -17,10 +17,9 @@ limitations under the License.
#include <vector>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
@@ -49,8 +48,8 @@ PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice<int64> sizes) {
}
/* static */ string ToString(const WindowDimension& dim) {
- using tensorflow::strings::StrAppend;
- using tensorflow::strings::StrCat;
+ using absl::StrAppend;
+ using absl::StrCat;
string str = StrCat("(size=", dim.size());
if (dim.stride() != 1) {
StrAppend(&str, ",stride=", dim.stride());
@@ -75,8 +74,8 @@ PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice<int64> sizes) {
}
string ToString(const Window& window) {
- using tensorflow::strings::StrAppend;
- using tensorflow::strings::StrCat;
+ using absl::StrAppend;
+ using absl::StrCat;
string str;
const auto add_field =
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index ab4328d459..66983801bf 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -181,6 +181,7 @@ cc_library(
"//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib",
"//tensorflow/contrib/coder:all_ops",
"//tensorflow/contrib/data:dataset_ops_op_lib",
+ "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
"//tensorflow/contrib/factorization:all_ops",
"//tensorflow/contrib/framework:all_ops",
"//tensorflow/contrib/hadoop:dataset_ops_op_lib",
diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py
index f7dd3183b0..8d314250a0 100644
--- a/tensorflow/contrib/autograph/converters/control_flow.py
+++ b/tensorflow/contrib/autograph/converters/control_flow.py
@@ -310,7 +310,9 @@ class ControlFlowTransformer(converter.Base):
template = """
def extra_test_name(state_ssf):
return extra_test_expr
- def body_name(iterate, state_ssf):
+ def body_name(loop_vars, state_ssf):
+ # Workaround for PEP-3113
+ iterate = loop_vars
body
return state_ssf,
state_ast_tuple = ag__.for_stmt(
diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py
index 02bc00dbc8..2a6f3cb395 100644
--- a/tensorflow/contrib/autograph/converters/control_flow_test.py
+++ b/tensorflow/contrib/autograph/converters/control_flow_test.py
@@ -217,5 +217,13 @@ class ControlFlowTest(converter_testing.TestCase):
with self.assertRaises(transformer.AutographParseError):
control_flow.transform(node, ctx)
+ def test_for_tuple_unpacking(self):
+ def test_fn(x_list):
+ z = tf.constant(0) # pylint:disable=undefined-variable
+ for i, x in enumerate(x_list):
+ z = z + x + i
+ return z
+
+ self.assertTransformedResult(test_fn, [3, 3], 7)
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/pyct/testing/BUILD b/tensorflow/contrib/autograph/pyct/testing/BUILD
index 9ef1ac9663..29a92444bb 100644
--- a/tensorflow/contrib/autograph/pyct/testing/BUILD
+++ b/tensorflow/contrib/autograph/pyct/testing/BUILD
@@ -34,8 +34,10 @@ py_test(
srcs = ["codegen_test.py"],
srcs_version = "PY2AND3",
tags = [
+ "manual",
"no_windows",
"nomsan",
+ "notap",
],
deps = [
":testing",
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 5821d51bca..5e6c1520a2 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -25,6 +25,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@Counter
@@CheckpointInputPipelineHook
@@CsvDataset
+@@LMDBDataset
@@RandomDataset
@@Reducer
@@SqlDataset
@@ -49,6 +50,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@map_and_batch
@@padded_batch_and_drop_remainder
@@parallel_interleave
+@@parse_example_dataset
@@prefetch_to_device
@@read_batch_features
@@rejection_resample
@@ -89,10 +91,12 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datase
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
+from tensorflow.contrib.data.python.ops.parsing_ops import parse_example_dataset
from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
from tensorflow.contrib.data.python.ops.random_ops import RandomDataset
from tensorflow.contrib.data.python.ops.readers import CsvDataset
+from tensorflow.contrib.data.python.ops.readers import LMDBDataset
from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset
from tensorflow.contrib.data.python.ops.readers import make_csv_dataset
from tensorflow.contrib.data.python.ops.readers import read_batch_features
diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD
index 4d1603a561..ec6cb37193 100644
--- a/tensorflow/contrib/data/kernels/BUILD
+++ b/tensorflow/contrib/data/kernels/BUILD
@@ -77,6 +77,17 @@ cc_library(
)
cc_library(
+ name = "lmdb_dataset_op",
+ srcs = ["lmdb_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@lmdb",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+)
+
+cc_library(
name = "threadpool_dataset_op",
srcs = ["threadpool_dataset_op.cc"],
deps = [
@@ -117,6 +128,7 @@ cc_library(
":directed_interleave_dataset_op",
":ignore_errors_dataset_op",
":indexed_dataset",
+ ":lmdb_dataset_op",
":prefetching_kernels",
":threadpool_dataset_op",
":unique_dataset_op",
diff --git a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc b/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc
new file mode 100644
index 0000000000..80f39992fb
--- /dev/null
+++ b/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc
@@ -0,0 +1,215 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <sys/stat.h>
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/lib/io/buffered_inputstream.h"
+#include "tensorflow/core/platform/file_system.h"
+
+#include "lmdb.h" // NOLINT(build/include)
+
+namespace tensorflow {
+namespace {
+
+class LMDBDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ const Tensor* filenames_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
+ OP_REQUIRES(
+ ctx, filenames_tensor->dims() <= 1,
+ errors::InvalidArgument("`filenames` must be a scalar or a vector."));
+
+ std::vector<string> filenames;
+ filenames.reserve(filenames_tensor->NumElements());
+ for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
+ filenames.push_back(filenames_tensor->flat<string>()(i));
+ }
+
+ *output = new Dataset(ctx, filenames);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const std::vector<string>& filenames)
+ : DatasetBase(DatasetContext(ctx)), filenames_(filenames) {}
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::LMDB")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes =
+ new DataTypeVector({DT_STRING, DT_STRING});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}, {}});
+ return *shapes;
+ }
+
+ string DebugString() const override { return "LMDBDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* filenames = nullptr;
+ TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ do {
+ if (mdb_cursor_) {
+ Tensor key_tensor(ctx->allocator({}), DT_STRING, {});
+ key_tensor.scalar<string>()() = string(
+ static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size);
+ out_tensors->emplace_back(std::move(key_tensor));
+
+ Tensor value_tensor(ctx->allocator({}), DT_STRING, {});
+ value_tensor.scalar<string>()() =
+ string(static_cast<const char*>(mdb_value_.mv_data),
+ mdb_value_.mv_size);
+ out_tensors->emplace_back(std::move(value_tensor));
+
+ int val;
+ val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT);
+ if (val != MDB_SUCCESS && val != MDB_NOTFOUND) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ if (val == MDB_NOTFOUND) {
+ ResetStreamsLocked();
+ ++current_file_index_;
+ }
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+ if (current_file_index_ == dataset()->filenames_.size()) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
+ } while (true);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ return errors::Unimplemented(
+ "Checkpointing is currently not supported for LMDBDataset.");
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ return errors::Unimplemented(
+ "Checkpointing is currently not supported for LMDBDataset.");
+ }
+
+ private:
+ Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (current_file_index_ >= dataset()->filenames_.size()) {
+ return errors::InvalidArgument(
+ "current_file_index_:", current_file_index_,
+ " >= filenames_.size():", dataset()->filenames_.size());
+ }
+ const string& filename = dataset()->filenames_[current_file_index_];
+
+ int val = mdb_env_create(&mdb_env_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK;
+
+ struct stat source_stat;
+ if (stat(filename.c_str(), &source_stat) == 0 &&
+ (source_stat.st_mode & S_IFREG)) {
+ flags |= MDB_NOSUBDIR;
+ }
+ val = mdb_env_open(mdb_env_, filename.c_str(), flags, 0664);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST);
+ if (val != MDB_SUCCESS && val != MDB_NOTFOUND) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ if (val == MDB_NOTFOUND) {
+ ResetStreamsLocked();
+ }
+ return Status::OK();
+ }
+ void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (mdb_env_ != nullptr) {
+ if (mdb_cursor_) {
+ mdb_cursor_close(mdb_cursor_);
+ mdb_cursor_ = nullptr;
+ }
+ mdb_dbi_close(mdb_env_, mdb_dbi_);
+ mdb_txn_abort(mdb_txn_);
+ mdb_env_close(mdb_env_);
+ mdb_txn_ = nullptr;
+ mdb_dbi_ = 0;
+ mdb_env_ = nullptr;
+ }
+ }
+ mutex mu_;
+ size_t current_file_index_ GUARDED_BY(mu_) = 0;
+ MDB_env* mdb_env_ GUARDED_BY(mu_) = nullptr;
+ MDB_txn* mdb_txn_ GUARDED_BY(mu_) = nullptr;
+ MDB_dbi mdb_dbi_ GUARDED_BY(mu_) = 0;
+ MDB_cursor* mdb_cursor_ GUARDED_BY(mu_) = nullptr;
+
+ MDB_val mdb_key_ GUARDED_BY(mu_);
+ MDB_val mdb_value_ GUARDED_BY(mu_);
+ };
+
+ const std::vector<string> filenames_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), LMDBDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
index cc5e250ea1..ae104d55bd 100644
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ b/tensorflow/contrib/data/ops/dataset_ops.cc
@@ -266,4 +266,13 @@ REGISTER_OP("AssertNextDataset")
return shape_inference::ScalarShape(c);
});
+REGISTER_OP("LMDBDataset")
+ .Input("filenames: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape);
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 803a3b33fa..9e2697534c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -4,7 +4,8 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test")
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "py_test")
py_test(
name = "batch_dataset_op_test",
@@ -194,6 +195,31 @@ py_test(
)
py_test(
+ name = "lmdb_dataset_op_test",
+ size = "medium",
+ srcs = ["lmdb_dataset_op_test.py"],
+ data = ["//tensorflow/core:lmdb_testdata"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/contrib/data/python/ops:readers",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:session",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "map_dataset_op_test",
size = "medium",
srcs = ["map_dataset_op_test.py"],
diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
new file mode 100644
index 0000000000..7bc582ebaa
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
@@ -0,0 +1,66 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for LMDBDatasetOp."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+
+from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+prefix_path = "tensorflow/core/lib"
+
+
+class LMDBDatasetTest(test.TestCase):
+
+ def setUp(self):
+ super(LMDBDatasetTest, self).setUp()
+ # Copy database out because we need the path to be writable to use locks.
+ path = os.path.join(prefix_path, "lmdb", "testdata", "data.mdb")
+ self.db_path = os.path.join(self.get_temp_dir(), "data.mdb")
+ shutil.copy(path, self.db_path)
+
+ def testReadFromFile(self):
+ filename = self.db_path
+
+ filenames = constant_op.constant([filename], dtypes.string)
+ num_repeats = 2
+
+ dataset = readers.LMDBDataset(filenames).repeat(num_repeats)
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for _ in range(num_repeats): # Dataset is repeated.
+ for i in range(10): # 10 records.
+ k = compat.as_bytes(str(i))
+ v = compat.as_bytes(str(chr(ord("a") + i)))
+ self.assertEqual((k, v), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
index 7b9ea191a4..4881f63ab9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
@@ -318,6 +318,19 @@ py_test(
)
py_test(
+ name = "parse_example_dataset_serialization_test",
+ size = "medium",
+ srcs = ["parse_example_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
name = "prefetch_dataset_serialization_test",
size = "small",
srcs = ["prefetch_dataset_serialization_test.py"],
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
index 9fdbcb66bf..595cecef4d 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
@@ -510,7 +510,6 @@ class DatasetSerializationTestBase(test.TestCase):
else:
init_op, get_next_op, saver = self._build_graph(
ds_fn, sparse_tensors=sparse_tensors)
- get_next_op = remove_variants(get_next_op)
return init_op, get_next_op, saver
for i in range(len(break_points) + 1):
@@ -616,29 +615,40 @@ class DatasetSerializationTestBase(test.TestCase):
# `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
# do not support tuples we flatten the tensors and restore the shape in
# `_get_iterator_ops_from_collection`.
-
- # TODO(shivaniagrwal): `output_classes` is a nested structure of classes,
- # this base class is specific to current test cases. Update when tests are
- # added with `output_classes` as a nested structure with at least one of the
- # component being `tf.SparseTensor`.
- if (sparse_tensors or
- self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor):
+ if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
ops.add_to_collection("iterator_ops", get_next.indices)
ops.add_to_collection("iterator_ops", get_next.values)
ops.add_to_collection("iterator_ops", get_next.dense_shape)
- else:
- for el in nest.flatten(get_next):
- ops.add_to_collection("iterator_ops", el)
+ return
+
+ get_next_list = nest.flatten(get_next)
+ for i, output_class in enumerate(
+ nest.flatten(self._get_output_classes(ds_fn))):
+ if output_class is sparse_tensor.SparseTensor:
+ ops.add_to_collection("iterator_ops", get_next_list[i].indices)
+ ops.add_to_collection("iterator_ops", get_next_list[i].values)
+ ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
+ else:
+ ops.add_to_collection("iterator_ops", get_next_list[i])
def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
all_ops = ops.get_collection("iterator_ops")
- if (sparse_tensors or
- self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor):
+ if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
init_op, indices, values, dense_shape = all_ops
return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
- else:
- return all_ops[0], nest.pack_sequence_as(
- self._get_output_types(ds_fn), all_ops[1:])
+ get_next_list = []
+ i = 1
+ for output_class in nest.flatten(self._get_output_classes(ds_fn)):
+ if output_class is sparse_tensor.SparseTensor:
+ indices, values, dense_shape = all_ops[i:i + 3]
+ i += 3
+ get_next_list.append(
+ sparse_tensor.SparseTensor(indices, values, dense_shape))
+ else:
+ get_next_list.append(all_ops[i])
+ i += 1
+ return all_ops[0], nest.pack_sequence_as(
+ self._get_output_types(ds_fn), get_next_list)
def _get_output_types(self, ds_fn):
with ops.Graph().as_default():
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py
new file mode 100644
index 0000000000..d3fa84e74c
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py
@@ -0,0 +1,50 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ParseExampleDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.platform import test
+
+
+class ParseExampleDatasetSerializationTest(
+ reader_dataset_ops_test_base.ReadBatchFeaturesTestBase,
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def ParseExampleDataset(self, num_repeat, batch_size):
+ return self.make_batch_feature(
+ filenames=self.test_filenames,
+ num_epochs=num_repeat,
+ batch_size=batch_size,
+ reader_num_threads=5,
+ parser_num_threads=10)
+
+ def testSerializationCore(self):
+ num_repeat = 5
+ batch_size = 2
+ num_outputs = self._num_records * self._num_files * num_repeat // batch_size
+ # pylint: disable=g-long-lambda
+ self.run_core_tests(
+ lambda: self.ParseExampleDataset(
+ num_repeat=num_repeat, batch_size=batch_size),
+ lambda: self.ParseExampleDataset(num_repeat=10, batch_size=4),
+ num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 0bd5b403e2..4b45cc7e36 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -82,7 +82,6 @@ py_library(
":interleave_ops",
":parsing_ops",
":shuffle_ops",
- ":stats_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index cca9bf6742..54a92ab185 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -216,25 +216,46 @@ def sample_from_datasets(datasets, weights=None, seed=None):
length of the `datasets` element.
"""
num_datasets = len(datasets)
- if weights is None:
- weights = dataset_ops.Dataset.from_tensors([1.0] * num_datasets).repeat()
- elif not isinstance(weights, dataset_ops.Dataset):
- weights = ops.convert_to_tensor(weights, name="weights")
- if weights.dtype not in (dtypes.float32, dtypes.float64):
- raise TypeError("`weights` must be convertible to a tensor of "
- "`tf.float32` or `tf.float64` elements.")
- if not weights.shape.is_compatible_with([num_datasets]):
- raise ValueError("`weights` must be a vector of length `len(datasets)`.")
- weights = dataset_ops.Dataset.from_tensors(weights).repeat()
-
- # The `stateless_multinomial()` op expects log-probabilities, as opposed to
- # weights.
- logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
- def select_dataset(logits, seed):
- return array_ops.squeeze(
- stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
- selector_input = dataset_ops.Dataset.zip(
- (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset)
+ if not isinstance(weights, dataset_ops.Dataset):
+ if weights is None:
+ # Select inputs with uniform probability.
+ logits = [[1.0] * num_datasets]
+ else:
+ # Use the given `weights` as the probability of choosing the respective
+ # input.
+ weights = ops.convert_to_tensor(weights, name="weights")
+ if weights.dtype not in (dtypes.float32, dtypes.float64):
+ raise TypeError("`weights` must be convertible to a tensor of "
+ "`tf.float32` or `tf.float64` elements.")
+ if not weights.shape.is_compatible_with([num_datasets]):
+ raise ValueError(
+ "`weights` must be a vector of length `len(datasets)`.")
+
+ # The `stateless_multinomial()` op expects log-probabilities, as opposed
+ # to weights.
+ logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0)
+
+ def select_dataset_constant_logits(seed):
+ return array_ops.squeeze(
+ stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
+
+ selector_input = random_ops.RandomDataset(seed).batch(2).map(
+ select_dataset_constant_logits)
+ else:
+ # Use each element of the given `weights` dataset as the probability of
+ # choosing the respective input.
+
+ # The `stateless_multinomial()` op expects log-probabilities, as opposed to
+ # weights.
+ logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
+
+ def select_dataset_varying_logits(logits, seed):
+ return array_ops.squeeze(
+ stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
+
+ selector_input = dataset_ops.Dataset.zip(
+ (logits_ds, random_ops.RandomDataset(seed).batch(2)
+ )).map(select_dataset_varying_logits)
return _DirectedInterleaveDataset(selector_input, datasets)
diff --git a/tensorflow/contrib/data/python/ops/parsing_ops.py b/tensorflow/contrib/data/python/ops/parsing_ops.py
index f868653554..2701605e64 100644
--- a/tensorflow/contrib/data/python/ops/parsing_ops.py
+++ b/tensorflow/contrib/data/python/ops/parsing_ops.py
@@ -102,8 +102,6 @@ class _ParseExampleDataset(dataset_ops.Dataset):
return self._output_classes
-# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
# TODO(b/111553342): add arguments names and example names as well.
def parse_example_dataset(features, num_parallel_calls=1):
"""A transformation that parses `Example` protos into a `dict` of tensors.
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index cafe0a4091..29005859d7 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -27,7 +27,6 @@ from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_da
from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.contrib.data.python.ops import parsing_ops
from tensorflow.contrib.data.python.ops import shuffle_ops
-from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.data.util import convert
@@ -326,7 +325,6 @@ def make_csv_dataset(
shuffle_seed=None,
prefetch_buffer_size=1,
num_parallel_reads=1,
- num_parallel_parser_calls=2,
sloppy=False,
num_rows_for_inference=100,
compression_type=None,
@@ -393,8 +391,6 @@ def make_csv_dataset(
batches consumed per training step.
num_parallel_reads: Number of threads used to read CSV records from files.
If >1, the results will be interleaved.
- num_parallel_parser_calls: Number of parallel invocations of the CSV parsing
- function on CSV records.
sloppy: If `True`, reading performance will be improved at
the cost of non-deterministic ordering. If `False`, the order of elements
produced is deterministic prior to shuffling (elements are still
@@ -503,7 +499,7 @@ def make_csv_dataset(
# indefinitely, and all batches will be full-sized.
dataset = dataset.batch(batch_size=batch_size,
drop_remainder=num_epochs is None)
- dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_parser_calls)
+ dataset = dataset.map(map_fn)
dataset = dataset.prefetch(prefetch_buffer_size)
return dataset
@@ -972,3 +968,49 @@ class SqlDataset(dataset_ops.Dataset):
@property
def output_types(self):
return self._output_types
+
+
+class LMDBDataset(dataset_ops.Dataset):
+ """A LMDB Dataset that reads the lmdb file."""
+
+ def __init__(self, filenames):
+ """Create a `LMDBDataset`.
+
+ `LMDBDataset` allows a user to read data from a mdb file as
+ (key value) pairs sequentially.
+ For example:
+ ```python
+ dataset = tf.contrib.lmdb.LMDBDataset("/foo/bar.mdb")
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+ # Prints the (key, value) pairs inside a lmdb file.
+ while True:
+ try:
+ print(sess.run(next_element))
+ except tf.errors.OutOfRangeError:
+ break
+ ```
+ Args:
+ filenames: A `tf.string` tensor containing one or more filenames.
+ """
+ super(LMDBDataset, self).__init__()
+ self._filenames = ops.convert_to_tensor(
+ filenames, dtype=dtypes.string, name="filenames")
+
+ def _as_variant_tensor(self):
+ return contrib_gen_dataset_ops.lmdb_dataset(
+ self._filenames,
+ output_types=nest.flatten(self.output_types),
+ output_shapes=nest.flatten(self.output_shapes))
+
+ @property
+ def output_classes(self):
+ return ops.Tensor, ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([]))
+
+ @property
+ def output_types(self):
+ return dtypes.string, dtypes.string
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index bcf9b3c568..8173b5d4ba 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -106,6 +106,38 @@ py_library(
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/distribute:multi_worker_util",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+cuda_py_test(
+ name = "parameter_server_strategy_test",
+ srcs = ["parameter_server_strategy_test.py"],
+ additional_deps = [
+ ":combinations",
+ ":multi_worker_test_base",
+ ":parameter_server_strategy",
+ ":values",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:session",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/distribute:multi_worker_util",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+ tags = [
+ "multi_and_single_gpu",
+ "no_pip",
],
)
@@ -239,35 +271,6 @@ py_test(
],
)
-py_test(
- name = "parameter_server_strategy_test",
- srcs = ["parameter_server_strategy_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- ],
- deps = [
- ":combinations",
- ":multi_worker_test_base",
- ":parameter_server_strategy",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:gradients",
- "//tensorflow/python:layers",
- "//tensorflow/python:session",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//tensorflow/python/eager:context",
- "//tensorflow/python/estimator:estimator_py",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
cuda_py_test(
name = "mirrored_strategy_multigpu_test",
srcs = ["mirrored_strategy_multigpu_test.py"],
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
index a411ca870e..2331444261 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -68,11 +68,11 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
self._cluster_spec = multi_worker_util.normalize_cluster_spec(
cluster_spec)
worker_device = "/job:%s/task:%d" % (task_type, task_id)
- num_workers = len(self._cluster_spec.as_dict().get(task_type, []))
- if "chief" in self._cluster_spec.as_dict():
- num_workers += 1
+ num_workers = len(self._cluster_spec.as_dict().get("worker", [])) + len(
+ self._cluster_spec.as_dict().get("chief", []))
if not num_workers:
- raise ValueError("`task_type` shoud be in `cluster_spec`.")
+ raise ValueError("No `worker` or `chief` tasks can be found in "
+ "`cluster_spec`.")
self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
task_id)
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
index c679fc8810..0d966d0e90 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
@@ -25,10 +25,8 @@ from tensorflow.contrib.distribute.python import collective_all_reduce_strategy
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import multi_worker_test_base
-from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
-from tensorflow.python.estimator import run_config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -41,52 +39,43 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-class DistributedCollectiveAllReduceStrategyTest(
- multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase):
+class CollectiveAllReduceStrategyTestBase(
+ multi_worker_test_base.MultiWorkerTestBase):
collective_key_base = 0
- @classmethod
- def setUpClass(cls):
- """Create a local cluster with 2 workers."""
- cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster(
- num_workers=3, num_ps=0)
- cls._cluster_spec = {
- run_config.TaskType.WORKER: [
- 'fake_worker_0', 'fake_worker_1', 'fake_worker_2'
- ]
- }
-
def setUp(self):
self._run_options = config_pb2.RunOptions()
self._run_options.experimental.collective_graph_key = 6
self._sess_config = config_pb2.ConfigProto()
- self._sess_config.experimental.collective_group_leader = (
- '/job:worker/replica:0/task:0')
# We use a different key_base for each test so that collective keys won't be
# reused.
# TODO(yuefengz, tucker): enable it to reuse collective keys in different
# tests.
- DistributedCollectiveAllReduceStrategyTest.collective_key_base += 100000
- super(DistributedCollectiveAllReduceStrategyTest, self).setUp()
+ CollectiveAllReduceStrategyTestBase.collective_key_base += 100000
+ super(CollectiveAllReduceStrategyTestBase, self).setUp()
def _get_test_object(self, task_type, task_id, num_gpus=0):
distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
num_gpus_per_worker=num_gpus)
- distribution.configure(
- cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id)
+ if task_type and task_id is not None:
+ distribution.configure(
+ cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id)
collective_keys = cross_tower_utils.CollectiveKeys(
group_key_start=10 * num_gpus +
- DistributedCollectiveAllReduceStrategyTest.collective_key_base,
+ CollectiveAllReduceStrategyTestBase.collective_key_base,
instance_key_start=num_gpus * 100 +
- DistributedCollectiveAllReduceStrategyTest.collective_key_base,
+ CollectiveAllReduceStrategyTestBase.collective_key_base,
instance_key_with_id_start=num_gpus * 10000 +
- DistributedCollectiveAllReduceStrategyTest.collective_key_base)
+ CollectiveAllReduceStrategyTestBase.collective_key_base)
distribution._collective_keys = collective_keys
distribution._cross_tower_ops._collective_keys = collective_keys
- return distribution, self._workers[task_id].target
+ if task_type and task_id is not None:
+ return distribution, 'grpc://' + self._cluster_spec[task_type][task_id]
+ else:
+ return distribution, ''
def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):
d, master_target = self._get_test_object(task_type, task_id, num_gpus)
@@ -154,12 +143,6 @@ class DistributedCollectiveAllReduceStrategyTest(
self.assertLess(error_after, error_before)
return error_after < error_before
- @combinations.generate(
- combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
- def testMinimizeLossGraph(self, num_gpus):
- self._run_between_graph_clients(self._test_minimize_loss_graph,
- self._cluster_spec, num_gpus)
-
def _test_variable_initialization(self, task_type, task_id, num_gpus):
distribution, master_target = self._get_test_object(task_type, task_id,
num_gpus)
@@ -184,13 +167,35 @@ class DistributedCollectiveAllReduceStrategyTest(
sess.run(
variables.global_variables_initializer(), options=self._run_options)
+
x_value, reduced_x_value = sess.run(
[x, reduced_x], options=self._run_options)
self.assertTrue(np.array_equal(x_value, reduced_x_value))
return np.array_equal(x_value, reduced_x_value)
+
+class DistributedCollectiveAllReduceStrategyTest(
+ CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 3 workers."""
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=0)
+
+ def setUp(self):
+ super(DistributedCollectiveAllReduceStrategyTest, self).setUp()
+ self._sess_config.experimental.collective_group_leader = (
+ '/job:worker/replica:0/task:0')
+
@combinations.generate(
- combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testMinimizeLossGraph(self, num_gpus):
+ self._run_between_graph_clients(self._test_minimize_loss_graph,
+ self._cluster_spec, num_gpus)
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
def testVariableInitialization(self, num_gpus):
if context.num_gpus() < num_gpus:
return
@@ -200,16 +205,46 @@ class DistributedCollectiveAllReduceStrategyTest(
num_gpus=num_gpus)
-class LocalCollectiveAllReduceStrategy(strategy_test_lib.DistributionTestBase,
- parameterized.TestCase):
+class DistributedCollectiveAllReduceStrategyTestWithChief(
+ CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 3 workers and 1 chief."""
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=0, has_chief=True)
+
+ def setUp(self):
+ super(DistributedCollectiveAllReduceStrategyTestWithChief, self).setUp()
+ self._run_options.experimental.collective_graph_key = 7
+ self._sess_config.experimental.collective_group_leader = (
+ '/job:chief/replica:0/task:0')
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testMinimizeLossGraph(self, num_gpus):
+ self._run_between_graph_clients(self._test_minimize_loss_graph,
+ self._cluster_spec, num_gpus)
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testVariableInitialization(self, num_gpus):
+ if context.num_gpus() < num_gpus:
+ return
+ self._run_between_graph_clients(
+ self._test_variable_initialization,
+ self._cluster_spec,
+ num_gpus=num_gpus)
+
+
+class LocalCollectiveAllReduceStrategy(
+ CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
def testMinimizeLossGraph(self, num_gpus=2):
# Collective ops doesn't support strategy with one device.
if context.num_gpus() < num_gpus:
return
- distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
- num_gpus_per_worker=num_gpus)
- self._test_minimize_loss_graph(distribution)
+ self._test_minimize_loss_graph(None, None, num_gpus)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
index 97c4778f0d..2ad91d56e9 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
@@ -32,7 +32,6 @@ from tensorflow.contrib.distribute.python import values as value_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import test
-from tensorflow.python.estimator import run_config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -379,12 +378,16 @@ class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase,
distribution=[
combinations.NamedDistribution(
"MirroredCPU",
- lambda: mirrored_strategy.MirroredStrategy(["/cpu:0"]),
- required_gpus=2),
+ lambda: mirrored_strategy.MirroredStrategy(num_gpus=0),
+ required_gpus=0),
combinations.NamedDistribution(
"Mirrored1GPU",
- lambda: mirrored_strategy.MirroredStrategy(["/gpu:1"]),
- required_gpus=2), combinations.mirrored_strategy_with_two_gpus
+ lambda: mirrored_strategy.MirroredStrategy(num_gpus=1),
+ required_gpus=1),
+ combinations.NamedDistribution(
+ "Mirrored2GPUs",
+ lambda: mirrored_strategy.MirroredStrategy(num_gpus=2),
+ required_gpus=2),
],
mode=["graph"])
@@ -406,13 +409,8 @@ class MultiWorkerCollectiveAllReduceTest(
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
- cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster(
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
num_workers=3, num_ps=0)
- cls._cluster_spec = {
- run_config.TaskType.WORKER: [
- "fake_worker_0", "fake_worker_1", "fake_worker_2"
- ]
- }
def setUp(self):
super(MultiWorkerCollectiveAllReduceTest, self).setUp()
@@ -446,7 +444,8 @@ class MultiWorkerCollectiveAllReduceTest(
]
else:
devices = ["/job:%s/task:%d" % (task_type, task_id)]
- return collective_all_reduce_ops, devices, self._workers[task_id].target
+ return (collective_all_reduce_ops, devices,
+ "grpc://" + self._cluster_spec[task_type][task_id])
def _assert_values_equal(self, left, right, sess):
if isinstance(left, list):
@@ -473,7 +472,8 @@ class MultiWorkerCollectiveAllReduceTest(
num_workers = 1
worker_device = None
else:
- num_workers = len(self._workers)
+ num_workers = len(self._cluster_spec.get("chief", [])) + len(
+ self._cluster_spec.get("worker", []))
worker_device = "/job:%s/task:%d" % (task_type, task_id)
with ops.Graph().as_default(), \
ops.device(worker_device), \
@@ -551,7 +551,7 @@ class MultiWorkerCollectiveAllReduceTest(
return True
@combinations.generate(
- combinations.combine(mode=["graph"], num_gpus=[0, 1, 2]))
+ combinations.combine(mode=["graph"], num_gpus=[0, 1, 2], required_gpus=1))
def testReductionDistributed(self, num_gpus):
if context.num_gpus() < num_gpus:
return
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index ecaf60f350..e87b48ba41 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -276,6 +276,9 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
else:
result = values.MirroredVariable(index, index[devices[0]], aggregation)
+ # Add the wrapped variable to the requested collections.
+ # The handling of eager mode and the global step matches
+ # ResourceVariable._init_from_args().
if not context.executing_eagerly():
g = ops.get_default_graph()
# If "trainable" is True, next_creator() will add the member variables
@@ -289,6 +292,9 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
for v in index.values():
l.remove(v)
g.add_to_collections(collections, result)
+ elif ops.GraphKeys.GLOBAL_STEP in collections:
+ ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
+
return result
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 612655a38a..ac2697958d 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -888,8 +888,18 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEquals(1.0, self.evaluate(mirrored_var))
- mirrored_var_result = self.evaluate(mirrored_var.assign_add(6.0))
+
+ # read_value == True
+ mirrored_var_result = self.evaluate(
+ mirrored_var.assign_add(6.0, read_value=True))
self.assertEquals(7.0, mirrored_var_result)
+ self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
+ self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
+
+ # read_value == False
+ self.evaluate(mirrored_var.assign_add(2.0, read_value=False))
+ self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
+ self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
@test_util.run_in_graph_and_eager_modes(config=config)
def testAssignAddMirroredVarTowerContext(self):
@@ -956,6 +966,8 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertEquals(5.0, self.evaluate(mirrored_var))
mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0))
self.assertEquals(3.0, mirrored_var_result)
+ self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
+ self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
@test_util.run_in_graph_and_eager_modes(config=config)
def testAssignSubMirroredVarTowerContext(self):
@@ -1262,5 +1274,22 @@ class MultiWorkerMirroredStrategyTest(
self._test_minimize_loss_graph(self._get_distribution_strategy())
+class MultiWorkerMirroredStrategyTestWithChief(
+ multi_worker_test_base.MultiWorkerTestBase,
+ strategy_test_lib.DistributionTestBase):
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 2 workers and 1 chief."""
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=2, num_ps=0, has_chief=True)
+ cls._default_target = "grpc://" + cls._cluster_spec["chief"][0]
+
+ def testMinimizeLossGraph(self):
+ strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus())
+ strategy.configure(cluster_spec=self._cluster_spec)
+ self._test_minimize_loss_graph(strategy)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
index 3f44ab7700..969e126956 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
@@ -62,6 +62,7 @@ class VariableCreatorStackTest(test.TestCase):
def model_fn(device_id):
assert isinstance(device_id, int)
+
def thread_creator_fn(next_creator, *args, **kwargs):
return next_creator(*args, **kwargs) + ":thread_" + str(device_id)
@@ -93,16 +94,15 @@ class MultiWorkerMirroredStrategyTest(test.TestCase):
def testDeviceScope(self):
"""Test the device scope of multi-worker MirroredStrategy."""
with context.graph_mode():
- strategy = mirrored_strategy.MirroredStrategy(
- num_gpus=context.num_gpus())
+ strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus())
strategy.configure(
cluster_spec={"worker": ["/job:worker/task:0", "/job:worker/task:1"]})
with strategy.scope():
a = constant_op.constant(1.)
- with ops.device('/cpu:0'):
+ with ops.device("/cpu:0"):
b = constant_op.constant(1.)
- self.assertEqual(a.device, '/job:worker/task:0')
- self.assertEqual(b.device, '/job:worker/task:0/device:CPU:0')
+ self.assertEqual(a.device, "/job:worker/task:0")
+ self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0")
if __name__ == "__main__":
diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
index 249de01f08..18b4503eff 100644
--- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py
+++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
@@ -23,26 +23,105 @@ import copy
import threading
import numpy as np
+_portpicker_import_error = None
+try:
+ import portpicker # pylint: disable=g-import-not-at-top
+except ImportError as _error: # pylint: disable=invalid-name
+ _portpicker_import_error = _error
+ portpicker = None
+
+# pylint: disable=g-import-not-at-top
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.estimator import run_config
from tensorflow.python.platform import test
-from tensorflow.python.framework import test_util
-
-
-def create_in_process_cluster(num_workers, num_ps):
+from tensorflow.python.training import server_lib
+
+
+def _create_cluster(num_workers,
+ num_ps,
+ has_chief=False,
+ has_eval=False,
+ protocol='grpc',
+ worker_config=None,
+ ps_config=None):
+ """Creates and starts local servers and returns the cluster_spec dict."""
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+ worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
+ ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+
+ cluster_dict = {}
+ if num_workers > 0:
+ cluster_dict['worker'] = ['localhost:%s' % port for port in worker_ports]
+ if num_ps > 0:
+ cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
+ if has_eval:
+ cluster_dict['evaluator'] = ['localhost:%s' % portpicker.pick_unused_port()]
+ if has_chief:
+ cluster_dict['chief'] = ['localhost:%s' % portpicker.pick_unused_port()]
+
+ cs = server_lib.ClusterSpec(cluster_dict)
+
+ for i in range(num_workers):
+ server_lib.Server(
+ cs,
+ job_name='worker',
+ protocol=protocol,
+ task_index=i,
+ config=worker_config,
+ start=True)
+
+ for i in range(num_ps):
+ server_lib.Server(
+ cs,
+ job_name='ps',
+ protocol=protocol,
+ task_index=i,
+ config=ps_config,
+ start=True)
+
+ if has_chief:
+ server_lib.Server(
+ cs,
+ job_name='chief',
+ protocol=protocol,
+ task_index=0,
+ config=worker_config,
+ start=True)
+
+ if has_eval:
+ server_lib.Server(
+ cs,
+ job_name='evaluator',
+ protocol=protocol,
+ task_index=0,
+ config=worker_config,
+ start=True)
+
+ return cluster_dict
+
+
+def create_in_process_cluster(num_workers,
+ num_ps,
+ has_chief=False,
+ has_eval=False):
"""Create an in-process cluster that consists of only standard server."""
# Leave some memory for cuda runtime.
- gpu_mem_frac = 0.7 / num_workers
+ gpu_mem_frac = 0.7 / (num_workers + int(has_chief) + int(has_eval))
worker_config = config_pb2.ConfigProto()
worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
# Enable collective ops which has no impact on non-collective ops.
# TODO(yuefengz, tucker): removing this after we move the initialization of
# collective mgr to the session level.
- worker_config.experimental.collective_group_leader = (
- '/job:worker/replica:0/task:0')
+ if has_chief:
+ worker_config.experimental.collective_group_leader = (
+ '/job:chief/replica:0/task:0')
+ else:
+ worker_config.experimental.collective_group_leader = (
+ '/job:worker/replica:0/task:0')
ps_config = config_pb2.ConfigProto()
ps_config.device_count['GPU'] = 0
@@ -56,9 +135,10 @@ def create_in_process_cluster(num_workers, num_ps):
# 2) there is something global in CUDA such that if we initialize CUDA in the
# parent process, the child process cannot initialize it again and thus cannot
# use GPUs (https://stackoverflow.com/questions/22950047).
- return test_util.create_local_cluster(
+ return _create_cluster(
num_workers,
num_ps=num_ps,
+ has_chief=has_chief,
worker_config=worker_config,
ps_config=ps_config,
protocol='grpc')
@@ -70,7 +150,8 @@ class MultiWorkerTestBase(test.TestCase):
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
- cls._workers, cls._ps = create_in_process_cluster(num_workers=2, num_ps=0)
+ cls._cluster_spec = create_in_process_cluster(num_workers=2, num_ps=0)
+ cls._default_target = 'grpc://' + cls._cluster_spec['worker'][0]
def setUp(self):
# We only cache the session in one test because another test may have a
@@ -111,17 +192,17 @@ class MultiWorkerTestBase(test.TestCase):
config.graph_options.rewrite_options.constant_folding = (
rewriter_config_pb2.RewriterConfig.OFF)
+ if target is None:
+ target = self._default_target
if graph is None:
if getattr(self._thread_local, 'cached_session', None) is None:
self._thread_local.cached_session = session.Session(
- graph=None, config=config, target=target or self._workers[0].target)
+ graph=None, config=config, target=target)
sess = self._thread_local.cached_session
with sess.graph.as_default(), sess.as_default():
yield sess
else:
- with session.Session(
- graph=graph, config=config, target=target or
- self._workers[0].target) as sess:
+ with session.Session(graph=graph, config=config, target=target) as sess:
yield sess
def _run_client(self, client_fn, task_type, task_id, num_gpus, *args,
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
index 96b6519bc4..361c8be590 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
@@ -22,6 +22,7 @@ from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import values
from tensorflow.python.distribute import multi_worker_util
+from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -94,11 +95,18 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
cluster configurations.
task_type: the current task type.
task_id: the current task id.
+
+ Raises:
+ ValueError: if `cluster_spec` is given but `task_type` or `task_id` is
+ not.
"""
super(ParameterServerStrategy, self).__init__()
self._num_gpus_per_worker = num_gpus_per_worker
if cluster_spec:
cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
+ if task_type is None or task_id is None:
+ raise ValueError("When `cluster_spec` is given, must also specify "
+ "`task_type` and `task_id`.")
self._cluster_spec = cluster_spec
# We typically don't need to do all-reduce in this strategy.
@@ -233,8 +241,35 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
" for variable: " + kwargs["name"])
def var_creator(*args, **kwargs):
+ # Record what collections this variable should be added to.
+ collections = kwargs.pop("collections", None)
+ if collections is None:
+ collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ kwargs["collections"] = []
+
+ # Create and wrap the variable.
v = next_creator(*args, **kwargs)
- return values.AggregatingVariable(v, aggregation)
+ wrapped = values.AggregatingVariable(v, aggregation)
+
+ # Add the wrapped variable to the requested collections.
+ # The handling of eager mode and the global step matches
+ # ResourceVariable._init_from_args().
+ if not context.executing_eagerly():
+ g = ops.get_default_graph()
+ # If "trainable" is True, next_creator() will add the contained
+ # variable to the TRAINABLE_VARIABLES collection, so we manually
+ # remove it and replace with the wrapper. We can't set "trainable"
+ # to False for next_creator() since that causes functions like
+ # implicit_gradients to skip those variables.
+ if kwargs.get("trainable", True):
+ collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l.remove(v)
+ g.add_to_collections(collections, wrapped)
+ elif ops.GraphKeys.GLOBAL_STEP in collections:
+ ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped)
+
+ return wrapped
else:
var_creator = next_creator
@@ -345,6 +380,10 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
cluster configurations.
task_type: the current task type.
task_id: the current task id.
+
+ Raises:
+ ValueError: if `cluster_spec` is given but `task_type` or `task_id` is
+ not.
"""
del session_config
@@ -353,6 +392,9 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
if not self._cluster_spec and cluster_spec:
self._cluster_spec = multi_worker_util.normalize_cluster_spec(
cluster_spec)
+ if task_type is None or task_id is None:
+ raise ValueError("When `cluster_spec` is given, must also specify "
+ "`task_type` and `task_id`.")
self._initialize_devices(self._num_gpus_per_worker, self._cluster_spec,
task_type, task_id)
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index adfe3e8b02..0e2bfcec5f 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -24,6 +24,8 @@ from absl.testing import parameterized
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import parameter_server_strategy
+from tensorflow.contrib.distribute.python import values
+from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.eager import context
from tensorflow.python.estimator import run_config
from tensorflow.python.framework import constant_op
@@ -37,21 +39,15 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.training import training_util
+CHIEF = run_config.TaskType.CHIEF
+WORKER = run_config.TaskType.WORKER
+PS = run_config.TaskType.PS
-class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
- parameterized.TestCase):
- @classmethod
- def setUpClass(cls):
- cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster(
- num_workers=3, num_ps=2)
- cls._cluster_spec = {
- run_config.TaskType.WORKER: [
- 'fake_worker_0', 'fake_worker_1', 'fake_worker_2'
- ],
- run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1']
- }
+class ParameterServerStrategyTestBase(
+ multi_worker_test_base.MultiWorkerTestBase):
def setUp(self):
self._result = 0
@@ -60,7 +56,7 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self._init_reached = 0
self._finish_condition = threading.Condition()
self._finish_reached = 0
- super(ParameterServerStrategyTest, self).setUp()
+ super(ParameterServerStrategyTestBase, self).setUp()
def _get_test_objects(self, task_type, task_id, num_gpus):
distribution = parameter_server_strategy.ParameterServerStrategy(
@@ -70,13 +66,13 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
distribution.configure(
cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id)
- return distribution, self._workers[task_id].target
+ return distribution, 'grpc://' + self._cluster_spec[WORKER][task_id]
def _test_device_assignment_distributed(self, task_type, task_id, num_gpus):
worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id)
d, _ = self._get_test_objects(task_type, task_id, num_gpus)
with ops.Graph().as_default(), \
- self.test_session(target=self._workers[0].target) as sess, \
+ self.test_session(target=self._default_target) as sess, \
d.scope():
# Define a variable outside the call_for_each_tower scope. This is not
@@ -172,18 +168,13 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self.assertEqual(z_val, 43.0)
self.assertEqual(f_val, 46.0)
- @combinations.generate(
- combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
- def testDeviceAssignmentDistributed(self, num_gpus):
- self._test_device_assignment_distributed('worker', 1, num_gpus)
-
def _test_device_assignment_local(self,
d,
compute_device='CPU',
variable_device='CPU',
num_gpus=0):
with ops.Graph().as_default(), \
- self.test_session(target=self._workers[0].target) as sess, \
+ self.test_session(target=self._default_target) as sess, \
d.scope():
def model_fn():
@@ -276,29 +267,12 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self.assertEqual(z_val, 43.0)
self.assertEqual(f_val, 46.0)
- def testDeviceAssignmentLocalCPU(self):
- distribution = parameter_server_strategy.ParameterServerStrategy(
- num_gpus_per_worker=0)
- self._test_device_assignment_local(
- distribution, compute_device='CPU', variable_device='CPU', num_gpus=0)
-
- def testDeviceAssignmentLocalOneGPU(self):
- distribution = parameter_server_strategy.ParameterServerStrategy(
- num_gpus_per_worker=1)
- self._test_device_assignment_local(
- distribution, compute_device='GPU', variable_device='GPU', num_gpus=1)
-
- def testDeviceAssignmentLocalTwoGPUs(self):
- distribution = parameter_server_strategy.ParameterServerStrategy(
- num_gpus_per_worker=2)
- self._test_device_assignment_local(
- distribution, compute_device='GPU', variable_device='CPU', num_gpus=2)
-
def _test_simple_increment(self, task_type, task_id, num_gpus):
d, master_target = self._get_test_objects(task_type, task_id, num_gpus)
if hasattr(d, '_cluster_spec') and d._cluster_spec:
- num_workers = len(d._cluster_spec.as_dict().get('worker',
- ['dummy_worker']))
+ num_workers = len(d._cluster_spec.as_dict().get(WORKER))
+ if 'chief' in d._cluster_spec.as_dict():
+ num_workers += 1
else:
num_workers = 1
with ops.Graph().as_default(), \
@@ -357,6 +331,11 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):
d, master_target = self._get_test_objects(task_type, task_id, num_gpus)
+ assert hasattr(d, '_cluster_spec') and d._cluster_spec
+ num_workers = len(d._cluster_spec.as_dict().get(WORKER))
+ if CHIEF in d._cluster_spec.as_dict():
+ num_workers += 1
+
with ops.Graph().as_default(), \
self.test_session(target=master_target) as sess, \
d.scope():
@@ -405,13 +384,13 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
if context.num_gpus() < d._num_gpus_per_worker:
return True
- if task_id == 0:
+ if multi_worker_util.is_chief(d._cluster_spec, task_type, task_id):
variables.global_variables_initializer().run()
# Workers waiting for chief worker's initializing variables.
self._init_condition.acquire()
self._init_reached += 1
- while self._init_reached != 3:
+ while self._init_reached != num_workers:
self._init_condition.wait()
self._init_condition.notify_all()
self._init_condition.release()
@@ -428,9 +407,42 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self.assertLess(error_after, error_before)
return error_after < error_before
+
+class ParameterServerStrategyTest(ParameterServerStrategyTestBase,
+ parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=2)
+ cls._default_target = 'grpc://' + cls._cluster_spec[WORKER][0]
+
+ def testDeviceAssignmentLocalCPU(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=0)
+ self._test_device_assignment_local(
+ distribution, compute_device='CPU', variable_device='CPU', num_gpus=0)
+
+ def testDeviceAssignmentLocalOneGPU(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=1)
+ self._test_device_assignment_local(
+ distribution, compute_device='GPU', variable_device='GPU', num_gpus=1)
+
+ def testDeviceAssignmentLocalTwoGPUs(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=2)
+ self._test_device_assignment_local(
+ distribution, compute_device='GPU', variable_device='CPU', num_gpus=2)
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
+ def testDeviceAssignmentDistributed(self, num_gpus):
+ self._test_device_assignment_distributed('worker', 1, num_gpus)
+
def testSimpleBetweenGraph(self):
self._run_between_graph_clients(self._test_simple_increment,
- self._cluster_spec, 0)
+ self._cluster_spec, context.num_gpus())
@combinations.generate(
combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
@@ -444,5 +456,38 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self._cluster_spec, num_gpus)
+class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
+ parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=2, has_chief=True)
+ cls._default_target = 'grpc://' + cls._cluster_spec[CHIEF][0]
+
+ def testSimpleBetweenGraph(self):
+ self._run_between_graph_clients(self._test_simple_increment,
+ self._cluster_spec, context.num_gpus())
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
+ def testMinimizeLossGraph(self, num_gpus):
+ self._run_between_graph_clients(self._test_minimize_loss_graph,
+ self._cluster_spec, num_gpus)
+
+ def testGlobalStepIsWrapped(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=2)
+ with ops.Graph().as_default(), distribution.scope():
+ created_step = training_util.create_global_step()
+ get_step = training_util.get_global_step()
+ self.assertEqual(created_step, get_step,
+ msg=('created_step %s type %s vs. get_step %s type %s' %
+ (id(created_step), created_step.__class__.__name__,
+ id(get_step), get_step.__class__.__name__)))
+ self.assertIs(values.AggregatingVariable, type(created_step))
+ self.assertIs(values.AggregatingVariable, type(get_step))
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index a486003076..6202a0750a 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -59,7 +59,7 @@ def get_tpu_system_metadata(tpu_cluster_resolver):
class TPUStrategy(one_device_strategy.OneDeviceStrategy):
"""Experimental TPU distribution strategy implementation."""
- def __init__(self, tpu_cluster_resolver, steps_per_run):
+ def __init__(self, tpu_cluster_resolver, steps_per_run, num_cores=None):
"""Initializes the TPUStrategy object.
Args:
@@ -70,6 +70,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
metrics, summaries etc.
This parameter is only used when Distribution Strategy is used with
estimator or keras.
+ num_cores: Number of cores to use on the TPU. If None specified, then
+ auto-detect the cores and topology of the TPU system.
"""
# TODO(isaprykin): Generalize the defaults. They are currently tailored for
# the unit test.
@@ -77,13 +79,15 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
self._tpu_cluster_resolver = tpu_cluster_resolver
self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver)
+ self._num_cores_override = num_cores
- # TODO(priyag): This should not be hardcoded here.
- self._host = '/device:CPU:0'
# TODO(sourabhbajaj): Remove this once performance of running one step
# at a time is comparable to multiple steps.
self.steps_per_run = steps_per_run
+ # TODO(frankchn): This should not be hardcoded here for pod purposes.
+ self._host = self.tpu_host_cpu_device(0)
+
def distribute_dataset(self, dataset_fn):
# TODO(priyag): Perhaps distribute across cores here.
return self._call_dataset_fn(dataset_fn)
@@ -106,6 +110,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
"""Enqueue ops for one iteration."""
control_deps = []
sharded_inputs = []
+ # TODO(sourabhbajaj): Add support for TPU pods
with ops.device(self._host):
for _ in range(self.num_towers):
# Use control dependencies to ensure a deterministic ordering.
@@ -258,4 +263,10 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
@property
def num_towers(self):
- return self._tpu_metadata.num_of_cores_per_host
+ return self._num_cores_override or self._tpu_metadata.num_cores
+
+ def tpu_host_cpu_device(self, host_id):
+ if self._tpu_cluster_resolver.get_master() in ('', 'local'):
+ return '/replica:0/task:0/device:CPU:0'
+ return '/job:%s/task:%d/device:CPU:0' % ('tpu_worker', host_id)
+
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index a58bb3a849..e73d9c193e 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -183,6 +183,14 @@ class Mirrored(DistributedDelegate):
return self._index[device]
return list(self._index.values())[0]
+ def _as_graph_element(self):
+ obj = self.get()
+ # pylint: disable=protected-access
+ conv_fn = getattr(obj, "_as_graph_element", None)
+ if conv_fn and callable(conv_fn):
+ return conv_fn()
+ return obj
+
def _assign_on_device(device, variable, tensor):
with ops.device(device):
@@ -354,8 +362,19 @@ class MirroredVariable(DistributedVariable, Mirrored,
# We are calling assign on the mirrored variable in cross tower context,
# use update to update the variable.
- return distribution_strategy_context.get_distribution_strategy().update(
- self, f, *args, **kwargs)
+ strategy = distribution_strategy_context.get_distribution_strategy()
+ updates = strategy.update(self, f, *args, **kwargs)
+ grouped = strategy.group(updates)
+ if isinstance(updates, DistributedValues) and updates.is_tensor_like:
+ # Make sure we run all updates. Without this, something like
+ # session.run(mirrored_var.assign*(...)) may only update one tower.
+ index = {}
+ for d in updates.devices:
+ with ops.device(d), ops.control_dependencies([grouped]):
+ index[d] = array_ops.identity(updates.get(d))
+ return Mirrored(index)
+ else:
+ return grouped
else:
_assert_tower_context()
# We are calling an assign function on the mirrored variable in tower
@@ -1180,6 +1199,10 @@ class AggregatingVariable(checkpointable.CheckpointableBase):
def __repr__(self):
return repr(self._v)
+ def _should_act_as_resource_variable(self):
+ """Pass resource_variable_ops.is_resource_variable check."""
+ pass
+
# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py
index 49a9afe3f6..31ee36f024 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
+@test_util.run_all_in_graph_and_eager_modes
class MatrixInverseTriLBijectorTest(test.TestCase):
"""Tests the correctness of the Y = inv(tril) transformation."""
@@ -40,7 +41,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
y[idx][np.triu_indices(y[idx].shape[-1], 1)] = 0
return y
- @test_util.run_in_graph_and_eager_modes
def testComputesCorrectValues(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
self.assertEqual("matrix_inverse_tril", inv.name)
@@ -62,7 +62,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
self.assertNear(expected_fldj_, fldj_, err=1e-3)
self.assertNear(-expected_fldj_, ildj_, err=1e-3)
- @test_util.run_in_graph_and_eager_modes
def testOneByOneMatrix(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.array([[5.]], dtype=np.float32)
@@ -81,7 +80,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
self.assertNear(expected_fldj_, fldj_, err=1e-3)
self.assertNear(-expected_fldj_, ildj_, err=1e-3)
- @test_util.run_in_graph_and_eager_modes
def testZeroByZeroMatrix(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.eye(0, dtype=np.float32)
@@ -100,7 +98,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
self.assertNear(expected_fldj_, fldj_, err=1e-3)
self.assertNear(-expected_fldj_, ildj_, err=1e-3)
- @test_util.run_in_graph_and_eager_modes
def testBatch(self):
# Test batch computation with input shape (2, 1, 2, 2), i.e. batch shape
# (2, 1).
@@ -125,20 +122,18 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
self.assertAllClose(expected_fldj_, fldj_, atol=0., rtol=1e-3)
self.assertAllClose(-expected_fldj_, ildj_, atol=0., rtol=1e-3)
- @test_util.run_in_graph_and_eager_modes
def testErrorOnInputRankTooLow(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.array([0.1], dtype=np.float32)
rank_error_msg = "must have rank at least 2"
- with self.test_session():
- with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
- inv.forward(x_).eval()
- with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
- inv.inverse(x_).eval()
- with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
- inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
- with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
- inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()
+ with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
+ self.evaluate(inv.forward(x_))
+ with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
+ self.evaluate(inv.inverse(x_))
+ with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
+ self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2))
+ with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
+ self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2))
# TODO(b/80481923): Figure out why these assertions fail, and fix them.
## def testErrorOnInputNonSquare(self):
@@ -146,55 +141,50 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
## x_ = np.array([[1., 2., 3.],
## [4., 5., 6.]], dtype=np.float32)
## square_error_msg = "must be a square matrix"
- ## with self.test_session():
- ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- ## square_error_msg):
- ## inv.forward(x_).eval()
- ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- ## square_error_msg):
- ## inv.inverse(x_).eval()
- ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- ## square_error_msg):
- ## inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
- ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- ## square_error_msg):
- ## inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()
-
- @test_util.run_in_graph_and_eager_modes
+ ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ ## square_error_msg):
+ ## self.evaluate(inv.forward(x_))
+ ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ ## square_error_msg):
+ ## self.evaluate(inv.inverse(x_))
+ ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ ## square_error_msg):
+ ## self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2))
+ ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ ## square_error_msg):
+ ## self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2))
+
def testErrorOnInputNotLowerTriangular(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.array([[1., 2.],
[3., 4.]], dtype=np.float32)
triangular_error_msg = "must be lower triangular"
- with self.test_session():
- with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- triangular_error_msg):
- inv.forward(x_).eval()
- with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- triangular_error_msg):
- inv.inverse(x_).eval()
- with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- triangular_error_msg):
- inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
- with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- triangular_error_msg):
- inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()
-
- @test_util.run_in_graph_and_eager_modes
+ with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ triangular_error_msg):
+ self.evaluate(inv.forward(x_))
+ with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ triangular_error_msg):
+ self.evaluate(inv.inverse(x_))
+ with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ triangular_error_msg):
+ self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2))
+ with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ triangular_error_msg):
+ self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2))
+
def testErrorOnInputSingular(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.array([[1., 0.],
[0., 0.]], dtype=np.float32)
nonsingular_error_msg = "must have all diagonal entries nonzero"
- with self.test_session():
- with self.assertRaisesOpError(nonsingular_error_msg):
- inv.forward(x_).eval()
- with self.assertRaisesOpError(nonsingular_error_msg):
- inv.inverse(x_).eval()
- with self.assertRaisesOpError(nonsingular_error_msg):
- inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
- with self.assertRaisesOpError(nonsingular_error_msg):
- inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()
+ with self.assertRaisesOpError(nonsingular_error_msg):
+ self.evaluate(inv.forward(x_))
+ with self.assertRaisesOpError(nonsingular_error_msg):
+ self.evaluate(inv.inverse(x_))
+ with self.assertRaisesOpError(nonsingular_error_msg):
+ self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2))
+ with self.assertRaisesOpError(nonsingular_error_msg):
+ self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2))
if __name__ == "__main__":
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py
index a188843952..9a88f8f1bc 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py
@@ -38,23 +38,22 @@ class OrderedBijectorTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testBijectorVector(self):
- with self.cached_session():
- ordered = Ordered()
- self.assertEqual("ordered", ordered.name)
- x = np.asarray([[2., 3, 4], [4., 8, 13]])
- y = [[2., 0, 0], [4., np.log(4.), np.log(5.)]]
- self.assertAllClose(y, self.evaluate(ordered.forward(x)))
- self.assertAllClose(x, self.evaluate(ordered.inverse(y)))
- self.assertAllClose(
- np.sum(np.asarray(y)[..., 1:], axis=-1),
- self.evaluate(ordered.inverse_log_det_jacobian(y, event_ndims=1)),
- atol=0.,
- rtol=1e-7)
- self.assertAllClose(
- self.evaluate(-ordered.inverse_log_det_jacobian(y, event_ndims=1)),
- self.evaluate(ordered.forward_log_det_jacobian(x, event_ndims=1)),
- atol=0.,
- rtol=1e-7)
+ ordered = Ordered()
+ self.assertEqual("ordered", ordered.name)
+ x = np.asarray([[2., 3, 4], [4., 8, 13]])
+ y = [[2., 0, 0], [4., np.log(4.), np.log(5.)]]
+ self.assertAllClose(y, self.evaluate(ordered.forward(x)))
+ self.assertAllClose(x, self.evaluate(ordered.inverse(y)))
+ self.assertAllClose(
+ np.sum(np.asarray(y)[..., 1:], axis=-1),
+ self.evaluate(ordered.inverse_log_det_jacobian(y, event_ndims=1)),
+ atol=0.,
+ rtol=1e-7)
+ self.assertAllClose(
+ self.evaluate(-ordered.inverse_log_det_jacobian(y, event_ndims=1)),
+ self.evaluate(ordered.forward_log_det_jacobian(x, event_ndims=1)),
+ atol=0.,
+ rtol=1e-7)
def testBijectorUnknownShape(self):
with self.cached_session():
@@ -84,18 +83,17 @@ class OrderedBijectorTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testShapeGetters(self):
- with self.cached_session():
- x = tensor_shape.TensorShape([4])
- y = tensor_shape.TensorShape([4])
- bijector = Ordered(validate_args=True)
- self.assertAllEqual(y, bijector.forward_event_shape(x))
- self.assertAllEqual(y.as_list(),
- self.evaluate(bijector.forward_event_shape_tensor(
- x.as_list())))
- self.assertAllEqual(x, bijector.inverse_event_shape(y))
- self.assertAllEqual(x.as_list(),
- self.evaluate(bijector.inverse_event_shape_tensor(
- y.as_list())))
+ x = tensor_shape.TensorShape([4])
+ y = tensor_shape.TensorShape([4])
+ bijector = Ordered(validate_args=True)
+ self.assertAllEqual(y, bijector.forward_event_shape(x))
+ self.assertAllEqual(y.as_list(),
+ self.evaluate(bijector.forward_event_shape_tensor(
+ x.as_list())))
+ self.assertAllEqual(x, bijector.inverse_event_shape(y))
+ self.assertAllEqual(x.as_list(),
+ self.evaluate(bijector.inverse_event_shape_tensor(
+ y.as_list())))
def testBijectiveAndFinite(self):
with self.cached_session():
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
index d0098c3c10..8dad80aa64 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
@@ -43,16 +43,15 @@ class SoftsignBijectorTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testBijectorBounds(self):
bijector = Softsign(validate_args=True)
- with self.test_session():
- with self.assertRaisesOpError("greater than -1"):
- bijector.inverse(-3.).eval()
- with self.assertRaisesOpError("greater than -1"):
- bijector.inverse_log_det_jacobian(-3., event_ndims=0).eval()
-
- with self.assertRaisesOpError("less than 1"):
- bijector.inverse(3.).eval()
- with self.assertRaisesOpError("less than 1"):
- bijector.inverse_log_det_jacobian(3., event_ndims=0).eval()
+ with self.assertRaisesOpError("greater than -1"):
+ self.evaluate(bijector.inverse(-3.))
+ with self.assertRaisesOpError("greater than -1"):
+ self.evaluate(bijector.inverse_log_det_jacobian(-3., event_ndims=0))
+
+ with self.assertRaisesOpError("less than 1"):
+ self.evaluate(bijector.inverse(3.))
+ with self.assertRaisesOpError("less than 1"):
+ self.evaluate(bijector.inverse_log_det_jacobian(3., event_ndims=0))
@test_util.run_in_graph_and_eager_modes
def testBijectorForwardInverse(self):
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
index f7b2efa7bc..05f5d30666 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
@@ -542,9 +542,9 @@ class PadDynamicTest(_PadTest, test.TestCase):
return False
+@test_util.run_all_in_graph_and_eager_modes
class TestMoveDimension(test.TestCase):
- @test_util.run_in_graph_and_eager_modes
def test_move_dimension_static_shape(self):
x = random_ops.random_normal(shape=[200, 30, 4, 1, 6])
@@ -561,7 +561,6 @@ class TestMoveDimension(test.TestCase):
x_perm = distribution_util.move_dimension(x, 4, 2)
self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 6, 4, 1])
- @test_util.run_in_graph_and_eager_modes
def test_move_dimension_dynamic_shape(self):
x_ = random_ops.random_normal(shape=[200, 30, 4, 1, 6])
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD
index fa3f1bb7ad..84517b57c7 100644
--- a/tensorflow/contrib/eager/python/BUILD
+++ b/tensorflow/contrib/eager/python/BUILD
@@ -14,6 +14,7 @@ py_library(
":datasets",
":metrics",
":network",
+ ":remote",
":saver",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
@@ -223,11 +224,24 @@ py_test(
],
)
+py_library(
+ name = "remote",
+ srcs = ["remote.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:platform",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
py_test(
name = "remote_test",
srcs = ["remote_test.py"],
srcs_version = "PY2AND3",
deps = [
+ ":remote",
"//tensorflow/contrib/eager/python:tfe",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
index a28bc8a43d..3f70f573b1 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
@@ -272,8 +272,8 @@ class ResNet50(tf.keras.Model):
else:
self.global_pooling = None
- def call(self, input_tensor, training):
- x = self.conv1(input_tensor)
+ def call(self, inputs, training=True):
+ x = self.conv1(inputs)
x = self.bn_conv1(x, training=training)
x = tf.nn.relu(x)
x = self.max_pool(x)
diff --git a/tensorflow/contrib/eager/python/remote.py b/tensorflow/contrib/eager/python/remote.py
new file mode 100644
index 0000000000..b74cf394f6
--- /dev/null
+++ b/tensorflow/contrib/eager/python/remote.py
@@ -0,0 +1,73 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Helpers to connect to remote servers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.core.protobuf.cluster_pb2 import ClusterDef
+from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
+from tensorflow.python.eager import context
+
+
+def connect_to_remote_host(remote_host=None, job_name="worker"):
+ """Connects to a single machine to enable remote execution on it.
+
+ Will make devices on the remote host available to use. Note that calling this
+ more than once will work, but will invalidate any tensor handles on the old
+ remote devices.
+
+ Using the default job_name of worker, you can schedule ops to run remotely as
+ follows:
+ ```python
+ # Enable eager execution, and connect to the remote host.
+ tf.enable_eager_execution()
+ tf.contrib.eager.connect_to_remote_host("exampleaddr.com:9876")
+
+ with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
+ # The following tensors should be resident on the remote device, and the op
+ # will also execute remotely.
+ x1 = array_ops.ones([2, 2])
+ x2 = array_ops.ones([2, 2])
+ y = math_ops.matmul(x1, x2)
+ ```
+
+ Args:
+ remote_host: The addr of the remote server in host-port format.
+ job_name: The job name under which the new server will be accessible.
+
+ Raises:
+ ValueError: if remote_host is None.
+ """
+ if remote_host is None:
+ raise ValueError("Must provide an remote_host")
+ cluster_def = ClusterDef()
+ job_def = cluster_def.job.add()
+ job_def.name = job_name
+ job_def.tasks[0] = "127.0.0.1:0"
+ job_def.tasks[1] = remote_host
+
+ server_def = ServerDef(
+ cluster=cluster_def,
+ job_name=job_name,
+ task_index=0,
+ protocol="grpc")
+
+ # TODO(nareshmodi): Make this default since it works in more situations.
+ os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
+ context.set_server_def(server_def)
diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py
index 76f48eeb1c..13029db975 100644
--- a/tensorflow/contrib/eager/python/remote_test.py
+++ b/tensorflow/contrib/eager/python/remote_test.py
@@ -23,6 +23,7 @@ import os
import numpy as np
+from tensorflow.contrib.eager.python import remote
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python.eager import backprop
@@ -85,6 +86,7 @@ class RemoteExecutionTest(test.TestCase):
self._cached_server1_target = self._cached_server1.target[len("grpc://"):]
self._cached_server2_target = self._cached_server2.target[len("grpc://"):]
+ def setUp(self):
# Start the local server.
context.set_server_def(
server_def=get_server_def(
@@ -172,6 +174,17 @@ class RemoteExecutionTest(test.TestCase):
y = math_ops.matmul(x1, x1)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
+ @run_sync_and_async
+ def testConnectToRemoteServer(self):
+ """Basic server connection."""
+ remote.connect_to_remote_host(self._cached_server1_target)
+
+ with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
+ x1 = array_ops.ones([2, 2])
+ x2 = array_ops.ones([2, 2])
+ y = math_ops.matmul(x1, x2)
+ np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
+
if __name__ == "__main__":
ops.enable_eager_execution()
diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py
index 4dfd083443..fe7f1b72fc 100644
--- a/tensorflow/contrib/eager/python/tfe.py
+++ b/tensorflow/contrib/eager/python/tfe.py
@@ -74,6 +74,8 @@ To use, at program startup, call `tf.enable_eager_execution()`.
@@TensorSpec
+@@connect_to_cloud_tpu
+
@@DEVICE_PLACEMENT_EXPLICIT
@@DEVICE_PLACEMENT_WARN
@@DEVICE_PLACEMENT_SILENT
@@ -94,6 +96,7 @@ from tensorflow.contrib.eager.python.network import Network
from tensorflow.contrib.eager.python.network import Sequential
from tensorflow.contrib.eager.python.network import save_network_checkpoint
from tensorflow.contrib.eager.python.network import restore_network_checkpoint
+from tensorflow.contrib.eager.python.remote import connect_to_remote_host
from tensorflow.contrib.eager.python.saver import get_optimizer_variables
from tensorflow.contrib.eager.python.saver import restore_variables_on_create
from tensorflow.contrib.eager.python.saver import Saver
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 45a0ded7eb..458a50f25c 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -293,6 +293,7 @@ def generated_test_models():
"topk",
"transpose",
#"transpose_conv", # disabled due to b/111213074
+ "unpack",
"where",
]
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index 70178b2faa..e81f9e4f51 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -286,6 +286,11 @@ typedef struct {
int axis;
} TfLiteOneHotParams;
+typedef struct {
+ int num;
+ int axis;
+} TfLiteUnpackParams;
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
index 32b1cfd2d8..c39013bb42 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
@@ -2434,7 +2434,8 @@ class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
}
};
-TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
+TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
+ DISABLED_LstmBlackBoxTest) {
const int n_batch = 1;
const int n_input = 2;
// n_cell and n_output have the same size when there is no projection.
@@ -2541,7 +2542,8 @@ class CifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
}
};
-TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
+TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest,
+ DISABLED_LstmBlackBoxTest) {
const int n_batch = 1;
const int n_input = 2;
// n_cell and n_output have the same size when there is no projection.
@@ -3200,7 +3202,7 @@ class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
}
};
-TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
+TEST_F(NoCifgPeepholeProjectionClippingLstmTest, DISABLED_LstmBlackBoxTest) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md
index 776803da8c..f255017ad9 100644
--- a/tensorflow/contrib/lite/g3doc/apis.md
+++ b/tensorflow/contrib/lite/g3doc/apis.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# TensorFlow Lite APIs
diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md
index d979353bb3..ee6150b60e 100644
--- a/tensorflow/contrib/lite/g3doc/custom_operators.md
+++ b/tensorflow/contrib/lite/g3doc/custom_operators.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# How to use custom operators
diff --git a/tensorflow/contrib/lite/g3doc/demo_android.md b/tensorflow/contrib/lite/g3doc/demo_android.md
index d79a2696b4..c38b928684 100644
--- a/tensorflow/contrib/lite/g3doc/demo_android.md
+++ b/tensorflow/contrib/lite/g3doc/demo_android.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Android Demo App
diff --git a/tensorflow/contrib/lite/g3doc/demo_ios.md b/tensorflow/contrib/lite/g3doc/demo_ios.md
index a554898899..7579ad84a0 100644
--- a/tensorflow/contrib/lite/g3doc/demo_ios.md
+++ b/tensorflow/contrib/lite/g3doc/demo_ios.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# iOS Demo App
diff --git a/tensorflow/contrib/lite/g3doc/devguide.md b/tensorflow/contrib/lite/g3doc/devguide.md
index dc9cc98c08..90e7915c52 100644
--- a/tensorflow/contrib/lite/g3doc/devguide.md
+++ b/tensorflow/contrib/lite/g3doc/devguide.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Developer Guide
diff --git a/tensorflow/contrib/lite/g3doc/ios.md b/tensorflow/contrib/lite/g3doc/ios.md
index d78d373ccf..5ff0412209 100644
--- a/tensorflow/contrib/lite/g3doc/ios.md
+++ b/tensorflow/contrib/lite/g3doc/ios.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# TensorFlow Lite for iOS
diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md
index 4ceb9a53dc..b984671e89 100644
--- a/tensorflow/contrib/lite/g3doc/models.md
+++ b/tensorflow/contrib/lite/g3doc/models.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# List of Hosted Models
diff --git a/tensorflow/contrib/lite/g3doc/ops_versioning.md b/tensorflow/contrib/lite/g3doc/ops_versioning.md
index b06f4fd3b8..0d571ce547 100644
--- a/tensorflow/contrib/lite/g3doc/ops_versioning.md
+++ b/tensorflow/contrib/lite/g3doc/ops_versioning.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# TensorFlow Lite Ops Versioning
diff --git a/tensorflow/contrib/lite/g3doc/overview.md b/tensorflow/contrib/lite/g3doc/overview.md
index be60d7941a..8cf43496df 100644
--- a/tensorflow/contrib/lite/g3doc/overview.md
+++ b/tensorflow/contrib/lite/g3doc/overview.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Introduction to TensorFlow Lite
diff --git a/tensorflow/contrib/lite/g3doc/performance.md b/tensorflow/contrib/lite/g3doc/performance.md
index 5cd0aab44f..28cb6aba6e 100644
--- a/tensorflow/contrib/lite/g3doc/performance.md
+++ b/tensorflow/contrib/lite/g3doc/performance.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Performance
diff --git a/tensorflow/contrib/lite/g3doc/rpi.md b/tensorflow/contrib/lite/g3doc/rpi.md
index 9fcf79ba00..8ed8640582 100644
--- a/tensorflow/contrib/lite/g3doc/rpi.md
+++ b/tensorflow/contrib/lite/g3doc/rpi.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# TensorFlow Lite for Raspberry Pi
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index aa65ec9988..fb9d5f6787 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# TensorFlow Lite & TensorFlow Compatibility Guide
@@ -843,6 +841,19 @@ Outputs {
}
```
+**UNPACK**
+
+```
+Inputs {
+ 0: a tensor.
+ 1: an integer.
+ 2: an integer.
+}
+Outputs {
+ 0-N: tensors of unpacked tensor.
+}
+```
+
And these are TensorFlow Lite operations that are present but not ready for
custom models yet:
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
index 76e16fc9db..c7cdee07de 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Building TensorFlow on Android
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/index.md b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
index bd047bfcec..d003bb2f38 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/index.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Overview
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md
index 6223707892..be8b4100c8 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Building TensorFlow on iOS
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md
index 4c2071ed05..4d4bb3bc08 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Integrating TensorFlow libraries
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md
index a0192c3541..7436594fd8 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Optimizing for mobile
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md
index 6b4e4a92bd..d1c67d4c61 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Preparing models for mobile deployment
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 1f528fdab9..407d52f0e8 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -211,6 +211,7 @@ cc_library(
"transpose_conv.cc",
"unidirectional_sequence_lstm.cc",
"unidirectional_sequence_rnn.cc",
+ "unpack.cc",
],
hdrs = [
"padding.h",
@@ -1201,6 +1202,20 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "unpack_test",
+ size = "small",
+ srcs = ["unpack_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 40160289c8..7319636bf5 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -2143,38 +2143,6 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims,
im2col_data, im2col_dims, gemm_context);
}
-template <typename T>
-inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
- int block_size, T* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("DepthToSpace");
-
- const int input_depth = ArraySize(input_dims, 0);
- const int input_width = ArraySize(input_dims, 1);
- const int input_height = ArraySize(input_dims, 2);
-
- const int output_depth = ArraySize(output_dims, 0);
- const int batch_size = ArraySize(output_dims, 3);
-
- // Number of continuous values that we can copy in one interation.
- const int stride = block_size * output_depth;
-
- for (int batch = 0; batch < batch_size; ++batch) {
- for (int in_h = 0; in_h < input_height; ++in_h) {
- const T* input_ptr = input_data + Offset(input_dims, 0, 0, in_h, batch);
- for (int offset_h = 0; offset_h < block_size; ++offset_h) {
- const T* src = input_ptr;
- for (int in_w = 0; in_w < input_width; ++in_w) {
- memcpy(output_data, src, stride * sizeof(T));
- output_data += stride;
- src += input_depth;
- }
- input_ptr += stride;
- }
- }
- }
-}
-
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac, typename T>
void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
@@ -2250,25 +2218,87 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
+inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("DepthToSpace");
+
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int input_depth = input_shape.Dims(3);
+ const int input_width = input_shape.Dims(2);
+ const int input_height = input_shape.Dims(1);
+
+ const int output_depth = output_shape.Dims(3);
+ const int batch_size = output_shape.Dims(0);
+
+ // Number of continuous values that we can copy in one interation.
+ const int stride = op_params.block_size * output_depth;
+
+ for (int batch = 0; batch < batch_size; ++batch) {
+ for (int in_h = 0; in_h < input_height; ++in_h) {
+ const T* input_ptr = input_data + Offset(input_shape, batch, in_h, 0, 0);
+ for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
+ const T* src = input_ptr;
+ for (int in_w = 0; in_w < input_width; ++in_w) {
+ memcpy(output_data, src, stride * sizeof(T));
+ output_data += stride;
+ src += input_depth;
+ }
+ input_ptr += stride;
+ }
+ }
+ }
+}
+
+// Legacy Dims<4>.
+template <typename T>
+inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
int block_size, T* output_data,
const Dims<4>& output_dims) {
+ tflite::DepthToSpaceParams op_params;
+ op_params.block_size = block_size;
+
+ DepthToSpace(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
gemmlowp::ScopedProfilingLabel label("SpaceToDepth");
- const int output_depth = ArraySize(output_dims, 0);
- const int output_width = ArraySize(output_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
- const int input_depth = ArraySize(input_dims, 0);
- const int batch_size = ArraySize(input_dims, 3);
+ const int output_depth = output_shape.Dims(3);
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+
+ const int input_depth = input_shape.Dims(3);
+ const int batch_size = input_shape.Dims(0);
// Number of continuous values that we can copy in one interation.
- const int stride = block_size * input_depth;
+ const int stride = op_params.block_size * input_depth;
for (int batch = 0; batch < batch_size; ++batch) {
for (int out_h = 0; out_h < output_height; ++out_h) {
- T* output_ptr = output_data + Offset(output_dims, 0, 0, out_h, batch);
- for (int offset_h = 0; offset_h < block_size; ++offset_h) {
+ T* output_ptr = output_data + Offset(output_shape, batch, out_h, 0, 0);
+ for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
T* dst = output_ptr;
for (int out_w = 0; out_w < output_width; ++out_w) {
memcpy(dst, input_data, stride * sizeof(T));
@@ -2281,6 +2311,18 @@ inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
}
}
+// Legacy Dims<4>.
+template <typename T>
+inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::SpaceToDepthParams op_params;
+ op_params.block_size = block_size;
+
+ SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
template <FusedActivationFunctionType Ac>
void NonGlobalBatchNormalization(
const float* input_data, const Dims<4>& input_dims, const float* mean_data,
@@ -5565,20 +5607,29 @@ inline void GetIndexRange(int spatial_index_dim, int block_shape_dim,
}
template <typename T>
-inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
- const int32* block_shape_data,
- const Dims<4>& block_shape_dims,
- const int32* crops_data, const Dims<4>& crops_dims,
- T* output_data, const Dims<4>& output_dims) {
+inline void BatchToSpaceND(
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
+ const RuntimeShape& unextended_input3_shape, const int32* crops_data,
+ const RuntimeShape& unextended_output_shape, T* output_data) {
gemmlowp::ScopedProfilingLabel label("BatchToSpaceND");
- const int output_batch_size = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int input_batch_size = ArraySize(input_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int depth = ArraySize(input_dims, 0);
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input1_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input1_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch_size = output_shape.Dims(0);
+
+ const int depth = input1_shape.Dims(3);
+ const int input_width = input1_shape.Dims(2);
+ const int input_height = input1_shape.Dims(1);
+ const int input_batch_size = input1_shape.Dims(0);
+
const int block_shape_width = block_shape_data[1];
const int block_shape_height = block_shape_data[0];
const int crops_top = crops_data[0];
@@ -5613,14 +5664,28 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
spatial_offset % block_shape_width - crops_left;
TFLITE_DCHECK_GE(out_w, 0);
TFLITE_DCHECK_LT(out_w, output_width);
- T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch);
- const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch);
+ T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0);
+ const T* in =
+ input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0);
memcpy(out, in, depth * sizeof(T));
}
}
}
}
+// Legacy Dims<4>.
+template <typename T>
+inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* crops_data, const Dims<4>& crops_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ BatchToSpaceND(DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
+ output_data);
+}
+
template <typename T>
void TypedMemset(void* ptr, T value, size_t num) {
// Optimization for common cases where memset() will suffice.
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index a6aef4fa29..3492a6c2f9 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -407,18 +407,29 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
- int block_size, T* output_data,
- const Dims<4>& output_dims) {
- const int input_depth = ArraySize(input_dims, 0);
- const int input_width = ArraySize(input_dims, 1);
- const int input_height = ArraySize(input_dims, 2);
- const int input_batch = ArraySize(input_dims, 3);
+inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
- const int output_depth = ArraySize(output_dims, 0);
- const int output_width = ArraySize(output_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_batch = ArraySize(output_dims, 3);
+ const int input_depth = input_shape.Dims(3);
+ const int input_width = input_shape.Dims(2);
+ const int input_height = input_shape.Dims(1);
+ const int input_batch = input_shape.Dims(0);
+
+ const int output_depth = output_shape.Dims(3);
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch = output_shape.Dims(0);
+
+ const int32 block_size = op_params.block_size;
TFLITE_DCHECK_EQ(input_width * block_size, output_width);
TFLITE_DCHECK_EQ(input_height * block_size, output_height);
@@ -437,9 +448,9 @@ inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
const int in_h = out_h / block_size;
const int in_b = out_b;
+ const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d);
const int output_index =
- Offset(output_dims, out_d, out_w, out_h, out_b);
- const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b);
+ Offset(output_shape, out_b, out_h, out_w, out_d);
output_data[output_index] = input_data[input_index];
}
@@ -448,19 +459,42 @@ inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
}
}
+// Legacy Dims<4>.
template <typename T>
-inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
+inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
int block_size, T* output_data,
const Dims<4>& output_dims) {
- const int input_depth = ArraySize(input_dims, 0);
- const int input_width = ArraySize(input_dims, 1);
- const int input_height = ArraySize(input_dims, 2);
- const int input_batch = ArraySize(input_dims, 3);
+ tflite::DepthToSpaceParams op_params;
+ op_params.block_size = block_size;
- const int output_depth = ArraySize(output_dims, 0);
- const int output_width = ArraySize(output_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_batch = ArraySize(output_dims, 3);
+ DepthToSpace(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int input_depth = input_shape.Dims(3);
+ const int input_width = input_shape.Dims(2);
+ const int input_height = input_shape.Dims(1);
+ const int input_batch = input_shape.Dims(0);
+
+ const int output_depth = output_shape.Dims(3);
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch = output_shape.Dims(0);
+
+ const int32 block_size = op_params.block_size;
TFLITE_DCHECK_EQ(input_width, output_width * block_size);
TFLITE_DCHECK_EQ(input_height, output_height * block_size);
@@ -478,9 +512,9 @@ inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
const int out_h = in_h / block_size;
const int out_b = in_b;
+ const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d);
const int output_index =
- Offset(output_dims, out_d, out_w, out_h, out_b);
- const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b);
+ Offset(output_shape, out_b, out_h, out_w, out_d);
output_data[output_index] = input_data[input_index];
}
@@ -489,6 +523,18 @@ inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
}
}
+// Legacy Dims<4>.
+template <typename T>
+inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::SpaceToDepthParams op_params;
+ op_params.block_size = block_size;
+
+ SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
const float* weights_data,
const Dims<4>& weights_dims, const float* bias_data,
@@ -2034,6 +2080,25 @@ void Pack(int dim, const Scalar* const* input_data,
}
}
+template <typename Scalar>
+void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims,
+ int dimensions, int outputs_count, Scalar* const* output_datas,
+ const Dims<4>& output_dims) {
+ int outer_size = 1;
+ for (int i = dimensions - axis; i < 4; i++) {
+ outer_size *= input_dims.sizes[i];
+ }
+
+ const int copy_size = FlatSize(input_dims) / outer_size / outputs_count;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < outputs_count; ++i) {
+ Scalar* output_ptr = output_datas[i] + copy_size * k;
+ int loc = k * outputs_count * copy_size + i * copy_size;
+ memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
+ }
+ }
+}
+
// TODO(prabhumk): This is the same as the optimized implementation.
// TODO(prabhumk): The quantized implementation of concatentation isn't fully
// quantized as it takes scale as a floating point value. This should be fixed
@@ -3467,45 +3532,56 @@ inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
- const int32* block_shape_data,
- const Dims<4>& block_shape_dims,
- const int32* paddings_data,
- const Dims<4>& paddings_dims, T* output_data,
- const Dims<4>& output_dims,
- const int32_t pad_value) {
- const int output_batch_size = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int input_batch_size = ArraySize(input_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int depth = ArraySize(input_dims, 0);
+inline void SpaceToBatchND(
+ const SpaceToBatchParams& params,
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
+ const RuntimeShape& unextended_input3_shape, const int32* paddings_data,
+ const RuntimeShape& unextended_output_shape, T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input1_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input1_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int depth = input1_shape.Dims(3);
+ const int input_width = input1_shape.Dims(2);
+ const int input_height = input1_shape.Dims(1);
+ const int input_batch_size = input1_shape.Dims(0);
+
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch_size = output_shape.Dims(0);
+
const int block_shape_height = block_shape_data[0];
const int block_shape_width = block_shape_data[1];
const int padding_top = paddings_data[0];
const int padding_left = paddings_data[2];
+ // For uint8 quantized, the correct padding "zero value" is the output offset.
+ const int32_t pad_value = params.output_offset;
+
for (int out_b = 0; out_b < output_batch_size; ++out_b) {
int input_batch = out_b % input_batch_size;
int shift_w = (out_b / input_batch_size) % block_shape_width;
int shift_h = (out_b / input_batch_size) / block_shape_width;
for (int out_h = 0; out_h < output_height; ++out_h) {
for (int out_w = 0; out_w < output_width; ++out_w) {
- T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b);
+ T* out = output_data + Offset(output_shape, out_b, out_h, out_w, 0);
if (out_h * block_shape_height + shift_h < padding_top ||
out_h * block_shape_height + shift_h >=
padding_top + input_height ||
out_w * block_shape_width + shift_w < padding_left ||
out_w * block_shape_width + shift_w >= padding_left + input_width) {
+ // This may not execute correctly when pad_value != 0 and T != uint8.
memset(out, pad_value, depth * sizeof(T));
} else {
const T* in =
- input_data +
- Offset(input_dims, 0,
- (out_w * block_shape_width + shift_w) - padding_left,
+ input1_data +
+ Offset(input1_shape, input_batch,
(out_h * block_shape_height + shift_h) - padding_top,
- input_batch);
+ (out_w * block_shape_width + shift_w) - padding_left, 0);
memcpy(out, in, depth * sizeof(T));
}
}
@@ -3513,30 +3589,63 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
}
}
+// Legacy Dims<4>.
template <typename T>
inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
const int32* block_shape_data,
const Dims<4>& block_shape_dims,
const int32* paddings_data,
const Dims<4>& paddings_dims, T* output_data,
- const Dims<4>& output_dims) {
- SpaceToBatchND(input_data, input_dims, block_shape_data, block_shape_dims,
- paddings_data, paddings_dims, output_data, output_dims, 0);
+ const Dims<4>& output_dims,
+ const int32_t pad_value) {
+ tflite::SpaceToBatchParams op_params;
+ op_params.output_offset = pad_value;
+
+ SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(paddings_dims), paddings_data,
+ DimsToShape(output_dims), output_data);
}
+// Legacy if no good reason to have signature with pad_value=0.
template <typename T>
-inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
+inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
const int32* block_shape_data,
const Dims<4>& block_shape_dims,
- const int32* crops_data, const Dims<4>& crops_dims,
- T* output_data, const Dims<4>& output_dims) {
- const int output_batch_size = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int input_batch_size = ArraySize(input_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int depth = ArraySize(input_dims, 0);
+ const int32* paddings_data,
+ const Dims<4>& paddings_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::SpaceToBatchParams op_params;
+ op_params.output_offset = 0;
+
+ SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(paddings_dims), paddings_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void BatchToSpaceND(
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
+ const RuntimeShape& unextended_input3_shape, const int32* crops_data,
+ const RuntimeShape& unextended_output_shape, T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input1_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input1_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch_size = output_shape.Dims(0);
+
+ const int depth = input1_shape.Dims(3);
+ const int input_width = input1_shape.Dims(2);
+ const int input_height = input1_shape.Dims(1);
+ const int input_batch_size = input1_shape.Dims(0);
+
const int block_shape_width = block_shape_data[1];
const int block_shape_height = block_shape_data[0];
const int crops_top = crops_data[0];
@@ -3558,14 +3667,28 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
if (out_w < 0 || out_w >= output_width) {
continue;
}
- T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch);
- const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch);
+ T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0);
+ const T* in =
+ input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0);
memcpy(out, in, depth * sizeof(T));
}
}
}
}
+// Legacy Dims<4>.
+template <typename T>
+inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* crops_data, const Dims<4>& crops_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ BatchToSpaceND(DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
+ output_data);
+}
+
// There are two versions of pad: Pad and PadV2. In PadV2 there is a second
// scalar input that provides the padding value. Therefore pad_value_ptr can be
// equivalent to a simple input1_data. For Pad, it should point to a zero
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 27b78aa225..2603ed2eb7 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -745,7 +745,7 @@ struct ConvParams {
};
struct DepthToSpaceParams {
- int16 block_size;
+ int32 block_size;
};
struct DepthwiseParams {
@@ -871,8 +871,13 @@ struct SoftmaxParams {
int diff_min;
};
+struct SpaceToBatchParams {
+ // "Zero" padding for uint8 means padding with the output offset.
+ int32 output_offset;
+};
+
struct SpaceToDepthParams {
- int16 block_size;
+ int32 block_size;
};
struct SplitParams {
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index ba251c451e..74dc3f25f9 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -37,7 +37,7 @@ namespace builtin {
namespace lstm {
struct OpData {
- // Which kernel type to use. Full kernel (18 or 20 inputs) or basic kernel
+ // Which kernel type to use. Full kernel (20 inputs) or basic kernel
// (5 inputs).
TfLiteLSTMKernelType kernel_type;
@@ -47,7 +47,7 @@ struct OpData {
int scratch_tensor_index;
};
-// For full inputs kernel (18 or 20 inputs).
+// For full inputs kernel (20-inputs).
namespace full {
// Input Tensors of size {n_batch, n_input}
@@ -81,19 +81,13 @@ constexpr int kProjectionWeightsTensor = 16; // Optional
// Projection bias tensor of size {n_output}
constexpr int kProjectionBiasTensor = 17; // Optional
-// If the node has 20 inputs, the following 2 tensors are used as state tensors.
-// These are defined as variable tensors, and will be modified by this op.
+// These state tensors are defined as variable tensors, and will be modified by
+// this op.
constexpr int kInputActivationStateTensor = 18;
constexpr int kInputCellStateTensor = 19;
// Output tensors.
-// * If the node has 18 inputs, these 2 tensors are used as state tensors.
-// * If the node has 20 inputs, these 2 tensors are ignored.
-// TODO(ycling): Make the 2 output state tensors optional, and propagate the
-// state to output tensors when the 2 tensors present.
-constexpr int kOutputStateTensor = 0;
-constexpr int kCellStateTensor = 1;
-constexpr int kOutputTensor = 2;
+constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData();
@@ -258,30 +252,12 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 3);
-
- // True if the node is using input variable state tensors. It means:
- // * The state tensors are defined as inputs. In this case it would be the
- // 19th and 20th input tensors.
- // * Otherwise, the output tensors are used to store states.
- bool use_input_variable_states;
- if (node->inputs->size == 20) {
- use_input_variable_states = true;
- op_data->activation_state_tensor_index =
- node->inputs->data[kInputActivationStateTensor];
- op_data->cell_state_tensor_index =
- node->inputs->data[kInputCellStateTensor];
- } else if (node->inputs->size == 18) {
- use_input_variable_states = false;
- op_data->activation_state_tensor_index =
- node->outputs->data[kOutputStateTensor];
- op_data->cell_state_tensor_index = node->outputs->data[kCellStateTensor];
- } else {
- context->ReportError(
- context, "The LSTM Full kernel expects 18 or 20 inputs. Got %d inputs",
- node->inputs->size);
- return kTfLiteError;
- }
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 20);
+
+ op_data->activation_state_tensor_index =
+ node->inputs->data[kInputActivationStateTensor];
+ op_data->cell_state_tensor_index = node->inputs->data[kInputCellStateTensor];
// Inferring batch size, number of outputs and number of cells from the
// input tensors.
@@ -316,31 +292,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* cell_state =
&context->tensors[op_data->cell_state_tensor_index];
- if (use_input_variable_states) {
- // Check the shape of input state tensors.
- // These tensor may be 1D or 2D. It's fine as long as the total size is
- // correct.
- TF_LITE_ENSURE_EQ(context, NumElements(activation_state),
- n_batch * n_output);
- TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
- } else {
- // If the state tensors are outputs, this function takes the
- // responsibility to resize the state tensors.
- TfLiteIntArray* activation_state_size = TfLiteIntArrayCreate(2);
- activation_state_size->data[0] = n_batch;
- activation_state_size->data[1] = n_output;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state,
- activation_state_size));
-
- TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
- cell_size->data[0] = n_batch;
- cell_size->data[1] = n_cell;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, cell_state, cell_size));
- // Mark state tensors as persistent tensors.
- activation_state->allocation_type = kTfLiteArenaRwPersistent;
- cell_state->allocation_type = kTfLiteArenaRwPersistent;
- }
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
// Resize the output tensors.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc
index 0266f5fe57..e7ddfceb45 100644
--- a/tensorflow/contrib/lite/kernels/lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/lstm_test.cc
@@ -106,14 +106,13 @@ class LSTMOpModel : public SingleOpModel {
input_cell_state_ =
AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
- output_state_ = AddOutput(TensorType_FLOAT32);
- cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
cell_clip, proj_clip)
.Union());
+
BuildInterpreter(input_shapes);
}
@@ -185,22 +184,6 @@ class LSTMOpModel : public SingleOpModel {
PopulateTensor(projection_bias_, f);
}
- void ResetOutputState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetCellState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
void SetInput(int offset, const float* begin, const float* end) {
PopulateTensor(input_, offset, const_cast<float*>(begin),
const_cast<float*>(end));
@@ -469,10 +452,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -529,10 +508,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
/*tolerance=*/0.0157651);
}
@@ -637,10 +612,6 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -698,14 +669,10 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
}
-class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
+class NoCifgPeepholeProjectionNoClippingLstmTest : public BaseLstmTest {
void SetUp() override {
input_to_input_weights_ = {
0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
@@ -1304,7 +1271,7 @@ class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
}
};
-TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
+TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
@@ -1362,14 +1329,10 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
lstm.SetProjectionWeights(projection_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
-TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
+TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
@@ -1428,10 +1391,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetProjectionWeights(projection_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
}
diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
index 1c728a4733..90a915bb02 100644
--- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
+++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
@@ -101,8 +101,6 @@ class LSTMOpModel : public SingleOpModel {
input_cell_state_ =
AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
- output_state_ = AddOutput(TensorType_FLOAT32);
- cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
@@ -180,22 +178,6 @@ class LSTMOpModel : public SingleOpModel {
PopulateTensor(projection_bias_, f);
}
- void ResetOutputState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetCellState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
void SetInput(int offset, float* begin, float* end) {
PopulateTensor(input_, offset, begin, end);
}
@@ -238,8 +220,6 @@ class LSTMOpModel : public SingleOpModel {
int input_cell_state_;
int output_;
- int output_state_;
- int cell_state_;
int n_batch_;
int n_input_;
@@ -324,10 +304,6 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
lstm.SetCellToOutputWeights(
{-0.17135078, 0.82760304, 0.85573703, -0.77109635});
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
// Verify the model by unpacking it.
lstm.Verify();
}
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 10d1fcc5bc..341fd14127 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -113,6 +113,7 @@ TfLiteRegistration* Register_ONE_HOT();
TfLiteRegistration* Register_LOGICAL_OR();
TfLiteRegistration* Register_LOGICAL_AND();
TfLiteRegistration* Register_LOGICAL_NOT();
+TfLiteRegistration* Register_UNPACK();
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
context->ReportError(
@@ -235,6 +236,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR());
AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND());
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
+ AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/contrib/lite/kernels/unpack.cc b/tensorflow/contrib/lite/kernels/unpack.cc
new file mode 100644
index 0000000000..4998f88b41
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/unpack.cc
@@ -0,0 +1,130 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace unpack {
+namespace {
+
+constexpr int kInputTensor = 0;
+
+// Op data for unpack op.
+struct OpData {
+ int num;
+ int axis;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ data->axis = 0;
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const OpData* data = reinterpret_cast<OpData*>(node->builtin_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), data->num);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
+ TF_LITE_ENSURE(context, NumDimensions(input) > 1);
+ TF_LITE_ENSURE(context, NumDimensions(input) > data->axis);
+ // TODO(renjieliu): Support negative axis.
+ TF_LITE_ENSURE(context, data->axis >= 0);
+ if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32) {
+ context->ReportError(context,
+ "Currently pack only supports int32 and float32.");
+ return kTfLiteError;
+ }
+
+ const TfLiteIntArray* input_shape = input->dims;
+ // Num should be equal to the shape[axis].
+ // Resize outputs. rank will be R - 1.
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(input) - 1);
+ int o = 0;
+ for (int index = 0; index < NumDimensions(input); ++index) {
+ if (index != data->axis) {
+ output_shape->data[o++] = input_shape->data[index];
+ }
+ }
+
+ TF_LITE_ENSURE_EQ(context, data->num, input_shape->data[data->axis]);
+ for (int i = 0; i < data->num; ++i) {
+ TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape);
+ TfLiteTensor* output = GetOutput(context, node, i);
+ TF_LITE_ENSURE_EQ(context, output->type, input->type);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, output, copied_output_shape));
+ }
+
+ TfLiteIntArrayFree(output_shape);
+ return kTfLiteOk;
+}
+
+template <typename T>
+void UnpackImpl(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteTensor* input, int output_count, int axis) {
+ VectorOfTensors<T> all_outputs(*context, *node->outputs);
+ reference_ops::Unpack<T>(axis, GetTensorData<T>(input), GetTensorDims(input),
+ NumDimensions(input), output_count,
+ all_outputs.data(), **all_outputs.dims());
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const OpData* data = reinterpret_cast<OpData*>(node->builtin_data);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ UnpackImpl<float>(context, node, input, data->num, data->axis);
+ break;
+ }
+ case kTfLiteInt32: {
+ UnpackImpl<int32_t>(context, node, input, data->num, data->axis);
+ break;
+ }
+ default: {
+ context->ReportError(context,
+ "Currently pack only supports int32 and float32.");
+ return kTfLiteError;
+ }
+ }
+
+ return kTfLiteOk;
+}
+} // namespace
+} // namespace unpack
+
+TfLiteRegistration* Register_UNPACK() {
+ static TfLiteRegistration r = {unpack::Init, unpack::Free, unpack::Prepare,
+ unpack::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/unpack_test.cc b/tensorflow/contrib/lite/kernels/unpack_test.cc
new file mode 100644
index 0000000000..4efc92a0fd
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/unpack_test.cc
@@ -0,0 +1,225 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+template <typename T>
+class UnpackOpModel : public SingleOpModel {
+ public:
+ UnpackOpModel(const TensorData& input, int axis) {
+ CHECK_LE(axis, input.shape.size());
+ const int num_outputs = input.shape[axis];
+ input_ = AddInput(input);
+ for (int i = 0; i < num_outputs; ++i) {
+ outputs_.push_back(AddOutput(input.type));
+ }
+ SetBuiltinOp(BuiltinOperator_UNPACK, BuiltinOptions_UnpackOptions,
+ CreatePackOptions(builder_, num_outputs, axis).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ void SetInput(std::initializer_list<T> data) {
+ PopulateTensor<T>(input_, data);
+ }
+
+ std::vector<std::vector<T>> GetOutputDatas() {
+ std::vector<std::vector<T>> output_datas;
+ for (const int output : outputs_) {
+ std::cerr << "the output is " << output << std::endl;
+ output_datas.push_back(ExtractVector<T>(output));
+ }
+ return output_datas;
+ }
+
+ std::vector<std::vector<int>> GetOutputShapes() {
+ std::vector<std::vector<int>> output_shapes;
+ for (const int output : outputs_) {
+ output_shapes.push_back(GetTensorShape(output));
+ }
+ return output_shapes;
+ }
+
+ private:
+ int input_;
+ std::vector<int> outputs_;
+};
+
+// float32 tests.
+TEST(UnpackOpTest, FloatThreeOutputs) {
+ UnpackOpModel<float> model({TensorType_FLOAT32, {3, 2}}, 0);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 3);
+ EXPECT_THAT(output_shapes[0], ElementsAre(2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(2));
+ EXPECT_THAT(output_shapes[2], ElementsAre(2));
+
+ // Check outputs values.
+ const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 3);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 2));
+ EXPECT_THAT(output_datas[1], ElementsAre(3, 4));
+ EXPECT_THAT(output_datas[2], ElementsAre(5, 6));
+}
+
+TEST(UnpackOpTest, FloatThreeOutputsAxisOne) {
+ UnpackOpModel<float> model({TensorType_FLOAT32, {3, 2}}, 1);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 2);
+ EXPECT_THAT(output_shapes[0], ElementsAre(3));
+ EXPECT_THAT(output_shapes[1], ElementsAre(3));
+
+ // Check outputs values.
+ const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 2);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5));
+ EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6));
+}
+
+TEST(UnpackOpTest, FloatOneOutput) {
+ UnpackOpModel<float> model({TensorType_FLOAT32, {1, 6}}, 0);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 1);
+ EXPECT_THAT(output_shapes[0], ElementsAre(6));
+
+ // Check outputs values.
+ const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 1);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 2, 3, 4, 5, 6));
+}
+
+TEST(UnpackOpTest, FloatThreeDimensionsOutputs) {
+ UnpackOpModel<float> model({TensorType_FLOAT32, {2, 2, 2}}, 2);
+ model.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 2);
+ EXPECT_THAT(output_shapes[0], ElementsAre(2, 2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(2, 2));
+
+ // Check outputs values.
+ const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 2);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5, 7));
+ EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6, 8));
+}
+
+// int32 tests.
+TEST(UnpackOpTest, IntThreeOutputs) {
+ UnpackOpModel<int32_t> model({TensorType_INT32, {3, 2}}, 0);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 3);
+ EXPECT_THAT(output_shapes[0], ElementsAre(2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(2));
+ EXPECT_THAT(output_shapes[2], ElementsAre(2));
+
+ // Check outputs values.
+ const std::vector<std::vector<int32_t>>& output_datas =
+ model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 3);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 2));
+ EXPECT_THAT(output_datas[1], ElementsAre(3, 4));
+ EXPECT_THAT(output_datas[2], ElementsAre(5, 6));
+}
+
+TEST(UnpackOpTest, IntThreeOutputsAxisOne) {
+ UnpackOpModel<int32_t> model({TensorType_INT32, {3, 2}}, 1);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 2);
+ EXPECT_THAT(output_shapes[0], ElementsAre(3));
+ EXPECT_THAT(output_shapes[1], ElementsAre(3));
+
+ // Check outputs values.
+ const std::vector<std::vector<int32_t>>& output_datas =
+ model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 2);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5));
+ EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6));
+}
+
+TEST(UnpackOpTest, IntOneOutput) {
+ UnpackOpModel<int32_t> model({TensorType_INT32, {1, 6}}, 0);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 1);
+ EXPECT_THAT(output_shapes[0], ElementsAre(6));
+
+ // Check outputs values.
+ const std::vector<std::vector<int32_t>>& output_datas =
+ model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 1);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 2, 3, 4, 5, 6));
+}
+
+TEST(UnpackOpTest, IntThreeDimensionsOutputs) {
+ UnpackOpModel<int32_t> model({TensorType_INT32, {2, 2, 2}}, 2);
+ model.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 2);
+ EXPECT_THAT(output_shapes[0], ElementsAre(2, 2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(2, 2));
+
+ // Check outputs values.
+ const std::vector<std::vector<int32_t>>& output_datas =
+ model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 2);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5, 7));
+ EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6, 8));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh b/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh
index b58ae26601..6195426d6d 100755
--- a/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh
+++ b/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh
@@ -14,6 +14,7 @@
# limitations under the License.
# ==============================================================================
+# TODO(ycling): Refactoring - Move this script into `tools/make`.
set -e
echo "Starting"
@@ -32,7 +33,7 @@ echo "Headers, populating: TensorFlow Lite"
cd $TFLITE_DIR/../../..
find tensorflow/contrib/lite -name '*.h' \
- -not -path 'tensorflow/contrib/lite/downloads/*' \
+ -not -path 'tensorflow/contrib/lite/tools/*' \
-not -path 'tensorflow/contrib/lite/examples/*' \
-not -path 'tensorflow/contrib/lite/gen/*' \
-not -path 'tensorflow/contrib/lite/toco/*' \
@@ -44,7 +45,7 @@ tar xf tmp.tar
rm -f tmp.tar
echo "Headers, populating: Flatbuffer"
-cd $TFLITE_DIR/downloads/flatbuffers/include/
+cd $TFLITE_DIR/tools/make/downloads/flatbuffers/include/
find . -name '*.h' | tar -cf $FW_DIR_TFLITE_HDRS/tmp.tar -T -
cd $FW_DIR_TFLITE_HDRS
tar xf tmp.tar
@@ -57,7 +58,7 @@ cp $TFLITE_DIR/../../../bazel-genfiles/tensorflow/tools/lib_package/include/tens
$FW_DIR_TFLITE
echo "Copying static libraries"
-cp $TFLITE_DIR/gen/lib/libtensorflow-lite.a \
+cp $TFLITE_DIR/tools/make/gen/lib/libtensorflow-lite.a \
$FW_DIR_TFLITE/tensorflow_lite
# This is required, otherwise they interfere with the documentation of the
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 7ca12cb841..da3ed42e20 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -745,6 +745,15 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = static_cast<void*>(params);
break;
}
+ case BuiltinOperator_UNPACK: {
+ TfLiteUnpackParams* params = MallocPOD<TfLiteUnpackParams>();
+ if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) {
+ params->num = unpack_params->num();
+ params->axis = unpack_params->axis();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
// Below are the ops with no builtin_data strcture.
case BuiltinOperator_BATCH_TO_SPACE_ND:
@@ -790,7 +799,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_LOGICAL_OR:
case BuiltinOperator_LOGICAL_AND:
case BuiltinOperator_LOGICAL_NOT:
- case BuiltinOperator_UNPACK:
case BuiltinOperator_FLOOR_DIV:
case BuiltinOperator_REDUCE_ANY:
break;
diff --git a/tensorflow/contrib/lite/models/speech_test.cc b/tensorflow/contrib/lite/models/speech_test.cc
index fad39bee9e..8ecf0b6154 100644
--- a/tensorflow/contrib/lite/models/speech_test.cc
+++ b/tensorflow/contrib/lite/models/speech_test.cc
@@ -126,7 +126,7 @@ TEST_P(SpeechTest, DISABLED_HotwordOkGoogleRank2Test) {
<< test_driver.GetErrorMessage();
}
-TEST_P(SpeechTest, SpeakerIdOkGoogleTest) {
+TEST_P(SpeechTest, DISABLED_SpeakerIdOkGoogleTest) {
std::stringstream os;
ASSERT_TRUE(ConvertCsvData(
"speech_speakerid_model.tflite", "speech_speakerid_model_in.csv",
@@ -139,7 +139,7 @@ TEST_P(SpeechTest, SpeakerIdOkGoogleTest) {
<< test_driver.GetErrorMessage();
}
-TEST_P(SpeechTest, AsrAmTest) {
+TEST_P(SpeechTest, DISABLED_AsrAmTest) {
std::stringstream os;
ASSERT_TRUE(
ConvertCsvData("speech_asr_am_model.tflite", "speech_asr_am_model_in.csv",
@@ -156,7 +156,7 @@ TEST_P(SpeechTest, AsrAmTest) {
// through the interpreter and stored the sum of all the output, which was them
// compared for correctness. In this test we are comparing all the intermediate
// results.
-TEST_P(SpeechTest, AsrLmTest) {
+TEST_P(SpeechTest, DISABLED_AsrLmTest) {
std::ifstream in_file;
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
ASSERT_TRUE(Init("speech_asr_lm_model.test_spec", &test_driver, &in_file));
@@ -165,7 +165,7 @@ TEST_P(SpeechTest, AsrLmTest) {
<< test_driver.GetErrorMessage();
}
-TEST_P(SpeechTest, EndpointerTest) {
+TEST_P(SpeechTest, DISABLED_EndpointerTest) {
std::stringstream os;
ASSERT_TRUE(ConvertCsvData(
"speech_endpointer_model.tflite", "speech_endpointer_model_in.csv",
@@ -178,7 +178,7 @@ TEST_P(SpeechTest, EndpointerTest) {
<< test_driver.GetErrorMessage();
}
-TEST_P(SpeechTest, TtsTest) {
+TEST_P(SpeechTest, DISABLED_TtsTest) {
std::stringstream os;
ASSERT_TRUE(ConvertCsvData("speech_tts_model.tflite",
"speech_tts_model_in.csv",
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 599c82940e..a329bb3a25 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -2378,7 +2378,7 @@ def make_lstm_tests(zip_path):
"time_step_size": [1],
"input_vec_size": [3],
"num_cells": [4],
- "split_tflite_lstm_inputs": [True, False],
+ "split_tflite_lstm_inputs": [False],
},
]
@@ -3149,6 +3149,36 @@ def make_pack_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_unpack_tests(zip_path):
+ """Make a set of tests to do unstack."""
+
+ test_parameters = [{
+ "base_shape": [[3, 4, 3], [3, 4], [5, 6, 7, 8]],
+ "axis": [0, 1, 2, 3],
+ }]
+
+ def get_valid_axis(parameters):
+ """Return a tweaked version of 'axis'."""
+ axis = parameters["axis"]
+ shape = parameters["base_shape"][:]
+ while axis > len(shape) - 1:
+ axis -= 1
+ return axis
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name=("input"), shape=parameters["base_shape"])
+ outs = tf.unstack(input_tensor, axis=get_valid_axis(parameters))
+ return [input_tensor], outs
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value = create_tensor_data(np.float32, shape=parameters["base_shape"])
+ return [input_value], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def _make_logical_tests(op):
"""Make a set of tests to do logical operations."""
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
index 4dacf9c84b..1836eb53b9 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.cc
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -302,28 +302,6 @@ bool TfLiteDriver::CheckResults() {
void TfLiteDriver::ResetLSTMStateTensors() {
interpreter_->ResetVariableTensorsToZero();
-
- // Below is a workaround for initializing state tensors for LSTM.
- // TODO(ycling): Remove the code below after nobody is using the 18-inputs
- // definition.
- for (auto node_index : interpreter_->execution_plan()) {
- const auto& node_and_reg = interpreter_->node_and_registration(node_index);
- const auto& node = node_and_reg->first;
- const auto& registration = node_and_reg->second;
-
- if (registration.builtin_code == tflite::BuiltinOperator_LSTM) {
- const auto* params =
- reinterpret_cast<const TfLiteLSTMParams*>(node.builtin_data);
- if (params->kernel_type == kTfLiteLSTMFullKernel &&
- node.inputs->size == 18 && node.outputs->size >= 2) {
- // The first 2 outputs of LSTM are state tensors.
- for (int i = 0; i < 2; ++i) {
- int node_index = node.outputs->data[i];
- ResetTensor(node_index);
- }
- }
- }
- }
}
} // namespace testing
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index f489c5ac65..94602445c2 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1967,6 +1967,20 @@ void ConvertCTCBeamSearchDecoderOperator(
(*op->mutable_attr())["merge_repeated"].set_b(src_op.merge_repeated);
}
+void ConvertUnpackOperator(const Model& model, const UnpackOperator& src_op,
+ const char* op_name, GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* unpack_op = tensorflow_graph->add_node();
+ unpack_op->set_op(op_name);
+ unpack_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *unpack_op->add_input() = src_op.inputs[0];
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*unpack_op->mutable_attr())["T"].set_type(data_type);
+ (*unpack_op->mutable_attr())["num"].set_i(src_op.num);
+ (*unpack_op->mutable_attr())["axis"].set_i(src_op.axis);
+}
+
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -2228,6 +2242,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertCTCBeamSearchDecoderOperator(
model, static_cast<const CTCBeamSearchDecoderOperator&>(src_op),
"CTCBeamSearchDecoder", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kUnpack) {
+ ConvertUnpackOperator(model, static_cast<const UnpackOperator&>(src_op),
+ "Unpack", tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index c8310161cb..323eefcd3a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -227,6 +227,15 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
ArrayDataType::kFloat;
break;
}
+ case OperatorType::kUnpack: {
+ CHECK_EQ(op->inputs.size(), 1);
+ const int output_size = op->outputs.size();
+ for (int i = 0; i < output_size; ++i) {
+ model->GetArray(op->outputs[i]).data_type =
+ model->GetArray(op->inputs[0]).data_type;
+ }
+ break;
+ }
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 91e290439a..fa2be961f5 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1629,6 +1629,32 @@ void ProcessOneHotOperator(Model* model, OneHotOperator* op) {
}
}
+void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
+ CHECK_EQ(op->inputs.size(), 1);
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+
+ const std::vector<int>& input_dims = input_array.shape().dims();
+ std::vector<int> output_dims;
+
+ output_dims.reserve(input_dims.size() - 1);
+ for (int i = 0; i < input_dims.size(); ++i) {
+ if (i != op->axis) {
+ output_dims.push_back(input_dims[i]);
+ }
+ }
+ for (const string& output_name : op->outputs) {
+ auto& output_array = model->GetArray(output_name);
+ if (output_array.has_shape()) {
+ return;
+ }
+ *output_array.mutable_shape()->mutable_dims() = output_dims;
+ }
+}
+
} // namespace
bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
@@ -1880,6 +1906,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kOneHot:
ProcessOneHotOperator(model, static_cast<OneHotOperator*>(op));
break;
+ case OperatorType::kUnpack:
+ ProcessUnpackOperator(model, static_cast<UnpackOperator*>(op));
+ break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index b7fffbce22..0e04ee4ccb 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1576,6 +1576,26 @@ tensorflow::Status ConvertPackOperator(
return tensorflow::Status::OK();
}
+tensorflow::Status ConvertUnpackOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "Unpack");
+ auto op = absl::make_unique<UnpackOperator>();
+ const int num_inputs = GetInputsCount(node, tf_import_flags);
+ QCHECK_EQ(num_inputs, 1);
+ op->inputs.push_back(node.input(0));
+ op->num = GetIntAttr(node, "num");
+ op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0;
+ op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T"));
+
+ op->outputs.push_back(node.name()); // Implicit :0.
+ for (int i = 1; i < op->num; ++i) {
+ op->outputs.push_back(node.name() + ":" + std::to_string(i));
+ }
+ model->operators.emplace_back(std::move(op));
+ return tensorflow::Status::OK();
+}
+
// Some TensorFlow ops only occur in graph cycles, representing
// control flow. We do not currently support control flow, so we wouldn't
// be able to fully support such graphs, including performing inference,
@@ -2020,6 +2040,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"TopK", ConvertTopKV2Operator},
{"TopKV2", ConvertTopKV2Operator},
{"Transpose", ConvertSimpleOperator<TransposeOperator, 2>},
+ {"Unpack", ConvertUnpackOperator},
});
}
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 412e14c4ad..3a909c3d8e 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -149,6 +149,7 @@ enum class OperatorType : uint8 {
kLogicalNot,
kLogicalOr,
kCTCBeamSearchDecoder,
+ kUnpack,
};
// Helper to deal with TensorFlow arrays using a different ordering of
@@ -1828,6 +1829,20 @@ struct LogicalOrOperator : Operator {
LogicalOrOperator() : Operator(OperatorType::kLogicalOr) {}
};
+// Unpack operator:
+//
+// Inputs:
+// Inputs[0]: required: A boolean input tensor.
+// Inputs[1]: required: reduction_indices.
+//
+// TensorFlow equivalent: tf.unstack.
+struct UnpackOperator : Operator {
+ UnpackOperator() : Operator(OperatorType::kUnpack) {}
+ int num;
+ int axis;
+ ArrayDataType dtype = ArrayDataType::kNone;
+};
+
// Alloc's are used for transient arrays only. An Alloc specifies which interval
// of the "transient_data" workspace buffer passed to inference functions, is to
// be used for the transient array at hand. The 'start' and 'end' values are
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index dcb5fff39f..e9383098cc 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1110,6 +1110,24 @@ class CTCBeamSearchDecoder
int GetVersion(const Operator& op) const override { return 1; }
};
+class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
+ ::tflite::BuiltinOptions_UnpackOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateUnpackOptions(*builder, op.num, op.axis);
+ }
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->num = options.num();
+ op->axis = options.axis();
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
class TensorFlowUnsupported : public BaseOperator {
public:
using BaseOperator::BaseOperator;
@@ -1353,6 +1371,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT,
OperatorType::kOneHot));
+ ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK,
+ OperatorType::kUnpack));
// Custom Operators.
ops.push_back(
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index fc854461b4..bb0b457483 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -476,6 +476,16 @@ TEST_F(OperatorTest, BuiltinOneHot) {
EXPECT_EQ(op.axis, output_toco_op->axis);
}
+TEST_F(OperatorTest, BuiltinUnpack) {
+ UnpackOperator op;
+ op.num = 5;
+ op.axis = 2;
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("UNPACK", OperatorType::kUnpack), op);
+ EXPECT_EQ(op.num, output_toco_op->num);
+ EXPECT_EQ(op.axis, output_toco_op->axis);
+}
+
TEST_F(OperatorTest, CustomCTCBeamSearchDecoder) {
CTCBeamSearchDecoderOperator op;
op.beam_width = 3;
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 3a4542f522..6ab93d9316 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -405,6 +405,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(LogicalNot)
HANDLE_OPERATORTYPENAME_CASE(LogicalOr)
HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder)
+ HANDLE_OPERATORTYPENAME_CASE(Unpack)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE
diff --git a/tensorflow/contrib/lite/tools/optimize/BUILD b/tensorflow/contrib/lite/tools/optimize/BUILD
new file mode 100644
index 0000000000..01fbce0ac7
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/optimize/BUILD
@@ -0,0 +1,11 @@
+# TODO(suharshs): Write quantize_weights tests that use small exportable files.
+# Then we can remove this file.
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
new file mode 100644
index 0000000000..0758514e39
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
@@ -0,0 +1,280 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/tools/optimize/quantize_weights.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tflite {
+namespace optimize {
+
+namespace {
+
+// The minimum number of elements a weights array must have to be quantized
+// by this transformation.
+// TODO(suharshs): Make this configurable.
+const int kWeightsMinSize = 1024;
+
+// Nudge min and max so that floating point 0 falls exactly on a quantized
+// value, returning the nudges scale and zero_point.
+//
+// Although this code originates from FakeQuantization in quantized training,
+// we may deviate from that implementation as we please since we do not fine
+// tune the weights with quantized training.
+void GetQuantizationParams(const float min, const float max,
+ const int quant_min, const int quant_max,
+ QuantizationParametersT* quantization_params) {
+ // Adjust the boundaries to guarantee 0 is included.
+ const float quant_min_float = std::min(static_cast<float>(quant_min), 0.0f);
+ const float quant_max_float = std::max(static_cast<float>(quant_max), 0.0f);
+ const float scale = (max - min) / (quant_max_float - quant_min_float);
+ const float zero_point_from_min = quant_min_float - min / scale;
+ int64_t zero_point;
+ if (zero_point_from_min < quant_min_float) {
+ zero_point = static_cast<int64_t>(quant_min);
+ } else if (zero_point_from_min > quant_max_float) {
+ zero_point = static_cast<int64_t>(quant_max);
+ } else {
+ zero_point = static_cast<int64_t>(std::round(zero_point_from_min));
+ }
+ quantization_params->scale = {scale};
+ quantization_params->zero_point = {zero_point};
+}
+
+// Returns the number of elements in tensor.
+uint64 NumElements(const TensorT* tensor) {
+ if (tensor->shape.empty()) {
+ LOG(FATAL) << "Tensor has no shape information.";
+ }
+ uint64 num_elements = 1;
+ for (const uint64 dim : tensor->shape) {
+ num_elements *= dim;
+ }
+ return num_elements;
+}
+
+uint64 CountTensorConsumers(const ModelT* model, const SubGraphT* subgraph,
+ int32_t tensor_idx) {
+ uint64 count = 0;
+ for (int op_idx = 0; op_idx < subgraph->operators.size(); ++op_idx) {
+ const OperatorT* op = subgraph->operators[op_idx].get();
+ if (op == nullptr) {
+ continue;
+ }
+ for (int i = 0; i < op->inputs.size(); ++i) {
+ if (op->inputs[i] == tensor_idx) {
+ count++;
+ }
+ }
+ }
+ return count;
+}
+
+// Returns true if the Operator's weight tensor should be quantized.
+bool GetQuantizableTensorFromOperator(const ModelT* model, const OperatorT* op,
+ TensorT** tensor, int32_t* tensor_idx,
+ int32_t* op_input_index) {
+ SubGraphT* subgraph = model->subgraphs.at(0).get();
+ const BuiltinOperator op_code =
+ model->operator_codes[op->opcode_index]->builtin_code;
+
+ if (op_code == BuiltinOperator_CONV_2D ||
+ op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
+ op_code == BuiltinOperator_FULLY_CONNECTED ||
+ op_code == BuiltinOperator_SVDF) {
+ *op_input_index = 1;
+ } else if (op_code == BuiltinOperator_LSTM) {
+ // TODO(suharshs): Add RNN, and sequential/bidi versions.
+ *op_input_index = 2;
+ } else {
+ return false;
+ }
+ *tensor_idx = op->inputs[*op_input_index];
+
+ // TODO(suharshs): Support shared weights, i.e. If two tensors share the
+ // same weight array, things may break. (i.e. SSD object detection)
+ if (CountTensorConsumers(model, subgraph, *tensor_idx) != 1) {
+ LOG(INFO) << "Skipping quantization of tensor that is shared between "
+ "multiple multiple operations.";
+ return false;
+ }
+
+ *tensor = subgraph->tensors[*tensor_idx].get();
+
+ if ((*tensor)->type != TensorType_FLOAT32) {
+ LOG(INFO) << "Skipping quantization of tensor that is not type float.";
+ return false;
+ }
+ const uint64 num_elements = NumElements(*tensor);
+ if (num_elements < kWeightsMinSize) {
+ LOG(INFO) << "Skipping quantization of tensor because it has fewer than "
+ << kWeightsMinSize << " elements (" << num_elements << ").";
+ return false;
+ }
+
+ return true;
+}
+
+// Quantizes tensor using asymmetric quantization with the min and max elements
+// of the tensor. This is needed to pass to Dequantize operations.
+TfLiteStatus AsymmetricQuantizeTensor(ModelT* model, TensorT* tensor) {
+ BufferT* buffer = model->buffers[tensor->buffer].get();
+ float* float_data = reinterpret_cast<float*>(buffer->data.data());
+ const uint64 num_elements = NumElements(tensor);
+ LOG(INFO) << "Quantizing tensor with " << num_elements << " elements.";
+
+ // Compute the quantization params.
+ float min_value = *std::min_element(float_data, float_data + num_elements);
+ float max_value = *std::max_element(float_data, float_data + num_elements);
+ GetQuantizationParams(min_value, max_value, 0, 255,
+ tensor->quantization.get());
+
+ // Quantize the buffer.
+ std::vector<uint8_t> quantized_buffer;
+ quantized_buffer.resize(num_elements);
+ const double inverse_scale = 1. / tensor->quantization->scale[0];
+ for (std::size_t i = 0; i < num_elements; i++) {
+ const float src_val = float_data[i];
+ double scaled_val;
+ if (tensor->quantization->scale[0] == 0) {
+ scaled_val = tensor->quantization->zero_point[0];
+ } else {
+ scaled_val =
+ tensor->quantization->zero_point[0] + inverse_scale * src_val;
+ }
+ uint8_t integer_val = static_cast<uint8_t>(std::round(scaled_val));
+ quantized_buffer[i] = integer_val;
+ }
+ model->buffers[tensor->buffer]->data = quantized_buffer;
+
+ // Update the tensor type.
+ tensor->type = TensorType_UINT8;
+
+ return kTfLiteOk;
+}
+
+// Returns the index of the Dequantize op_code.
+// If a Dequantize op_code doesn't exist, adds it and returns its index.
+int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) {
+ for (int i = 0; i < model->operator_codes.size(); ++i) {
+ if (model->operator_codes[i]->builtin_code == BuiltinOperator_DEQUANTIZE) {
+ return i;
+ }
+ }
+ model->operator_codes.push_back(std::make_unique<OperatorCodeT>());
+ int op_code_idx = model->operator_codes.size() - 1;
+ model->operator_codes[op_code_idx]->builtin_code = BuiltinOperator_DEQUANTIZE;
+ // TODO(suharshs): How should the version be set in this op_code?
+
+ // Return the index of the newly placed OperatorCodeT.
+ return op_code_idx;
+}
+
+// Creates a Dequantize OperatorT object.
+void MakeDequantizeOperator(ModelT* model, std::unique_ptr<OperatorT>* op,
+ int32_t input, int32_t output) {
+ OperatorT* op_raw = new OperatorT;
+ op_raw->opcode_index = GetOrInsertDequantizeOpCodeIndex(model);
+ op_raw->inputs = {input};
+ op_raw->outputs = {output};
+
+ op->reset(op_raw);
+}
+
+// Create a new TensorT object.
+void MakeTensor(const string& name, const std::vector<int32_t>& shape,
+ std::unique_ptr<TensorT>* tensor) {
+ TensorT* tensor_raw = new TensorT;
+ tensor_raw->name = name;
+ tensor_raw->shape = shape;
+
+ tensor->reset(tensor_raw);
+}
+
+} // namespace
+
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model) {
+ std::unique_ptr<ModelT> model;
+ model.reset(input_model->UnPack());
+
+ // TODO(suharshs): When models support multiple subgraphs, add support.
+ if (model->subgraphs.size() != 1) {
+ LOG(ERROR) << "Quantize weights tool only supports tflite models with one "
+ "subgraph.";
+ return kTfLiteError;
+ }
+
+ SubGraphT* subgraph = model->subgraphs.at(0).get();
+
+ std::vector<std::unique_ptr<OperatorT>> new_operators;
+ for (int i = 0; i < subgraph->operators.size(); ++i) {
+ OperatorT* op = subgraph->operators[i].get();
+
+ TensorT* tensor;
+ // The index of the weight tensor in subgraph->tensors.
+ int32_t tensor_idx;
+ int32_t op_input_idx; // The index of tensor_idx in the op->inputs.
+ // TODO(suharshs): Support hybrid ops that require symmetric quantization.
+ if (GetQuantizableTensorFromOperator(model.get(), op, &tensor, &tensor_idx,
+ &op_input_idx)) {
+ // Quantize the tensors.
+ TF_LITE_ENSURE_STATUS(AsymmetricQuantizeTensor(model.get(), tensor));
+
+ // Create a new tensor to be the output of the dequantize op.
+ std::unique_ptr<TensorT> dequantize_output;
+ MakeTensor(tensor->name + "_dequantize", tensor->shape,
+ &dequantize_output);
+ int32_t dequantize_output_idx = subgraph->tensors.size();
+ subgraph->tensors.push_back(std::move(dequantize_output));
+
+ // Create the Dequantize operation.
+ std::unique_ptr<OperatorT> dequantize_op;
+ MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx,
+ dequantize_output_idx);
+
+ // Update the op_input of tensor_idx to dequantize_output_idx.
+ op->inputs[op_input_idx] = dequantize_output_idx;
+ // Insert the updated op.
+ new_operators.push_back(std::move(subgraph->operators[i]));
+
+ // Insert the newly created Dequantize operation.
+ new_operators.push_back(std::move(dequantize_op));
+ } else {
+ // If this tensor wasn't quantizable, just copy the op over as-is.
+ new_operators.push_back(std::move(subgraph->operators[i]));
+ }
+ }
+ // At this point all unique_ptrs in the original operators are invalid, and
+ // we need to replace it with the new_operators vector.
+ subgraph->operators = std::move(new_operators);
+
+ flatbuffers::Offset<Model> output_model_location =
+ Model::Pack(*builder, model.get());
+ FinishModelBuffer(*builder, output_model_location);
+
+ return kTfLiteOk;
+}
+
+} // namespace optimize
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.h b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h
new file mode 100644
index 0000000000..a408c1662d
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_
+
+#include <memory>
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace optimize {
+
+// Quantizes input_model and populates the provided builder with the new model.
+//
+// A tflite::Model can be obtained from the builder with:
+// const uint8_t* buffer = builder->GetBufferPointer();
+// tflite::Model* model = GetModel(buffer);
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model);
+
+} // namespace optimize
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc
new file mode 100644
index 0000000000..0e0676e5ff
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc
@@ -0,0 +1,130 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/tools/optimize/quantize_weights.h"
+
+#include <memory>
+
+#include "flatbuffers/flexbuffers.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace optimize {
+namespace {
+
+class QuantizeWeightsTest : public ::testing::Test {
+ protected:
+ int GetElementsNum(const TensorT* tensor) {
+ int tensor_size = 1;
+ for (const int dim : tensor->shape) {
+ tensor_size *= dim;
+ }
+ return tensor_size;
+ }
+
+ const OperatorT* GetOpWithOutput(const SubGraphT* subgraph,
+ int32_t output_tensor_idx) {
+ for (int i = 0; i < subgraph->operators.size(); ++i) {
+ OperatorT* op = subgraph->operators[i].get();
+ if (std::find(op->outputs.begin(), op->outputs.end(),
+ output_tensor_idx) != op->outputs.end()) {
+ return op;
+ }
+ }
+ return nullptr;
+ }
+
+ void CheckWeights(const Model* model_packed) {
+ std::unique_ptr<ModelT> model;
+ model.reset(model_packed->UnPack());
+
+ SubGraphT* subgraph = model->subgraphs.at(0).get();
+
+ for (int i = 0; i < subgraph->operators.size(); ++i) {
+ OperatorT* op = subgraph->operators[i].get();
+ const BuiltinOperator op_code =
+ model->operator_codes[op->opcode_index]->builtin_code;
+
+ // These are the operations that should be quantized.
+ int32_t tensor_idx;
+ if (op_code == BuiltinOperator_CONV_2D ||
+ op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
+ op_code == BuiltinOperator_FULLY_CONNECTED) {
+ tensor_idx = op->inputs[1];
+ } else if (op_code == BuiltinOperator_LSTM) {
+ // TODO(suharshs): Add tests for LSTMs.
+ tensor_idx = op->inputs[1];
+ } else {
+ continue;
+ }
+ const TensorT* tensor = subgraph->tensors[tensor_idx].get();
+ int tensor_size = GetElementsNum(tensor);
+ // If the tensor_size is less than 1024 we expect the tensor to remain
+ // unquantized.
+ if (tensor_size < 1024) {
+ ASSERT_TRUE(tensor->type == TensorType_FLOAT32) << tensor->name;
+ const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx);
+ // The weight tensor should not come from a dequantize op.
+ ASSERT_TRUE(preceding_op == nullptr);
+ } else {
+ // The input to the op should still be float.
+ ASSERT_TRUE(tensor->type == TensorType_FLOAT32) << tensor->name;
+ const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx);
+ ASSERT_TRUE(preceding_op != nullptr);
+ // The float input should be the dequantize output.
+ ASSERT_TRUE(
+ model->operator_codes[preceding_op->opcode_index]->builtin_code ==
+ BuiltinOperator_DEQUANTIZE);
+ // Finally, ensure that the input to the dequantize operation is
+ // quantized.
+ ASSERT_TRUE(subgraph->tensors[preceding_op->inputs[0]]->type ==
+ TensorType_UINT8);
+ // TODO(suharshs): Add more rigorous testing for the numerical values in
+ // the tensors.
+ }
+ }
+ }
+};
+
+TEST_F(QuantizeWeightsTest, SimpleTest) {
+ string model_path =
+ "third_party/tensorflow/contrib/lite/tools/optimize/testdata/"
+ "mobilenet_v1_0.25_128.tflite";
+ std::unique_ptr<FlatBufferModel> input_fb =
+ FlatBufferModel::BuildFromFile(model_path.data());
+ const Model* input_model = input_fb->GetModel();
+
+ flatbuffers::FlatBufferBuilder builder;
+ EXPECT_EQ(QuantizeWeights(&builder, input_model), kTfLiteOk);
+
+ const uint8_t* buffer = builder.GetBufferPointer();
+ const Model* output_model = GetModel(buffer);
+
+ CheckWeights(output_model);
+}
+
+// TODO(suharshs): Add tests that run the resulting model.
+
+} // namespace
+} // namespace optimize
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: FLAGS_logtostderr = true;
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index a328670526..bbf5d3f30c 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -2532,7 +2532,8 @@ def sparse_recall_at_top_k(labels,
name=name_scope)
-def _compute_recall_at_precision(tp, fp, fn, precision, name):
+def _compute_recall_at_precision(tp, fp, fn, precision, name,
+ strict_mode=False):
"""Helper function to compute recall at a given `precision`.
Args:
@@ -2541,17 +2542,42 @@ def _compute_recall_at_precision(tp, fp, fn, precision, name):
fn: The number of false negatives.
precision: The precision for which the recall will be calculated.
name: An optional variable_scope name.
+ strict_mode: If true and there exists a threshold where the precision is
+ no smaller than the target precision, return the corresponding recall at
+ the threshold. Otherwise, return 0. If false, find the threshold where the
+ precision is closest to the target precision and return the recall at the
+ threshold.
Returns:
The recall at a given `precision`.
"""
precisions = math_ops.div(tp, tp + fp + _EPSILON)
- tf_index = math_ops.argmin(
- math_ops.abs(precisions - precision), 0, output_type=dtypes.int32)
+ if not strict_mode:
+ tf_index = math_ops.argmin(
+ math_ops.abs(precisions - precision), 0, output_type=dtypes.int32)
+ # Now, we have the implicit threshold, so compute the recall:
+ return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON,
+ name)
+ else:
+ # We aim to find the threshold where the precision is minimum but no smaller
+ # than the target precision.
+ # The rationale:
+ # 1. Compute the difference between precisions (by different thresholds) and
+ # the target precision.
+ # 2. Take the reciprocal of the values by the above step. The intention is
+ # to make the positive values rank before negative values and also the
+ # smaller positives rank before larger positives.
+ tf_index = math_ops.argmax(
+ math_ops.div(1.0, precisions - precision + _EPSILON),
+ 0,
+ output_type=dtypes.int32)
+
+ def _return_good_recall():
+ return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON,
+ name)
- # Now, we have the implicit threshold, so compute the recall:
- return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON,
- name)
+ return control_flow_ops.cond(precisions[tf_index] >= precision,
+ _return_good_recall, lambda: .0)
def recall_at_precision(labels,
@@ -2561,7 +2587,8 @@ def recall_at_precision(labels,
num_thresholds=200,
metrics_collections=None,
updates_collections=None,
- name=None):
+ name=None,
+ strict_mode=False):
"""Computes `recall` at `precision`.
The `recall_at_precision` function creates four local variables,
@@ -2593,6 +2620,11 @@ def recall_at_precision(labels,
updates_collections: An optional list of collections that `update_op` should
be added to.
name: An optional variable_scope name.
+ strict_mode: If true and there exists a threshold where the precision is
+ above the target precision, return the corresponding recall at the
+ threshold. Otherwise, return 0. If false, find the threshold where the
+ precision is closest to the target precision and return the recall at the
+ threshold.
Returns:
recall: A scalar `Tensor` representing the recall at the given
@@ -2621,10 +2653,11 @@ def recall_at_precision(labels,
predictions, labels, thresholds, weights)
recall = _compute_recall_at_precision(values['tp'], values['fp'],
- values['fn'], precision, 'value')
+ values['fn'], precision, 'value',
+ strict_mode)
update_op = _compute_recall_at_precision(update_ops['tp'], update_ops['fp'],
update_ops['fn'], precision,
- 'update_op')
+ 'update_op', strict_mode)
if metrics_collections:
ops.add_to_collections(metrics_collections, recall)
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 1c2c17960a..024bd54912 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -3467,6 +3467,60 @@ class RecallAtPrecisionTest(test.TestCase):
self.assertAlmostEqual(target_recall, sess.run(update_op))
self.assertAlmostEqual(target_recall, recall.eval())
+ def _test_strict_mode(self, strict_mode, target_precision, expected_recall):
+ num_thresholds = 11
+ predictions_values = [.2, .3, .5, .6, .7, .8, .9, .9, .9, .1]
+ labels_values = [1, 1, 0, 0, 0, 0, 0, 0, 0, 1]
+ # Resulting thresholds and the corresponding precision and recall values at
+ # each threshold:
+ # Thresholds [0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
+ # precisions: [0.3 0.2 0.1 0 0 0 0 0 0]
+ # recalls: [1.0 0.7 0.3 0 0 0 0 0 0]
+ predictions = constant_op.constant(
+ predictions_values, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(labels_values)
+ recall, update_op = metrics.recall_at_precision(
+ labels,
+ predictions,
+ num_thresholds=num_thresholds,
+ precision=target_precision,
+ strict_mode=strict_mode)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ self.assertAlmostEqual(expected_recall, sess.run(update_op))
+ self.assertAlmostEqual(expected_recall, recall.eval())
+
+ def testStrictMode_Off(self):
+ # strict_mode is turned off and return the recall at the threshold where the
+ # precision (0.3) is closest to target precision (0.9). The recall
+ # corresponding to the threshold is 1.0.
+ self._test_strict_mode(
+ strict_mode=False, target_precision=0.9, expected_recall=1.0)
+
+ def testStrictMode_OnAndFail(self):
+ # strict_mode is turned on and we fail to reach the target precision at any
+ # threshold.
+ # Target precision: 0.9
+ # Diff: [-0.6 -0.7 -0.8 -0.9 -0.9 -0.9 -0.9 -0.9 -0.9]
+ # Reciprocal: [-1.6 -1.4 -1.3 -1.1 -1.1 -1.1 -1.1 -1.1 -1.1]
+ # Max index: 3 and corresponding precision is: 0 which is smaller than
+ # target precsion 0.9. As a result, the expected recall is 0.
+ self._test_strict_mode(
+ strict_mode=True, target_precision=0.9, expected_recall=.0)
+
+ def testStrictMode_OnAndSucceed(self):
+ # strict_mode is on and we can reach the target precision at certain
+ # threshold.
+ # Target precision: 0.2
+ # Diff: [0.1 0 -0.1 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2]
+ # Reciprocal: [10 infty -10.0 -5.0 -5.0 -5.0 -5.0 -5.0 -5.0]
+ # Max index: 1 and corresponding precision is: 0.2 which is no smaller than
+ # target precsion 0.2. In this case, we return the recall at index 1, which
+ # is 2.0/3 (0.7).
+ self._test_strict_mode(
+ strict_mode=True, target_precision=0.2, expected_recall=2.0 / 3)
+
class PrecisionAtRecallTest(test.TestCase):
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
index d43884481a..99c5800391 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
@@ -130,7 +130,11 @@ void TensorDataSet::RandomSample(int example,
num_total_features += num_sparse;
}
}
- int rand_feature = rng_->Uniform(num_total_features);
+ int rand_feature = 0;
+ {
+ mutex_lock lock(mu_);
+ rand_feature = rng_->Uniform(num_total_features);
+ }
if (rand_feature < available_features_.size()) { // it's dense.
*feature_id = available_features_[rand_feature];
*type = input_spec_.GetDenseFeatureType(rand_feature);
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
index 95f75b4d7e..4945b53007 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
@@ -25,6 +25,7 @@
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace tensorforest {
@@ -120,6 +121,8 @@ class TensorDataSet {
int32 split_sampling_random_seed_;
std::unique_ptr<random::PhiloxRandom> single_rand_;
std::unique_ptr<random::SimplePhilox> rng_;
+ // Mutex for using random number generator.
+ mutable mutex mu_;
};
} // namespace tensorforest
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index a0fc3e43a9..122a67a407 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -279,6 +279,7 @@ tf_cuda_library(
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
+ "//tensorflow/core:framework",
"//tensorflow/core:framework_lite",
"//tensorflow/core:gpu_runtime",
"//tensorflow/core:graph",
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 0f5abe6898..c98b07ad8b 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/core/framework/node_def.pb.h" // NOLINT
#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.pb.h" // NOLINT
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/algorithm.h"
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index a5e8277ba5..1d1cb48e8e 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -111,17 +111,24 @@ def reset_tpu_sessions():
# Work-around dependency cycle between DistributionStrategy and TPU lib.
-def TPUDistributionStrategy(tpu_cluster_resolver=None): # pylint: disable=invalid-name
+def TPUDistributionStrategy(tpu_cluster_resolver=None, num_cores=None): # pylint: disable=invalid-name
"""Construct a TPUDistributionStrategy."""
from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top
- # TODO -- remove this when TPUStrategy API is consistent (b/112705069)
+ # TODO(b/112705069): Remove this when TPUStrategy API is consistent.
+ # We are including this for (a) backwards compatibility for open sourced
+ # releases of TensorFlow and (b) to work around a circular dependency
+ # where keras_support and tpu_strategy depends on each other. Once we release
+ # a final version and remove support for the old API, this will be deleted.
+ # (See bug above for more details)
if tpu_cluster_resolver is None:
tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('')
args, _, _, _ = tf_inspect.getargspec(tpu_strategy.TPUStrategy.__init__)
- if len(args) == 3:
+ if len(args) == 4:
logging.info('Detected new TPUStrategy API.')
- return tpu_strategy.TPUStrategy(tpu_cluster_resolver, steps_per_run=1)
+ return tpu_strategy.TPUStrategy(tpu_cluster_resolver,
+ steps_per_run=1,
+ num_cores=num_cores)
else:
logging.info('Detected old TPUStrategy API.')
strategy = tpu_strategy.TPUStrategy(num_cores_per_host=8)
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 60830b7d60..836c3ce34e 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -375,6 +375,7 @@ cc_library(
":lib_platform",
":platform_base",
"//tensorflow/core/platform/default/build_config:port",
+ "@com_google_absl//absl/base",
"@snappy",
],
)
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index b5a51d2526..97b6971c5b 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -37,6 +37,8 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/denormal.h"
+#include "tensorflow/core/platform/setround.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
@@ -553,6 +555,11 @@ bool ReplaceTensorWithConstant(
Status ConstantFold(const ConstantFoldingOptions& opts,
FunctionLibraryRuntime* function_library, Env* env,
Device* partition_device, Graph* graph, bool* was_mutated) {
+ // TensorFlow flushes denormals to zero and rounds to nearest, so we do
+ // the same here.
+ port::ScopedFlushDenormal flush;
+ port::ScopedSetRound round(FE_TONEAREST);
+
DumpGraph("Before", graph);
ConstantFoldNameGenerator generate_new_name = opts.generate_new_name;
if (generate_new_name == nullptr) {
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc
index ea1b04feeb..4bc88ffc8c 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/framework/tensor.h"
@@ -36,4 +37,12 @@ void GPUDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
GPUUtil::CopyGPUTensorToCPU(device, this, device_tensor, cpu_tensor, done);
}
+Status GPUDeviceContext::ThenExecute(Device* device, se::Stream* stream,
+ std::function<void()> func) {
+ const DeviceBase::GpuDeviceInfo* gpu_info =
+ device->tensorflow_gpu_device_info();
+ gpu_info->event_mgr->ThenExecute(stream, func);
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu_device_context.h b/tensorflow/core/common_runtime/gpu_device_context.h
index 8370c63842..3603808152 100644
--- a/tensorflow/core/common_runtime/gpu_device_context.h
+++ b/tensorflow/core/common_runtime/gpu_device_context.h
@@ -60,6 +60,9 @@ class GPUDeviceContext : public DeviceContext {
void MaintainLifetimeOnStream(const Tensor* t,
se::Stream* stream) const override {}
+ Status ThenExecute(Device* device, se::Stream* stream,
+ std::function<void()> func) override;
+
private:
int stream_id_;
// The default primary stream to use for this context.
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index b184fd91e1..794250a2c1 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -89,6 +89,15 @@ class DeviceContext : public core::RefCounted {
Tensor* cpu_tensor, StatusCallback done) {
done(errors::Internal("Unrecognized device type in device-to-CPU Copy"));
}
+
+ // If possible, wait for all events on *stream to complete then execute func.
+ // A non-OK Status is returned otherwise. The stream argument should be the
+ // one provided by GpuDeviceInfo. This function is not applicable to devices
+ // that don't provide such a value.
+ virtual Status ThenExecute(Device* device, stream_executor::Stream* stream,
+ std::function<void()> func) {
+ return errors::Internal("ThenExecute not supported by device");
+ }
};
// map[i] is the DeviceContext* for the node with id i, if i < map.size().
diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc
index 12e3e46f65..f543dca49e 100644
--- a/tensorflow/core/grappler/clusters/virtual_cluster.cc
+++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc
@@ -45,6 +45,8 @@ VirtualCluster::VirtualCluster(const DeviceSet* device_set)
for (const auto& device : device_set_->devices()) {
DeviceProperties props = GetDeviceInfo(device->parsed_name());
if (props.type() == "UNKNOWN") continue;
+ auto attrs = device->attributes();
+ props.set_memory_size(attrs.memory_limit());
devices_[device->name()] = props;
}
}
diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
index a60e3c7a9f..0690640ffa 100644
--- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <limits>
#include <unordered_map>
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/graph/types.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
diff --git a/tensorflow/core/grappler/costs/graph_memory.cc b/tensorflow/core/grappler/costs/graph_memory.cc
index a5736d40b1..b01aca610a 100644
--- a/tensorflow/core/grappler/costs/graph_memory.cc
+++ b/tensorflow/core/grappler/costs/graph_memory.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 231c7c63be..6710ff9df3 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
@@ -804,8 +805,9 @@ class SymbolicShapeRefiner {
CHECK_NOTNULL(function_library_.Find(function_node->op()));
GrapplerFunctionItem grappler_function_item;
- TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
- *function_def, function_library_, &grappler_function_item));
+ TF_RETURN_IF_ERROR(
+ MakeGrapplerFunctionItem(*function_def, function_library_,
+ graph_def_version_, &grappler_function_item));
if (grappler_function_item.inputs().size() > function_node->input_size()) {
return errors::FailedPrecondition(
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 5acfb56b05..8938b7c32e 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -18,8 +18,10 @@ limitations under the License.
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/grappler/clusters/single_machine.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
@@ -783,6 +785,46 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
EXPECT_EQ("float: [128,256]", PropToString(prop));
}
+TEST_F(GraphPropertiesTest, FunctionWithScalarInputTest) {
+ // Create graph with a function that takes a scalar value so that we use
+ // Placeholder with scalar as for input to the function shape inference.
+ // Placeholder -> Identity -> MyFunc, where MyFunc simply takes Identity of
+ // the input; all tensors are scalars.
+ FunctionDefLibrary library;
+ *library.add_function() = FunctionDefHelper::Create(
+ "MyFunc", // Name
+ {"x: float"}, // Inputs
+ {"out: float"}, // Outputs
+ {}, // Attrs
+ {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_FLOAT}}}}, // Nodes
+ {{"out", "a:output:0"}}); // Returns
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
+ Output placeholder =
+ ops::Placeholder(s.WithOpName("Placeholder"), DataType::DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({})));
+ Output identity = ops::Identity(s.WithOpName("Identity"), placeholder);
+ auto _identity = tensorflow::ops::AsNodeOut(s, identity);
+ auto builder =
+ tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
+ tensorflow::Node* func_op;
+ TF_CHECK_OK(builder.Input(_identity).Finalize(s.graph(), &func_op));
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ // Tensorflow version < 21 infers output shape of Placeholder with empty shape
+ // as unknown, instead of scalar.
+ EXPECT_GT(item.graph.versions().producer(), 21);
+
+ // MyFunc output shouldn't be unknown rank.
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("MyFunc");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
+ EXPECT_FALSE(out_prop0.shape().unknown_rank());
+}
+
TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) {
// Test graph produced in python using:
/*
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 0341d7f8e1..71f4d9fd05 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/clusters/utils.h"
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index 9e579098ef..998bd59dce 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 6e3ebdee12..037a823096 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -880,10 +880,15 @@ Costs VirtualScheduler::Summary() const {
// Print per device summary
VLOG(1) << "Devices:";
Costs critical_path_costs = Costs::ZeroCosts();
+ std::vector<string> device_names;
+ device_names.reserve(device_.size());
+ for (auto& it : device_) {
+ device_names.push_back(it.first);
+ }
+ std::sort(device_names.begin(), device_names.end());
- for (const auto& device : device_) {
- const auto& name = device.first;
- const auto& state = device.second;
+ for (const auto& name : device_names) {
+ const auto& state = device_.at(name);
std::map<string, int64> op_to_memory;
// First profile only persistent memory usage.
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index b1373d8317..02a379fca8 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 288587ce9b..029515ad3c 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/variable.pb.h"
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index caaa5ac8db..a8af169e28 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -827,11 +827,6 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core/grappler:grappler_item",
- "//tensorflow/core/grappler:op_types",
- "//tensorflow/core/grappler:utils",
- "//tensorflow/core/grappler/clusters:cluster",
- "//tensorflow/core/grappler/costs:graph_properties",
],
)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 889445bbd6..4fb2fe6883 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index b9765b9292..5bf45af6b3 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -3047,6 +3047,39 @@ TEST_F(ConstantFoldingTest, TensorArraySize) {
test::ExpectTensorEqual<int32>(tensors_expected[1], tensors_actual[1]);
}
+TEST_F(ConstantFoldingTest, FoldingPreservesDenormalFlushing) {
+ // Multiplying min() with 0.1 gives a denormal without FTZ and zero with FTZ.
+ // Make sure constant folding behaves the same way as TensorFlow.
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ Output a =
+ ops::Const(s.WithOpName("a"), std::numeric_limits<float>::min(), {1});
+ Output b = ops::Const(s.WithOpName("b"), 0.1f, {1});
+ Output c = ops::Mul(s.WithOpName("c"), a, b);
+
+ GrapplerItem item;
+ item.fetch.push_back("c");
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(1, output.node_size());
+
+ const NodeDef& node_d = output.node(0);
+ EXPECT_EQ("c", node_d.name());
+ EXPECT_EQ("Const", node_d.op());
+
+ std::vector<string> fetch = {"c"};
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ auto tensors = EvaluateNodes(output, fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index 23f35050f2..92551a0459 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
diff --git a/tensorflow/core/grappler/optimizers/evaluation_utils.cc b/tensorflow/core/grappler/optimizers/evaluation_utils.cc
index 00ad7494f4..79d9ea1608 100644
--- a/tensorflow/core/grappler/optimizers/evaluation_utils.cc
+++ b/tensorflow/core/grappler/optimizers/evaluation_utils.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/denormal.h"
diff --git a/tensorflow/core/grappler/optimizers/evaluation_utils.h b/tensorflow/core/grappler/optimizers/evaluation_utils.h
index 8414b5b8ca..c9dfb6dc0b 100644
--- a/tensorflow/core/grappler/optimizers/evaluation_utils.h
+++ b/tensorflow/core/grappler/optimizers/evaluation_utils.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace Eigen {
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index 645e4c2087..56364f0095 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -453,6 +453,7 @@ Status InitializeFunctionSpecializationSignature(
}
Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
+ const int graph_def_version,
FunctionOptimizerContext* ctx,
GraphDef* optimized_graph) {
VLOG(2) << "Specialize function instantiation: "
@@ -492,7 +493,8 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
// Make a GrapplerFunctionItem and convert it back to FunctionDef after
// pushing all constant inputs into the function body.
GrapplerFunctionItem item;
- TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, func_attr, flib,
+ graph_def_version, &item));
// Push const inputs into the function body, and keep track of their control
// dependencies.
@@ -576,15 +578,15 @@ NodeDef InlinedFunctionOutputsNode(const NodeDef& func_node,
Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
const FunctionOptimizerContext& ctx,
- GraphDef* optimized_graph) {
+ const int graph_def_version, GraphDef* optimized_graph) {
VLOG(2) << "Inline function instantiation: " << SummarizeNodeDef(func_node);
const std::unordered_map<string, AttrValue> func_attr(
func_node.attr().begin(), func_node.attr().end());
GrapplerFunctionItem item;
- Status item_status =
- MakeGrapplerFunctionItem(func, func_attr, ctx.function_library(), &item);
+ Status item_status = MakeGrapplerFunctionItem(
+ func, func_attr, ctx.function_library(), graph_def_version, &item);
if (!item_status.ok()) {
return errors::InvalidArgument("Failed to inline function ", func_node.op(),
@@ -645,7 +647,8 @@ Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
if (func_body_node_func != nullptr) {
// Recursively inline function calls.
TF_RETURN_IF_ERROR(InlineFunction(func_body_node, *func_body_node_func,
- ctx, optimized_graph));
+ ctx, graph_def_version,
+ optimized_graph));
} else {
// Annotate the node with the function attributes.
for (const auto& attr : func.attr()) {
@@ -824,7 +827,8 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
if (inline_func && ctx.IsInlinedFunction(func_name)) {
// Inline function body into the optimized graph}
TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
- InlineFunction(node, *func, ctx, optimized_graph));
+ InlineFunction(node, *func, ctx, item.graph.versions().producer(),
+ optimized_graph));
continue;
}
@@ -837,7 +841,8 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// TODO(ezhulenev): Specialize function call if input has a known shape.
// Specialize function body for its instantiation attributes and inputs.
TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
- SpecializeFunction(node, *func, &ctx, optimized_graph));
+ SpecializeFunction(node, *func, item.graph.versions().producer(),
+ &ctx, optimized_graph));
continue;
}
}
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index 1be5f8dcc2..91794cefe5 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/costs/graph_memory.h"
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index e778b7879d..5fd34efeb1 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -361,7 +361,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// Make a GrapplerItem from a FunctionDef.
GrapplerFunctionItem func_item;
- TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, flib, &func_item));
+ TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
+ func, flib, item.graph.versions().producer(), &func_item));
// Optimize function body graph.
GraphDef optimized_func_graph;
diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc
index 89847f83d4..b033cff8e6 100644
--- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/testlib.h"
diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
index 26c54df56b..caa0b7b0cb 100644
--- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/graph_view.h"
diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc
index 462b752316..a2c363ea6e 100644
--- a/tensorflow/core/grappler/utils/functions.cc
+++ b/tensorflow/core/grappler/utils/functions.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -307,8 +308,8 @@ GrapplerFunctionItem::GrapplerFunctionItem(
const AttrValueMap& func_attr,
const std::vector<InputArgExpansion>& input_arg_expansions,
const std::vector<OutputArgExpansion>& output_arg_expansions,
- const std::vector<string>& keep_nodes, bool is_stateful,
- GraphDef&& function_body)
+ const std::vector<string>& keep_nodes, const int graph_def_version,
+ bool is_stateful, GraphDef&& function_body)
: description_(description),
func_attr_(func_attr),
input_arg_expansions_(input_arg_expansions),
@@ -318,6 +319,7 @@ GrapplerFunctionItem::GrapplerFunctionItem(
keep_ops = keep_nodes;
// Swap the graph body.
graph.Swap(&function_body);
+ graph.mutable_versions()->set_producer(graph_def_version);
// Fill the feed nodes with input placeholders.
for (const InputArgExpansion& input_arg : input_arg_expansions_) {
for (const string& placeholder : input_arg.placeholders) {
@@ -472,6 +474,7 @@ Status InstantiationBodyParameters(
Status MakeGrapplerFunctionItem(const FunctionDef& func,
const AttrValueMap& func_instantiation_attr,
const FunctionLibraryDefinition& flib,
+ const int graph_def_version,
GrapplerFunctionItem* item) {
const OpDef& signature = func.signature();
@@ -595,14 +598,17 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
*item = GrapplerFunctionItem(
/*func_name=*/signature.name(), /*description=*/signature.description(),
/*func_attr=*/AttrValueMap(func.attr().begin(), func.attr().end()),
- inputs, outputs, keep_nodes, is_stateful, std::move(function_body));
+ inputs, outputs, keep_nodes, graph_def_version, is_stateful,
+ std::move(function_body));
return Status::OK();
}
Status MakeGrapplerFunctionItem(const FunctionDef& func,
const FunctionLibraryDefinition& flib,
+ const int graph_def_version,
GrapplerFunctionItem* item) {
- return MakeGrapplerFunctionItem(func, AttrValueMap(), flib, item);
+ return MakeGrapplerFunctionItem(func, AttrValueMap(), flib, graph_def_version,
+ item);
}
// Register GrapplerFunctionItem input arg expansion and function body outputs
diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h
index 9f607dc2ee..61588ceb83 100644
--- a/tensorflow/core/grappler/utils/functions.h
+++ b/tensorflow/core/grappler/utils/functions.h
@@ -141,8 +141,8 @@ class GrapplerFunctionItem : public GrapplerItem {
const AttrValueMap& func_attr,
const std::vector<InputArgExpansion>& input_arg_expansions,
const std::vector<OutputArgExpansion>& output_arg_expansions,
- const std::vector<string>& keep_nodes, bool is_stateful,
- GraphDef&& function_body);
+ const std::vector<string>& keep_nodes, const int versions,
+ bool is_stateful, GraphDef&& function_body);
const string& description() const;
@@ -222,6 +222,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
Status MakeGrapplerFunctionItem(const FunctionDef& func,
const AttrValueMap& func_instantiation_attr,
const FunctionLibraryDefinition& flib,
+ const int graph_def_version,
GrapplerFunctionItem* item);
// Make a GrapplerFunction item from the function definition. Function must be
@@ -231,6 +232,7 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
// without specializing it to it's instantiation attributes (at least types)?
Status MakeGrapplerFunctionItem(const FunctionDef& func,
const FunctionLibraryDefinition& flib,
+ const int graph_def_version,
GrapplerFunctionItem* item);
// Make a FunctionDef from the GrapplerFunctionItem. Use function library
diff --git a/tensorflow/core/grappler/utils/functions_test.cc b/tensorflow/core/grappler/utils/functions_test.cc
index b2d059e0ac..b51f2781b8 100644
--- a/tensorflow/core/grappler/utils/functions_test.cc
+++ b/tensorflow/core/grappler/utils/functions_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace grappler {
@@ -239,7 +240,8 @@ TEST_F(FunctionsTest, FromSimpleFunctionDef) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ("XTimesTwo", item.id);
EXPECT_EQ(4, item.function_body().node_size());
@@ -314,7 +316,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithMultiOutputNodes) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ("SubGrad", item.id);
EXPECT_EQ(12, item.function_body().node_size());
@@ -395,7 +398,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithNestedFuncs) {
func_attr["T"].set_type(DT_FLOAT);
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
int count = 0;
for (const NodeDef &node : item.function_body().node()) {
@@ -456,7 +460,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithOutputMappings) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ(1, item.output_size());
EXPECT_EQ("Exp", item.output(0).output_tensors[0]);
@@ -499,7 +504,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithInputForwarding) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ("ForwardInputs", item.id);
EXPECT_EQ(5, item.function_body().node_size());
@@ -545,7 +551,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ(0, item.input_size());
EXPECT_EQ(1, item.output_size());
@@ -584,7 +591,8 @@ TEST_F(FunctionsTest, MakeFunctionDef) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
FunctionDef specialized;
TF_EXPECT_OK(MakeFunctionDef(item, flib, &specialized));
@@ -622,7 +630,8 @@ TEST_F(FunctionsTest, ReplaceInputWithConst) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ(2, item.input_size());
EXPECT_EQ(1, item.output_size());
@@ -713,7 +722,8 @@ TEST_F(FunctionsTest, SwapFunctionBodyAndMakeFunctionDef) {
FunctionLibraryDefinition flib(OpRegistry::Global(), lib_def);
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
// Replace function body with identity function
item.SwapFunctionBody(std::move(id_func_body));
@@ -754,7 +764,8 @@ TEST_F(FunctionsTest, FunctionDefGrapplerFunctionItemRoundTrip) {
GrapplerFunctionItem item;
std::unordered_map<string, AttrValue> func_attr;
func_attr["T"].set_type(DT_INT32);
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
FunctionDef func2;
TF_EXPECT_OK(MakeFunctionDef(item, flib, &func2));
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 82ff2a365d..7716043055 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -237,6 +237,7 @@ cc_library(
srcs = ["parse_example_dataset_op.cc"],
deps = [
":parallel_map_iterator",
+ "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
],
)
diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
index cc5007ee92..6a0522e4f3 100644
--- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <deque>
+#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/kernels/data/parallel_map_iterator.h"
#include "tensorflow/core/util/example_proto_fast_parsing.h"
@@ -166,8 +167,6 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape>& output_shapes)
: DatasetBase(DatasetContext(ctx)),
input_(input),
- device_threadpool_(
- ctx->device()->tensorflow_cpu_worker_threads()->workers),
dense_defaults_(std::move(dense_defaults)),
sparse_keys_(std::move(sparse_keys)),
dense_keys_(std::move(dense_keys)),
@@ -190,6 +189,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor> input_element,
std::vector<Tensor>* result, StatusCallback done) {
(*ctx->runner())([this, ctx, input_element, result, done]() {
+ thread::ThreadPool* device_threadpool =
+ ctx->lib()->device()->tensorflow_cpu_worker_threads()->workers;
std::vector<string> slice_vec;
for (Tensor t : input_element) {
auto serialized_t = t.flat<string>();
@@ -205,7 +206,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
config.collect_feature_stats = true;
}
example::Result example_result;
- Status s = FastParseExample(config, slice_vec, {}, device_threadpool_,
+ Status s = FastParseExample(config, slice_vec, {}, device_threadpool,
&example_result);
if (s.ok()) {
(*result).resize(key_to_output_index_.size());
@@ -339,7 +340,6 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
private:
const DatasetBase* const input_;
- thread::ThreadPool* const device_threadpool_;
const std::vector<Tensor> dense_defaults_;
const std::vector<string> sparse_keys_;
const std::vector<string> dense_keys_;
diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc
index 9ec83b867f..aa70ee06f5 100644
--- a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc
+++ b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc
@@ -196,6 +196,9 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
}
output(sample) = z * stddev + mean;
sample++;
+ if (sample >= limit_sample) {
+ break;
+ }
numIterations = 0;
} else {
numIterations++;
diff --git a/tensorflow/core/lib/monitoring/collection_registry.cc b/tensorflow/core/lib/monitoring/collection_registry.cc
index 8c28620ff9..fface033cb 100644
--- a/tensorflow/core/lib/monitoring/collection_registry.cc
+++ b/tensorflow/core/lib/monitoring/collection_registry.cc
@@ -38,15 +38,15 @@ void Collector::CollectMetricDescriptor(
mutex_lock l(mu_);
return collected_metrics_->metric_descriptor_map
.insert(std::make_pair(
- std::string(metric_def->name()),
+ string(metric_def->name()),
std::unique_ptr<MetricDescriptor>(new MetricDescriptor())))
.first->second.get();
}();
- metric_descriptor->name = std::string(metric_def->name());
- metric_descriptor->description = std::string(metric_def->description());
+ metric_descriptor->name = string(metric_def->name());
+ metric_descriptor->description = string(metric_def->description());
for (const StringPiece label_name : metric_def->label_descriptions()) {
- metric_descriptor->label_names.push_back(std::string(label_name));
+ metric_descriptor->label_names.emplace_back(label_name);
}
metric_descriptor->metric_kind = metric_def->kind();
diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h
index 20f0444f8b..c204d52cfe 100644
--- a/tensorflow/core/lib/monitoring/collection_registry.h
+++ b/tensorflow/core/lib/monitoring/collection_registry.h
@@ -72,7 +72,7 @@ class MetricCollector {
registration_time_millis_(registration_time_millis),
collector_(collector),
point_set_(point_set) {
- point_set_->metric_name = std::string(metric_def->name());
+ point_set_->metric_name = string(metric_def->name());
}
const MetricDef<metric_kind, Value, NumLabels>* const metric_def_;
@@ -261,7 +261,7 @@ class Collector {
auto* const point_set = [&]() {
mutex_lock l(mu_);
return collected_metrics_->point_set_map
- .insert(std::make_pair(std::string(metric_def->name()),
+ .insert(std::make_pair(string(metric_def->name()),
std::unique_ptr<PointSet>(new PointSet())))
.first->second.get();
}();
diff --git a/tensorflow/core/lib/monitoring/metric_def.h b/tensorflow/core/lib/monitoring/metric_def.h
index 6f94685665..756e5c2af8 100644
--- a/tensorflow/core/lib/monitoring/metric_def.h
+++ b/tensorflow/core/lib/monitoring/metric_def.h
@@ -98,8 +98,8 @@ class AbstractMetricDef {
const std::vector<string>& label_descriptions)
: kind_(kind),
value_type_(value_type),
- name_(std::string(name)),
- description_(std::string(description)),
+ name_(name),
+ description_(description),
label_descriptions_(std::vector<string>(label_descriptions.begin(),
label_descriptions.end())) {}
diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h
index e0a5281d68..959290ba8c 100644
--- a/tensorflow/core/lib/strings/numbers.h
+++ b/tensorflow/core/lib/strings/numbers.h
@@ -140,11 +140,11 @@ inline bool ProtoParseNumeric(StringPiece s, uint64* value) {
}
inline bool ProtoParseNumeric(StringPiece s, float* value) {
- return safe_strtof(std::string(s).c_str(), value);
+ return safe_strtof(s, value);
}
inline bool ProtoParseNumeric(StringPiece s, double* value) {
- return safe_strtod(std::string(s).c_str(), value);
+ return safe_strtod(s, value);
}
// Convert strings to number of type T.
diff --git a/tensorflow/core/lib/strings/str_util.cc b/tensorflow/core/lib/strings/str_util.cc
index cab8f81585..3aba5ec80e 100644
--- a/tensorflow/core/lib/strings/str_util.cc
+++ b/tensorflow/core/lib/strings/str_util.cc
@@ -332,7 +332,7 @@ string StringReplace(StringPiece s, StringPiece oldsub, StringPiece newsub,
bool replace_all) {
// TODO(jlebar): We could avoid having to shift data around in the string if
// we had a StringPiece::find() overload that searched for a StringPiece.
- string res = std::string(s);
+ string res(s);
size_t pos = 0;
while ((pos = res.find(oldsub.data(), pos, oldsub.size())) != string::npos) {
res.replace(pos, oldsub.size(), newsub.data(), newsub.size());
@@ -448,8 +448,7 @@ bool SplitAndParseAsFloats(StringPiece text, char delim,
std::vector<float>* result) {
return SplitAndParseAsInts<float>(text, delim,
[](StringPiece str, float* value) {
- return strings::safe_strtof(
- std::string(str).c_str(), value);
+ return strings::safe_strtof(str, value);
},
result);
}
diff --git a/tensorflow/core/lib/strings/str_util.h b/tensorflow/core/lib/strings/str_util.h
index 58e87fcb9e..9f52cf29fc 100644
--- a/tensorflow/core/lib/strings/str_util.h
+++ b/tensorflow/core/lib/strings/str_util.h
@@ -205,7 +205,7 @@ std::vector<string> Split(StringPiece text, StringPiece delims, Predicate p) {
if ((i == text.size()) || (delims.find(text[i]) != StringPiece::npos)) {
StringPiece token(text.data() + token_start, i - token_start);
if (p(token)) {
- result.push_back(std::string(token));
+ result.emplace_back(token);
}
token_start = i + 1;
}
diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc
index 47c59d435b..afc4201e53 100644
--- a/tensorflow/core/platform/env.cc
+++ b/tensorflow/core/platform/env.cc
@@ -92,7 +92,7 @@ Env::Env() : file_system_registry_(new FileSystemRegistryImpl) {}
Status Env::GetFileSystemForFile(const string& fname, FileSystem** result) {
StringPiece scheme, host, path;
io::ParseURI(fname, &scheme, &host, &path);
- FileSystem* file_system = file_system_registry_->Lookup(std::string(scheme));
+ FileSystem* file_system = file_system_registry_->Lookup(string(scheme));
if (!file_system) {
if (scheme.empty()) {
scheme = "[local]";
@@ -166,7 +166,7 @@ bool Env::FilesExist(const std::vector<string>& files,
for (const auto& file : files) {
StringPiece scheme, host, path;
io::ParseURI(file, &scheme, &host, &path);
- files_per_fs[std::string(scheme)].push_back(file);
+ files_per_fs[string(scheme)].push_back(file);
}
std::unordered_map<string, Status> per_file_status;
diff --git a/tensorflow/core/platform/file_system.cc b/tensorflow/core/platform/file_system.cc
index 922773684b..3ab542a5d8 100644
--- a/tensorflow/core/platform/file_system.cc
+++ b/tensorflow/core/platform/file_system.cc
@@ -158,7 +158,7 @@ Status FileSystem::RecursivelyCreateDir(const string& dirname) {
std::reverse(sub_dirs.begin(), sub_dirs.end());
// Now create the directories.
- string built_path = std::string(remaining_dir);
+ string built_path(remaining_dir);
for (const StringPiece sub_dir : sub_dirs) {
built_path = io::JoinPath(built_path, sub_dir);
Status status = CreateDir(io::CreateURI(scheme, host, built_path));
diff --git a/tensorflow/core/platform/file_system_helper.cc b/tensorflow/core/platform/file_system_helper.cc
index 0ba0e6304f..342cf28e38 100644
--- a/tensorflow/core/platform/file_system_helper.cc
+++ b/tensorflow/core/platform/file_system_helper.cc
@@ -59,7 +59,7 @@ Status GetMatchingPaths(FileSystem* fs, Env* env, const string& pattern,
string fixed_prefix = pattern.substr(0, pattern.find_first_of("*?[\\"));
string eval_pattern = pattern;
std::vector<string> all_files;
- string dir = std::string(io::Dirname(fixed_prefix));
+ string dir(io::Dirname(fixed_prefix));
// If dir is empty then we need to fix up fixed_prefix and eval_pattern to
// include . as the top level directory.
if (dir.empty()) {
diff --git a/tensorflow/core/platform/file_system_test.cc b/tensorflow/core/platform/file_system_test.cc
index c0a16c95f9..a637d42a92 100644
--- a/tensorflow/core/platform/file_system_test.cc
+++ b/tensorflow/core/platform/file_system_test.cc
@@ -125,7 +125,7 @@ class InterPlanetaryFileSystem : public NullFileSystem {
ASSERT_EQ(scheme, "ipfs");
ASSERT_EQ(host, "solarsystem");
str_util::ConsumePrefix(&path, "/");
- *parsed_path = std::string(path);
+ *parsed_path = string(path);
}
std::map<string, std::set<string>> celestial_bodies_ = {
diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc
index b281acb2b0..55f1e30880 100644
--- a/tensorflow/core/util/command_line_flags.cc
+++ b/tensorflow/core/util/command_line_flags.cc
@@ -32,7 +32,7 @@ bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
if (str_util::ConsumePrefix(&arg, "--") &&
str_util::ConsumePrefix(&arg, flag) &&
str_util::ConsumePrefix(&arg, "=")) {
- *value_parsing_ok = hook(std::string(arg));
+ *value_parsing_ok = hook(string(arg));
return true;
}
diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h
index aee647a1b3..5e2aeb7830 100644
--- a/tensorflow/core/util/ctc/ctc_beam_search.h
+++ b/tensorflow/core/util/ctc/ctc_beam_search.h
@@ -259,6 +259,16 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
} else {
max_coeff = raw_input.maxCoeff();
}
+
+ // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))).
+ float logsumexp = 0.0;
+ for (int j = 0; j < raw_input.size(); ++j) {
+ logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff);
+ }
+ logsumexp = Eigen::numext::log(logsumexp);
+ // Final normalization offset to get correct log probabilities.
+ float norm_offset = max_coeff + logsumexp;
+
const float label_selection_input_min =
(label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
: -std::numeric_limits<float>::infinity();
@@ -290,10 +300,10 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
beam_scorer_->GetStateExpansionScore(b->state, previous));
}
// Plabel(l=abc @ t=6) *= P(c @ 6)
- b->newp.label += raw_input(b->label) - max_coeff;
+ b->newp.label += raw_input(b->label) - norm_offset;
}
// Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
- b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff;
+ b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset;
// P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
@@ -328,6 +338,8 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
const float logit = top_k ? top_k_logits[ind] : raw_input(ind);
// Perform label selection: if input for this label looks very
// unpromising, never evaluate it with a scorer.
+ // We may compare logits instead of log probabilities,
+ // since the difference is the same in both cases.
if (logit < label_selection_input_min) {
continue;
}
@@ -341,7 +353,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
// Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
- c.newp.label = logit - max_coeff +
+ c.newp.label = logit - norm_offset +
beam_scorer_->GetStateExpansionScore(c.state, previous);
// P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
c.newp.total = c.newp.label;
diff --git a/tensorflow/core/util/env_var.cc b/tensorflow/core/util/env_var.cc
index 8d43bcc927..2604a5d66a 100644
--- a/tensorflow/core/util/env_var.cc
+++ b/tensorflow/core/util/env_var.cc
@@ -28,7 +28,7 @@ namespace tensorflow {
Status ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val,
bool* value) {
*value = default_val;
- const char* tf_env_var_val = getenv(std::string(env_var_name).c_str());
+ const char* tf_env_var_val = getenv(string(env_var_name).c_str());
if (tf_env_var_val == nullptr) {
return Status::OK();
}
@@ -48,7 +48,7 @@ Status ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val,
Status ReadInt64FromEnvVar(StringPiece env_var_name, int64 default_val,
int64* value) {
*value = default_val;
- const char* tf_env_var_val = getenv(std::string(env_var_name).c_str());
+ const char* tf_env_var_val = getenv(string(env_var_name).c_str());
if (tf_env_var_val == nullptr) {
return Status::OK();
}
@@ -62,11 +62,11 @@ Status ReadInt64FromEnvVar(StringPiece env_var_name, int64 default_val,
Status ReadStringFromEnvVar(StringPiece env_var_name, StringPiece default_val,
string* value) {
- const char* tf_env_var_val = getenv(std::string(env_var_name).c_str());
+ const char* tf_env_var_val = getenv(string(env_var_name).c_str());
if (tf_env_var_val != nullptr) {
*value = tf_env_var_val;
} else {
- *value = std::string(default_val);
+ *value = string(default_val);
}
return Status::OK();
}
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc
index 1fec0010a1..a38cd1d09f 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.cc
+++ b/tensorflow/core/util/example_proto_fast_parsing.cc
@@ -353,7 +353,7 @@ bool TestFastParse(const string& serialized, Example* example) {
// I.e. last entry in the map overwrites all the previous ones.
parsed::FeatureMapEntry& name_and_feature =
parsed_example[parsed_example_size - i - 1];
- string name = std::string(name_and_feature.first);
+ string name(name_and_feature.first);
if ((*features.mutable_feature()).count(name) > 0) continue;
auto& value = (*features.mutable_feature())[name];
diff --git a/tensorflow/docs_src/guide/premade_estimators.md b/tensorflow/docs_src/guide/premade_estimators.md
index a1703058c3..9b64d51b98 100644
--- a/tensorflow/docs_src/guide/premade_estimators.md
+++ b/tensorflow/docs_src/guide/premade_estimators.md
@@ -366,6 +366,8 @@ Running this code yields the following output (or something similar):
Test set accuracy: 0.967
```
+The `eval_result` dictionary also contains the `average_loss` (mean loss per sample), the `loss` (mean loss per mini-batch) and the value of the estimator's `global_step` (the number of training iterations it underwent).
+
### Making predictions (inferring) from the trained model
We now have a trained model that produces good evaluation results.
diff --git a/tensorflow/docs_src/guide/saved_model.md b/tensorflow/docs_src/guide/saved_model.md
index 6c967fd882..33ab891861 100644
--- a/tensorflow/docs_src/guide/saved_model.md
+++ b/tensorflow/docs_src/guide/saved_model.md
@@ -2,7 +2,7 @@
The `tf.train.Saver` class provides methods to save and restore models. The
`tf.saved_model.simple_save` function is an easy way to build a
-`tf.saved_model` suitable for serving. [Estimators](./estimators)
+`tf.saved_model` suitable for serving. [Estimators](../guide/estimators.md)
automatically save and restore variables in the `model_dir`.
## Save and restore variables
diff --git a/tensorflow/js/ops/ts_op_gen.cc b/tensorflow/js/ops/ts_op_gen.cc
index babf55cd5f..fb93bb6d8e 100644
--- a/tensorflow/js/ops/ts_op_gen.cc
+++ b/tensorflow/js/ops/ts_op_gen.cc
@@ -38,6 +38,15 @@ struct ArgDefs {
const ApiDef::Arg& api_def_arg;
};
+// Struct to hold a combo OpDef::AttrDef and ApiDef::Attr for an Op.
+struct OpAttrs {
+ OpAttrs(const OpDef::AttrDef& op_def_attr, const ApiDef::Attr& api_def_attr)
+ : op_def_attr(op_def_attr), api_def_attr(api_def_attr) {}
+
+ const OpDef::AttrDef& op_def_attr;
+ const ApiDef::Attr& api_def_attr;
+};
+
// Helper class to generate TypeScript code for a given OpDef:
class GenTypeScriptOp {
public:
@@ -49,8 +58,12 @@ class GenTypeScriptOp {
private:
void ProcessArgs();
+ void ProcessAttrs();
+ void AddAttrForArg(const string& attr, int arg_index);
+ string InputForAttr(const OpDef::AttrDef& op_def_attr);
void AddMethodSignature();
+ void AddOpAttrs();
void AddMethodReturnAndClose();
const OpDef& op_def_;
@@ -62,6 +75,13 @@ class GenTypeScriptOp {
// Holds in-order vector of Op inputs:
std::vector<ArgDefs> input_op_args_;
+ // Holds in-order vector of Op attributes:
+ std::vector<OpAttrs> op_attrs_;
+
+ // Stores attributes-to-arguments by name:
+ typedef std::unordered_map<string, std::vector<int>> AttrArgIdxMap;
+ AttrArgIdxMap attr_arg_idx_map_;
+
// Holds number of outputs:
int num_outputs_;
};
@@ -73,9 +93,11 @@ GenTypeScriptOp::~GenTypeScriptOp() {}
string GenTypeScriptOp::Code() {
ProcessArgs();
+ ProcessAttrs();
// Generate exported function for Op:
AddMethodSignature();
+ AddOpAttrs();
AddMethodReturnAndClose();
strings::StrAppend(&result_, "\n");
@@ -96,12 +118,52 @@ void GenTypeScriptOp::ProcessArgs() {
<< api_def_.arg_order(i);
continue;
}
+
+ // Map attr names to arg indexes:
+ if (!op_def_arg->type_attr().empty()) {
+ AddAttrForArg(op_def_arg->type_attr(), i);
+ } else if (!op_def_arg->type_list_attr().empty()) {
+ AddAttrForArg(op_def_arg->type_list_attr(), i);
+ }
+ if (!op_def_arg->number_attr().empty()) {
+ AddAttrForArg(op_def_arg->number_attr(), i);
+ }
+
input_op_args_.push_back(ArgDefs(*op_def_arg, *api_def_arg));
}
num_outputs_ = api_def_.out_arg_size();
}
+void GenTypeScriptOp::ProcessAttrs() {
+ for (int i = 0; i < op_def_.attr_size(); i++) {
+ op_attrs_.push_back(OpAttrs(op_def_.attr(i), api_def_.attr(i)));
+ }
+}
+
+void GenTypeScriptOp::AddAttrForArg(const string& attr, int arg_index) {
+ // Keep track of attributes-to-arguments by name. These will be used for
+ // construction Op attributes that require information about the inputs.
+ auto iter = attr_arg_idx_map_.find(attr);
+ if (iter == attr_arg_idx_map_.end()) {
+ attr_arg_idx_map_.insert(AttrArgIdxMap::value_type(attr, {arg_index}));
+ } else {
+ iter->second.push_back(arg_index);
+ }
+}
+
+string GenTypeScriptOp::InputForAttr(const OpDef::AttrDef& op_def_attr) {
+ string inputs;
+ auto arg_list = attr_arg_idx_map_.find(op_def_attr.name());
+ if (arg_list != attr_arg_idx_map_.end()) {
+ for (auto iter = arg_list->second.begin(); iter != arg_list->second.end();
+ ++iter) {
+ strings::StrAppend(&inputs, input_op_args_[*iter].op_def_arg.name());
+ }
+ }
+ return inputs;
+}
+
void GenTypeScriptOp::AddMethodSignature() {
strings::StrAppend(&result_, "export function ", api_def_.endpoint(0).name(),
"(");
@@ -131,6 +193,35 @@ void GenTypeScriptOp::AddMethodSignature() {
}
}
+void GenTypeScriptOp::AddOpAttrs() {
+ strings::StrAppend(&result_, " const opAttrs = [\n");
+
+ bool is_first = true;
+ for (auto& attr : op_attrs_) {
+ if (is_first) {
+ is_first = false;
+ } else {
+ strings::StrAppend(&result_, ",\n");
+ }
+
+ // Append 4 spaces to start:
+ strings::StrAppend(&result_, " ");
+
+ if (attr.op_def_attr.type() == "type") {
+ // Type OpAttributes can be generated from a helper function:
+ strings::StrAppend(&result_, "createTensorsTypeOpAttr('",
+ attr.op_def_attr.name(), "', ",
+ InputForAttr(attr.op_def_attr), ")");
+ } else if (attr.op_def_attr.type() == "int") {
+ strings::StrAppend(&result_, "{name: '", attr.op_def_attr.name(), "', ");
+ strings::StrAppend(&result_, "type: nodeBackend().binding.TF_ATTR_INT, ");
+ strings::StrAppend(&result_, "value: ", InputForAttr(attr.op_def_attr),
+ ".length}");
+ }
+ }
+ strings::StrAppend(&result_, "\n ];\n");
+}
+
void GenTypeScriptOp::AddMethodReturnAndClose() {
strings::StrAppend(&result_, " return null;\n}\n");
}
@@ -162,7 +253,7 @@ void StartFile(WritableFile* ts_file) {
// This file is MACHINE GENERATED! Do not edit
import * as tfc from '@tensorflow/tfjs-core';
-import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils';
+import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
)header";
diff --git a/tensorflow/js/ops/ts_op_gen_test.cc b/tensorflow/js/ops/ts_op_gen_test.cc
index 9a85c021b0..03241689b5 100644
--- a/tensorflow/js/ops/ts_op_gen_test.cc
+++ b/tensorflow/js/ops/ts_op_gen_test.cc
@@ -36,7 +36,6 @@ void ExpectDoesNotContainStr(StringPiece s, StringPiece expected) {
<< "'" << s << "' does not contain '" << expected << "'";
}
-// TODO(kreeger): Add multiple outputs here?
constexpr char kBaseOpDef[] = R"(
op {
name: "Foo"
@@ -79,50 +78,15 @@ op {
summary: "Summary for op Foo."
description: "Description for op Foo."
}
-op {
- name: "DeprecatedFoo"
- input_arg {
- name: "input"
- description: "Description for input."
- type: DT_FLOAT
- }
- output_arg {
- name: "output"
- description: "Description for output."
- type: DT_FLOAT
- }
- deprecation {
- explanation: "Deprecated."
- }
-}
-op {
- name: "MultiOutputFoo"
- input_arg {
- name: "input"
- description: "Description for input."
- type: DT_FLOAT
- }
- output_arg {
- name: "output1"
- description: "Description for output 1."
- type: DT_FLOAT
- }
- output_arg {
- name: "output2"
- description: "Description for output 2."
- type: DT_FLOAT
- }
- summary: "Summary for op MultiOutputFoo."
- description: "Description for op MultiOutputFoo."
-}
)";
// Generate TypeScript code
-// @param api_def_str TODO doc me.
-void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) {
+void GenerateTsOpFileText(const string& op_def_str, const string& api_def_str,
+ string* ts_file_text) {
Env* env = Env::Default();
OpList op_defs;
- protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
+ protobuf::TextFormat::ParseFromString(
+ op_def_str.empty() ? kBaseOpDef : op_def_str, &op_defs);
ApiDefMap api_def_map(op_defs);
if (!api_def_str.empty()) {
@@ -138,11 +102,11 @@ void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) {
TEST(TsOpGenTest, TestImports) {
string ts_file_text;
- GenerateTsOpFileText("", &ts_file_text);
+ GenerateTsOpFileText("", "", &ts_file_text);
const string expected = R"(
import * as tfc from '@tensorflow/tfjs-core';
-import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils';
+import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
)";
ExpectContainsStr(ts_file_text, expected);
}
@@ -160,12 +124,10 @@ op {
)";
string ts_file_text;
- GenerateTsOpFileText(api_def, &ts_file_text);
+ GenerateTsOpFileText("", api_def, &ts_file_text);
const string expected = R"(
export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
- return null;
-}
)";
ExpectContainsStr(ts_file_text, expected);
}
@@ -179,34 +141,106 @@ op {
)";
string ts_file_text;
- GenerateTsOpFileText(api_def, &ts_file_text);
+ GenerateTsOpFileText("", api_def, &ts_file_text);
const string expected = R"(
export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
- return null;
-}
)";
ExpectDoesNotContainStr(ts_file_text, expected);
}
TEST(TsOpGenTest, SkipDeprecated) {
+ const string op_def = R"(
+op {
+ name: "DeprecatedFoo"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ description: "Description for input."
+ }
+ output_arg {
+ name: "output"
+ description: "Description for output."
+ type: DT_FLOAT
+ }
+ attr {
+ name: "T"
+ type: "type"
+ description: "Type for input"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ deprecation {
+ explanation: "Deprecated."
+ }
+}
+)";
+
string ts_file_text;
- GenerateTsOpFileText("", &ts_file_text);
+ GenerateTsOpFileText(op_def, "", &ts_file_text);
ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo");
}
TEST(TsOpGenTest, MultiOutput) {
+ const string op_def = R"(
+op {
+ name: "MultiOutputFoo"
+ input_arg {
+ name: "input"
+ description: "Description for input."
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output1"
+ description: "Description for output 1."
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "output2"
+ description: "Description for output 2."
+ type: DT_FLOAT
+ }
+ attr {
+ name: "T"
+ type: "type"
+ description: "Type for input"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ summary: "Summary for op MultiOutputFoo."
+ description: "Description for op MultiOutputFoo."
+}
+)";
+
string ts_file_text;
- GenerateTsOpFileText("", &ts_file_text);
+ GenerateTsOpFileText(op_def, "", &ts_file_text);
const string expected = R"(
export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] {
- return null;
-}
)";
ExpectContainsStr(ts_file_text, expected);
}
+TEST(TsOpGenTest, OpAttrs) {
+ string ts_file_text;
+ GenerateTsOpFileText("", "", &ts_file_text);
+
+ const string expectedFooAttrs = R"(
+ const opAttrs = [
+ createTensorsTypeOpAttr('T', images),
+ {name: 'N', type: nodeBackend().binding.TF_ATTR_INT, value: images.length}
+ ];
+)";
+
+ ExpectContainsStr(ts_file_text, expectedFooAttrs);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index e1d3422730..40f98474b5 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -723,7 +723,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":array_ops",
- ":cond_v2_impl",
":dtypes",
":framework_ops",
":graph_to_function_def",
@@ -2620,8 +2619,10 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":constant_op",
+ ":dtypes",
":framework_test_lib",
":sparse_ops",
+ ":sparse_tensor",
],
)
@@ -3245,7 +3246,6 @@ py_library(
),
srcs_version = "PY2AND3",
deps = [
- "saver",
":array_ops",
":array_ops_gen",
":checkpoint_management",
@@ -3269,6 +3269,7 @@ py_library(
":random_ops",
":resource_variable_ops",
":resources",
+ ":saver",
":sdca_ops",
":session",
":sparse_ops",
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 6642a5bfb1..e0826a7945 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 23)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 24)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 8a4ac6aaef..55d2709845 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -576,7 +576,6 @@ py_test(
srcs_version = "PY2AND3",
tags = [
"no_windows",
- "nomac",
"oss_serial",
],
deps = [
@@ -1047,7 +1046,6 @@ cuda_py_test(
tags = [
"no_oss", # Incompatible with bazel_pip.
"no_windows",
- "nomac", # TODO(cais): Install of futures and grpcio on all macs.
"notsan",
],
)
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index bdabbf4ea3..6f48d38b58 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -237,6 +237,7 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
":graph_only_ops",
+ "//tensorflow/python:cond_v2_impl",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index dba9779488..3171ef9d62 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import functools
+import sys
import threading
import numpy as np
@@ -38,6 +39,7 @@ from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import cond_v2_impl
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
@@ -49,6 +51,10 @@ from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
+# This is to avoid a circular dependency with cond_v2_impl
+# (function -> gradients_impl -> control_flow_ops -> cond_v2_impl).
+cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-access
+
def create_substitute_placeholder(value, name, dtype=None):
"""Creates a placeholder for `value` and propagates shape info to it."""
@@ -113,10 +119,6 @@ class CapturingGraph(ops.Graph):
# for resource tensors.
self._last_op_using_resource_tensor = {}
- # TODO(apassos) remove once the C API is used by default.
- def _use_c_api_hack(self):
- return True
-
def clear_resource_control_flow_state(self):
self._last_op_using_resource_tensor = {}
@@ -203,6 +205,8 @@ class FuncGraph(CapturingGraph):
by this function. The Tensors in this structure are the same as those of
self.outputs. Note that this structure might contain Python `None`s.
variables: Variables that should be watched during function execution.
+ outer_graph: The graph this function is defined in. May be another FuncGraph
+ or the global default Graph.
seed: The graph-level random seed.
"""
@@ -222,8 +226,9 @@ class FuncGraph(CapturingGraph):
self.outputs = []
self.structured_outputs = None
self.variables = []
+ self.outer_graph = ops.get_default_graph()
- graph = ops.get_default_graph()
+ graph = self.outer_graph
if context.executing_eagerly():
self.seed = context.global_seed()
@@ -259,6 +264,16 @@ class FuncGraph(CapturingGraph):
return internal_tensor
+ @property
+ def external_captures(self):
+ """External tensors captured by this function."""
+ return list(self.captures.keys())
+
+ @property
+ def internal_captures(self):
+ """Placeholders in this function corresponding captured tensors."""
+ return list(self.captures.values())
+
def _forward_name(n):
"""The name of a generated forward defun named n."""
@@ -695,7 +710,7 @@ def _get_defun_inputs_from_args(args):
return nest.pack_sequence_as(args, function_inputs)
-def _func_graph_from_py_func(name, python_func, args, kwds, signature=None):
+def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
"""Returns a `FuncGraph` generated from `python_func`.
Args:
@@ -1069,8 +1084,8 @@ class _PolymorphicFunction(object):
if graph_function is None:
graph_function = GraphCallable(
- _func_graph_from_py_func(self._name, self._python_function, args,
- kwds, self._input_signature))
+ func_graph_from_py_func(self._name, self._python_function, args,
+ kwds, self._input_signature))
self._variables.extend(
[v for v in graph_function.variables if v not in self._variables])
self._arguments_to_functions[cache_key] = graph_function
@@ -1469,8 +1484,7 @@ def make_defun_op(func, *args, **kwds):
and which can be called directly the way a `@defun` wrapped function
can.
"""
- return GraphCallable(
- _func_graph_from_py_func(func.__name__, func, args, kwds))
+ return GraphCallable(func_graph_from_py_func(func.__name__, func, args, kwds))
class AutomaticControlDependencies(object):
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 3e9bb91d54..4f23b3c4da 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -358,6 +358,47 @@ class FunctionTest(test.TestCase):
self.assertEqual(3.0, float(test_assign_add()))
+ @test_util.run_in_graph_and_eager_modes
+ def testTensorInitializationInFunctionRaisesError(self):
+ error_msg = ('Tensor-typed variable initializers must either be '
+ 'wrapped in an init_scope or callable.*')
+
+ @function.defun
+ def tensor_init():
+ with self.assertRaisesRegexp(ValueError, error_msg):
+ resource_variable_ops.ResourceVariable(constant_op.constant(2.0))
+
+ tensor_init()
+
+ @test_util.run_in_graph_and_eager_modes
+ def testCallableTensorInitializationInFunction(self):
+
+ @function.defun
+ def tensor_init():
+ v = resource_variable_ops.ResourceVariable(
+ lambda: constant_op.constant(2.0))
+ return v.read_value()
+
+ value = tensor_init()
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEqual(self.evaluate(value), 2.0)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testInitScopeTensorInitializationInFunction(self):
+
+ @function.defun
+ def tensor_init():
+ with ops.init_scope():
+ const = constant_op.constant(2.0)
+ v = resource_variable_ops.ResourceVariable(const)
+ return v.read_value()
+
+ value = tensor_init()
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEqual(self.evaluate(value), 2.0)
+
def testDefunShapeInferenceWithCapturedResourceVariable(self):
v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index f7ee42c7f6..bcbd7b7933 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -431,7 +431,11 @@ class Estimator(object):
Returns:
A dict containing the evaluation metrics specified in `model_fn` keyed by
name, as well as an entry `global_step` which contains the value of the
- global step for which this evaluation was performed.
+ global step for which this evaluation was performed. For canned
+ estimators, the dict contains the `loss` (mean loss per mini-batch) and
+ the `average_loss` (mean loss per sample). Canned classifiers also return
+ the `accuracy`. Canned regressors also return the `label/mean` and the
+ `prediction/mean`.
Raises:
ValueError: If `steps <= 0`.
diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py
index 7723fcce74..55aace5fa9 100644
--- a/tensorflow/python/estimator/export/export.py
+++ b/tensorflow/python/estimator/export/export.py
@@ -311,13 +311,33 @@ def build_parsing_serving_input_receiver_fn(feature_spec,
def _placeholder_from_tensor(t, default_batch_size=None):
+ """Creates a placeholder that matches the dtype and shape of passed tensor.
+
+ Args:
+ t: Tensor or EagerTensor
+ default_batch_size: the number of query examples expected per batch.
+ Leave unset for variable batch size (recommended).
+
+ Returns:
+ Placeholder that matches the passed tensor.
+ """
batch_shape = tensor_shape.TensorShape([default_batch_size])
shape = batch_shape.concatenate(t.get_shape()[1:])
# Reuse the feature tensor's op name (t.op.name) for the placeholder,
# excluding the index from the tensor's name (t.name):
# t.name = "%s:%d" % (t.op.name, t._value_index)
- return array_ops.placeholder(dtype=t.dtype, shape=shape, name=t.op.name)
+ try:
+ name = t.op.name
+ except AttributeError:
+ # In Eager mode, tensors don't have ops or names, and while they do have
+ # IDs, those are not maintained across runs. The name here is used
+ # primarily for debugging, and is not critical to the placeholder.
+ # So, in order to make this Eager-compatible, continue with an empty
+ # name if none is available.
+ name = None
+
+ return array_ops.placeholder(dtype=t.dtype, shape=shape, name=name)
def _placeholders_from_receiver_tensors_dict(input_vals,
diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py
index e87b88327f..3eed1ab163 100644
--- a/tensorflow/python/estimator/export/export_test.py
+++ b/tensorflow/python/estimator/export/export_test.py
@@ -416,6 +416,7 @@ class ExportTest(test_util.TensorFlowTestCase):
tensor_shape.unknown_shape(),
v.receiver_tensors["feature_2"].shape)
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_serving_input_receiver_fn(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -434,6 +435,7 @@ class ExportTest(test_util.TensorFlowTestCase):
dtypes.int32,
serving_input_receiver.receiver_tensors["feature_2"].dtype)
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_supervised_input_receiver_fn(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -454,6 +456,7 @@ class ExportTest(test_util.TensorFlowTestCase):
self.assertEqual(
dtypes.int32, input_receiver.receiver_tensors["feature_2"].dtype)
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_supervised_input_receiver_fn_raw_tensors(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -477,6 +480,7 @@ class ExportTest(test_util.TensorFlowTestCase):
self.assertEqual(set(["input", "label"]),
set(input_receiver.receiver_tensors.keys()))
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_supervised_input_receiver_fn_batch_size(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -489,6 +493,7 @@ class ExportTest(test_util.TensorFlowTestCase):
self.assertEqual([10], input_receiver.receiver_tensors["feature_1"].shape)
self.assertEqual([10], input_receiver.features["feature_1"].shape)
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_supervised_input_receiver_fn_overlapping_keys(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -497,6 +502,7 @@ class ExportTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
export.build_raw_supervised_input_receiver_fn(features, labels)
+ @test_util.run_in_graph_and_eager_modes
def test_build_supervised_input_receiver_fn_from_input_fn(self):
def dummy_input_fn():
return ({"x": constant_op.constant([[1], [1]]),
@@ -514,6 +520,7 @@ class ExportTest(test_util.TensorFlowTestCase):
self.assertEqual(set(["x", "y", "label"]),
set(input_receiver.receiver_tensors.keys()))
+ @test_util.run_in_graph_and_eager_modes
def test_build_supervised_input_receiver_fn_from_input_fn_args(self):
def dummy_input_fn(feature_key="x"):
return ({feature_key: constant_op.constant([[1], [1]]),
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index 220c3e58ca..12daddb044 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -51,6 +51,7 @@ _DEFAULT_REPLACEABLE_LIST = [
'device_fn',
'protocol',
'eval_distribute',
+ 'experimental_distribute',
]
_SAVE_CKPT_ERR = (
@@ -331,7 +332,8 @@ class RunConfig(object):
train_distribute=None,
device_fn=None,
protocol=None,
- eval_distribute=None):
+ eval_distribute=None,
+ experimental_distribute=None):
"""Constructs a RunConfig.
All distributed training related properties `cluster_spec`, `is_chief`,
@@ -469,6 +471,9 @@ class RunConfig(object):
`tf.contrib.distribute.DistributionStrategy`. If specified,
then Estimator will distribute the user's model during evaluation,
according to the policy specified by that strategy.
+ experimental_distribute: an optional
+ `tf.contrib.distribute.DistributeConfig` object specifying
+ DistributionStrategy-related configuration.
Raises:
ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs`
@@ -508,7 +513,8 @@ class RunConfig(object):
train_distribute=train_distribute,
device_fn=device_fn,
protocol=protocol,
- eval_distribute=eval_distribute)
+ eval_distribute=eval_distribute,
+ experimental_distribute=experimental_distribute)
self._init_distributed_setting_from_environment_var(tf_config)
@@ -810,6 +816,7 @@ class RunConfig(object):
- `device_fn`,
- `protocol`.
- `eval_distribute`,
+ - `experimental_distribute`,
In addition, either `save_checkpoints_steps` or `save_checkpoints_secs`
can be set (should not be both).
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 9d2babc6e0..9b482237ab 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -2747,6 +2747,62 @@ class FunctionalInputLayerTest(test.TestCase):
variables_lib.Variable)
self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10])
+ def test_fills_cols_to_vars_shared_embedding(self):
+ # Provide 5 DenseColumn's to input_layer: a NumericColumn, a
+ # BucketizedColumn, an EmbeddingColumn, two SharedEmbeddingColumns. The
+ # EmbeddingColumn creates a Variable and the two SharedEmbeddingColumns
+ # shared one variable.
+ price1 = fc.numeric_column('price1')
+ dense_feature = fc.numeric_column('dense_feature')
+ dense_feature_bucketized = fc.bucketized_column(
+ dense_feature, boundaries=[0.])
+ some_sparse_column = fc.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5)
+ some_embedding_column = fc.embedding_column(
+ some_sparse_column, dimension=10)
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ shared_embedding_a, shared_embedding_b = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b], dimension=2)
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[3.], [4.]],
+ 'dense_feature': [[-1.], [4.]],
+ 'sparse_feature': [['a'], ['x']],
+ 'aaa':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2)),
+ 'bbb':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 1),
+ dense_shape=(2, 2)),
+ }
+ cols_to_vars = {}
+ all_cols = [
+ price1, dense_feature_bucketized, some_embedding_column,
+ shared_embedding_a, shared_embedding_b
+ ]
+ fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
+ self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
+ self.assertEqual(0, len(cols_to_vars[price1]))
+ self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
+ self.assertEqual(1, len(cols_to_vars[some_embedding_column]))
+ self.assertEqual(1, len(cols_to_vars[shared_embedding_a]))
+ # This is a bug in the current implementation and should be fixed in the
+ # new one.
+ self.assertEqual(0, len(cols_to_vars[shared_embedding_b]))
+ self.assertIsInstance(cols_to_vars[some_embedding_column][0],
+ variables_lib.Variable)
+ self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10])
+ self.assertIsInstance(cols_to_vars[shared_embedding_a][0],
+ variables_lib.Variable)
+ self.assertAllEqual(cols_to_vars[shared_embedding_a][0].shape, [3, 2])
+
def test_fills_cols_to_vars_partitioned_variables(self):
price1 = fc.numeric_column('price1')
dense_feature = fc.numeric_column('dense_feature')
@@ -2772,6 +2828,10 @@ class FunctionalInputLayerTest(test.TestCase):
self.assertEqual(0, len(cols_to_vars[price1]))
self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
self.assertEqual(3, len(cols_to_vars[some_embedding_column]))
+ self.assertEqual(
+ 'input_from_feature_columns/input_layer/sparse_feature_embedding/'
+ 'embedding_weights/part_0:0',
+ cols_to_vars[some_embedding_column][0].name)
self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [2, 10])
self.assertAllEqual(cols_to_vars[some_embedding_column][1].shape, [2, 10])
self.assertAllEqual(cols_to_vars[some_embedding_column][2].shape, [1, 10])
@@ -5544,20 +5604,6 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertIsNone(partition_info)
return embedding_values
- # Expected lookup result, using combiner='mean'.
- expected_lookups_a = (
- # example 0:
- (7., 11.), # ids [2], embedding = [7, 11]
- # example 1:
- (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
- )
- expected_lookups_b = (
- # example 0:
- (1., 2.), # ids [0], embedding = [1, 2]
- # example 1:
- (0., 0.), # ids [], embedding = [0, 0]
- )
-
# Build columns.
categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index b6bf516286..aa66ed77e9 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -142,6 +142,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.engine import training
+from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.layers import base
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@@ -155,7 +156,6 @@ from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
-from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
@@ -164,67 +164,148 @@ from tensorflow.python.training import checkpoint_utils
from tensorflow.python.util import nest
-def _internal_input_layer(features,
- feature_columns,
- weight_collections=None,
- trainable=True,
- cols_to_vars=None,
- scope=None):
- """See input_layer. `scope` is a name or variable scope to use."""
+class StateManager(object):
+ """Manages the state associated with FeatureColumns.
- feature_columns = fc_old._normalize_feature_columns(feature_columns) # pylint: disable=protected-access
- for column in feature_columns:
- if not isinstance(column, fc_old._DenseColumn): # pylint: disable=protected-access
- raise ValueError(
- 'Items of feature_columns must be a _DenseColumn. '
- 'You can wrap a categorical column with an '
- 'embedding_column or indicator_column. Given: {}'.format(column))
- weight_collections = list(weight_collections or [])
- if ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections:
- weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
- if ops.GraphKeys.MODEL_VARIABLES not in weight_collections:
- weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
-
- # a non-None `scope` can allow for variable reuse, when, e.g., this function
- # is wrapped by a `make_template`.
- with variable_scope.variable_scope(
- scope, default_name='input_layer', values=features.values()):
- builder = fc_old._LazyBuilder(features) # pylint: disable=protected-access
- output_tensors = []
- ordered_columns = []
- for column in sorted(feature_columns, key=lambda x: x.name):
- ordered_columns.append(column)
- with variable_scope.variable_scope(
- None, default_name=column._var_scope_name): # pylint: disable=protected-access
- tensor = column._get_dense_tensor( # pylint: disable=protected-access
- builder,
- weight_collections=weight_collections,
- trainable=trainable)
- num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access
- batch_size = array_ops.shape(tensor)[0]
- output_tensors.append(
- array_ops.reshape(tensor, shape=(batch_size, num_elements)))
- if cols_to_vars is not None:
- # Retrieve any variables created (some _DenseColumn's don't create
- # variables, in which case an empty list is returned).
- cols_to_vars[column] = ops.get_collection(
- ops.GraphKeys.GLOBAL_VARIABLES,
- scope=variable_scope.get_variable_scope().name)
- _verify_static_batch_size_equality(output_tensors, ordered_columns)
- return array_ops.concat(output_tensors, 1)
+ Some `FeatureColumn`s create variables or resources to assist their
+ computation. The `StateManager` is responsible for creating and storing these
+ objects since `FeatureColumn`s are supposed to be stateless configuration
+ only.
+ """
+
+ def create_variable(self,
+ feature_column,
+ name,
+ shape,
+ dtype=None,
+ trainable=True,
+ initializer=None):
+ """Creates a new variable.
+
+ Args:
+ feature_column: A `FeatureColumn` object this variable corresponds to.
+ name: variable name.
+ shape: variable shape.
+ dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
+ trainable: Whether this variable is trainable or not.
+ initializer: initializer instance (callable).
+
+ Returns:
+ The created variable.
+ """
+ del feature_column, name, shape, dtype, trainable, initializer
+ raise NotImplementedError('StateManager.create_variable')
+
+ def add_variable(self, feature_column, var):
+ """Adds an existing variable to the state.
+
+ Args:
+ feature_column: A `FeatureColumn` object to associate this variable with.
+ var: The variable.
+ """
+ del feature_column, var
+ raise NotImplementedError('StateManager.add_variable')
+
+ def get_variable(self, feature_column, name):
+ """Returns an existing variable.
+
+ Args:
+ feature_column: A `FeatureColumn` object this variable corresponds to.
+ name: variable name.
+ """
+ del feature_column, name
+ raise NotImplementedError('StateManager.get_var')
+
+ def add_resource(self, feature_column, name, resource):
+ """Creates a new resource.
+
+ Resources can be things such as tables etc.
+
+ Args:
+ feature_column: A `FeatureColumn` object this resource corresponds to.
+ name: Name of the resource.
+ resource: The resource.
+
+ Returns:
+ The created resource.
+ """
+ del feature_column, name, resource
+ raise NotImplementedError('StateManager.add_resource')
+ def get_resource(self, feature_column, name):
+ """Returns an already created resource.
-def input_layer(features,
- feature_columns,
- weight_collections=None,
- trainable=True,
- cols_to_vars=None):
- """Returns a dense `Tensor` as input layer based on given `feature_columns`.
+ Resources can be things such as tables etc.
+
+ Args:
+ feature_column: A `FeatureColumn` object this variable corresponds to.
+ name: Name of the resource.
+ """
+ del feature_column, name
+ raise NotImplementedError('StateManager.get_resource')
+
+
+class _InputLayerStateManager(StateManager):
+ """Manages the state of InputLayer."""
+
+ def __init__(self, layer, feature_columns, trainable):
+ """Creates an _InputLayerStateManager object.
+
+ Args:
+ layer: The input layer this state manager is associated with.
+ feature_columns: List of feature columns for the input layer
+ trainable: Whether by default, variables created are trainable or not.
+ """
+ self._trainable = trainable
+ self._layer = layer
+ self._cols_to_vars_map = {}
+ self._cols_to_names_map = {}
+ for column in sorted(feature_columns, key=lambda x: x.name):
+ self._cols_to_vars_map[column] = {}
+ base_name = column.name
+ if isinstance(column, SharedEmbeddingColumn):
+ base_name = column.shared_collection_name
+ with variable_scope.variable_scope(base_name) as vs:
+ self._cols_to_names_map[column] = _strip_leading_slashes(vs.name)
+
+ def create_variable(self,
+ feature_column,
+ name,
+ shape,
+ dtype=None,
+ trainable=True,
+ initializer=None):
+ if name in self._cols_to_vars_map[feature_column]:
+ raise ValueError('Variable already exists.')
+ with variable_scope.variable_scope(self._cols_to_names_map[feature_column]):
+ var = self._layer.add_variable(
+ name=name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ trainable=self._trainable and trainable,
+ # TODO(rohanj): Get rid of this hack once we have a mechanism for
+ # specifying a default partitioner for an entire layer. In that case,
+ # the default getter for Layers should work.
+ getter=variable_scope.get_variable)
+ self._cols_to_vars_map[feature_column][name] = var
+ return var
+
+ def get_variable(self, feature_column, name):
+ if name in self._cols_to_vars_map[feature_column]:
+ return self._cols_to_vars_map[feature_column][name]
+ raise ValueError('Variable does not exist.')
+
+
+class FeatureLayer(Layer):
+ """A layer that produces a dense `Tensor` based on given `feature_columns`.
Generally a single example in training data is described with FeatureColumns.
At the first layer of the model, this column oriented data should be converted
to a single `Tensor`.
+ This layer can be called multiple times with different features.
+
Example:
```python
@@ -233,105 +314,122 @@ def input_layer(features,
categorical_column_with_hash_bucket("keywords", 10K), dimensions=16)
columns = [price, keywords_embedded, ...]
features = tf.parse_example(..., features=make_parse_example_spec(columns))
- dense_tensor = input_layer(features, columns)
+ feature_layer = FeatureLayer(columns)
+ dense_tensor = feature_layer(features)
for units in [128, 64, 32]:
dense_tensor = tf.layers.dense(dense_tensor, units, tf.nn.relu)
- prediction = tf.layers.dense(dense_tensor, 1)
- ```
-
- Args:
- features: A mapping from key to tensors. `_FeatureColumn`s look up via these
- keys. For example `numeric_column('price')` will look at 'price' key in
- this dict. Values can be a `SparseTensor` or a `Tensor` depends on
- corresponding `_FeatureColumn`.
- feature_columns: An iterable containing the FeatureColumns to use as inputs
- to your model. All items should be instances of classes derived from
- `_DenseColumn` such as `numeric_column`, `embedding_column`,
- `bucketized_column`, `indicator_column`. If you have categorical features,
- you can wrap them with an `embedding_column` or `indicator_column`.
- weight_collections: A list of collection names to which the Variable will be
- added. Note that variables will also be added to collections
- `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
- trainable: If `True` also add the variable to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- cols_to_vars: If not `None`, must be a dictionary that will be filled with a
- mapping from `_FeatureColumn` to list of `Variable`s. For example, after
- the call, we might have cols_to_vars =
- {_EmbeddingColumn(
- categorical_column=_HashedCategoricalColumn(
- key='sparse_feature', hash_bucket_size=5, dtype=tf.string),
- dimension=10): [<tf.Variable 'some_variable:0' shape=(5, 10),
- <tf.Variable 'some_variable:1' shape=(5, 10)]}
- If a column creates no variables, its value will be an empty list.
-
- Returns:
- A `Tensor` which represents input layer of a model. Its shape
- is (batch_size, first_layer_dimension) and its dtype is `float32`.
- first_layer_dimension is determined based on given `feature_columns`.
-
- Raises:
- ValueError: if an item in `feature_columns` is not a `_DenseColumn`.
- """
- return _internal_input_layer(features, feature_columns, weight_collections,
- trainable, cols_to_vars)
-
-
-# TODO(akshayka): InputLayer should be a subclass of Layer, and it
-# should implement the logic in input_layer using Layer's build-and-call
-# paradigm; input_layer should create an instance of InputLayer and
-# return the result of invoking its apply method, just as functional layers do.
-class InputLayer(object):
- """An object-oriented version of `input_layer` that reuses variables."""
+ prediction = tf.layers.dense(dense_tensor, 1)."""
def __init__(self,
feature_columns,
- weight_collections=None,
trainable=True,
- cols_to_vars=None):
- """See `input_layer`."""
+ name=None,
+ shared_state_manager=None,
+ **kwargs):
+ """Constructs a FeatureLayer.
- self._feature_columns = feature_columns
- self._weight_collections = weight_collections
- self._trainable = trainable
- self._cols_to_vars = cols_to_vars
- self._input_layer_template = template.make_template(
- 'feature_column_input_layer',
- _internal_input_layer,
- create_scope_now_=True)
- self._scope = self._input_layer_template.variable_scope
-
- def __call__(self, features):
- return self._input_layer_template(
- features=features,
- feature_columns=self._feature_columns,
- weight_collections=self._weight_collections,
- trainable=self._trainable,
- cols_to_vars=None,
- scope=self._scope)
+ Args:
+ feature_columns: An iterable containing the FeatureColumns to use as
+ inputs to your model. All items should be instances of classes derived
+ from `DenseColumn` such as `numeric_column`, `embedding_column`,
+ `bucketized_column`, `indicator_column`. If you have categorical
+ features, you can wrap them with an `embedding_column` or
+ `indicator_column`.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: Name to give to the FeatureLayer.
+ shared_state_manager: SharedEmbeddingStateManager that manages the state
+ of SharedEmbeddingColumns. The state of SharedEmbeddingColumns, unlike
+ regular embedding columns cannot be owned by the InputLayer itself since
+ SharedEmbeddingColumns can be shared across different InputLayers. As a
+ result users are expected to create a SharedEmbeddingStateManager object
+ which would be responsible for managing the shared state and can be
+ passed into different InputLayer objects to share state. For example,
+
+ ```python
+ sc_1, sc_2 = shared_embedding_column_v2(...)
+ sc_3, sc_4 = shared_embedding_column_v2(...)
+ ssm = SharedEmbeddingStateManager()
+ feature_layer1 = FeatureLayer([sc_1, sc_3], ...,
+ shared_state_manager=ssm)
+ feature_layer2 = FeatureLayer([sc_2, sc_4], ...,
+ shared_state_manager=ssm)
+ ```
+ now input_layer1 and input_layer2 will share variables across. If
+ sharing is not desired, one can create 2 separate
+ SharedEmbeddingStateManager objects
+
+ ```python
+ ssm1 = SharedEmbeddingStateManager()
+ ssm2 = SharedEmbeddingStateManager()
+ feature_layer1 = FeatureLayer([sc_1, sc_3], ...,
+ shared_state_manager=ssm1)
+ feature_layer2 = FeatureLayer([sc_2, sc_4], ...,
+ shared_state_manager=ssm2)
+ ```
+ **kwargs: Keyword arguments to construct a layer.
- @property
- def non_trainable_variables(self):
- return self._input_layer_template.non_trainable_variables
+ Raises:
+ ValueError: if an item in `feature_columns` is not a `DenseColumn`.
+ """
+ super(FeatureLayer, self).__init__(name=name, trainable=trainable, **kwargs)
- @property
- def non_trainable_weights(self):
- return self._input_layer_template.non_trainable_weights
+ self._feature_columns = _normalize_feature_columns(feature_columns)
+ self._state_manager = _InputLayerStateManager(self, self._feature_columns,
+ self.trainable)
+ self._shared_state_manager = shared_state_manager
+ for column in sorted(self._feature_columns, key=lambda x: x.name):
+ if not isinstance(column, DenseColumn):
+ raise ValueError(
+ 'Items of feature_columns must be a DenseColumn. '
+ 'You can wrap a categorical column with an '
+ 'embedding_column or indicator_column. Given: {}'.format(column))
- @property
- def trainable_variables(self):
- return self._input_layer_template.trainable_variables
+ def build(self, _):
+ for column in sorted(self._feature_columns, key=lambda x: x.name):
+ if isinstance(column, SharedEmbeddingColumn):
+ column.create_state(self._shared_state_manager)
+ else:
+ with variable_scope.variable_scope(None, default_name=self.name):
+ column.create_state(self._state_manager)
+ super(FeatureLayer, self).build(None)
- @property
- def trainable_weights(self):
- return self._input_layer_template.trainable_weights
+ def call(self, features, cols_to_output_tensors=None):
+ """Returns a dense tensor corresponding to the `feature_columns`.
- @property
- def variables(self):
- return self._input_layer_template.variables
+ Args:
+ features: A mapping from key to tensors. `FeatureColumn`s look up via
+ these keys. For example `numeric_column('price')` will look at 'price'
+ key in this dict. Values can be a `SparseTensor` or a `Tensor` depends
+ on corresponding `FeatureColumn`.
+ cols_to_output_tensors: If not `None`, this will be filled with a dict
+ mapping feature columns to output tensors created.
- @property
- def weights(self):
- return self._input_layer_template.weights
+ Returns:
+ A `Tensor` which represents input layer of a model. Its shape
+ is (batch_size, first_layer_dimension) and its dtype is `float32`.
+ first_layer_dimension is determined based on given `feature_columns`.
+ """
+ transformation_cache = FeatureTransformationCache(features)
+ output_tensors = []
+ ordered_columns = []
+ for column in sorted(self._feature_columns, key=lambda x: x.name):
+ ordered_columns.append(column)
+ if isinstance(column, SharedEmbeddingColumn):
+ tensor = column.get_dense_tensor(transformation_cache,
+ self._shared_state_manager)
+ else:
+ tensor = column.get_dense_tensor(transformation_cache,
+ self._state_manager)
+ num_elements = column.variable_shape.num_elements()
+ batch_size = array_ops.shape(tensor)[0]
+ tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
+ output_tensors.append(tensor)
+ if cols_to_output_tensors is not None:
+ cols_to_output_tensors[column] = tensor
+
+ _verify_static_batch_size_equality(output_tensors, ordered_columns)
+ return array_ops.concat(output_tensors, 1)
def linear_model(features,
@@ -565,12 +663,15 @@ class _BiasLayer(base.Layer):
return self._bias_variable
-def _get_expanded_variable_list(variable):
- if (isinstance(variable, variables.Variable) or
- resource_variable_ops.is_resource_variable(variable)):
- return [variable] # Single variable case.
- else: # Must be a PartitionedVariable, so convert into a list.
- return list(variable)
+def _get_expanded_variable_list(var_list):
+ returned_list = []
+ for variable in var_list:
+ if (isinstance(variable, variables.Variable) or
+ resource_variable_ops.is_resource_variable(variable)):
+ returned_list.append(variable) # Single variable case.
+ else: # Must be a PartitionedVariable, so convert into a list.
+ returned_list.extend(list(variable))
+ return returned_list
def _strip_leading_slashes(name):
@@ -661,7 +762,7 @@ class _LinearModel(training.Model):
scope=variable_scope.get_variable_scope()), # pylint: disable=not-callable
name='weighted_sum')
bias = self._bias_layer.variables[0]
- self._cols_to_vars['bias'] = _get_expanded_variable_list(bias)
+ self._cols_to_vars['bias'] = _get_expanded_variable_list([bias])
return predictions
def _add_layers(self, layers):
@@ -877,10 +978,15 @@ def embedding_column(
trainable=trainable)
-def shared_embedding_columns(
- categorical_columns, dimension, combiner='mean', initializer=None,
- shared_embedding_collection_name=None, ckpt_to_load_from=None,
- tensor_name_in_ckpt=None, max_norm=None, trainable=True):
+def shared_embedding_columns_v2(categorical_columns,
+ dimension,
+ combiner='mean',
+ initializer=None,
+ shared_embedding_collection_name=None,
+ ckpt_to_load_from=None,
+ tensor_name_in_ckpt=None,
+ max_norm=None,
+ trainable=True):
"""List of dense columns that convert from sparse, categorical input.
This is similar to `embedding_column`, except that it produces a list of
@@ -1803,51 +1909,6 @@ def crossed_column(keys, hash_bucket_size, hash_key=None):
keys=tuple(keys), hash_bucket_size=hash_bucket_size, hash_key=hash_key)
-class StateManager(object):
- """Manages the state associated with FeatureColumns.
-
- Some `FeatureColumn`s create variables or resources to assist their
- computation. The `StateManager` is responsible for creating and storing these
- objects since `FeatureColumn`s are supposed to be stateless configuration
- only.
- """
-
- def get_variable(self,
- feature_column,
- name,
- shape,
- dtype=None,
- initializer=None):
- """Creates a new variable or returns an existing one.
-
- Args:
- feature_column: A `FeatureColumn` object this variable corresponds to.
- name: variable name.
- shape: variable shape.
- dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
- initializer: initializer instance (callable).
-
- Returns:
- The variable.
- """
- raise NotImplementedError('StateManager.get_variable')
-
- def get_resource(self, feature_column, name, resource_creator):
- """Creates a new resource or returns an existing one.
-
- Resources can be things such as tables etc.
-
- Args:
- feature_column: A `FeatureColumn` object this variable corresponds to.
- name: Name of the resource.
- resource_creator: A callable that can create the resource.
-
- Returns:
- The resource.
- """
- raise NotImplementedError('StateManager.get_resource')
-
-
class FeatureColumn(object):
"""Represents a feature column abstraction.
@@ -2550,6 +2611,17 @@ class EmbeddingColumn(
"""See `DenseColumn` base class."""
return tensor_shape.vector(self.dimension)
+ def create_state(self, state_manager):
+ """Creates the embedding lookup variable."""
+ embedding_shape = (self.categorical_column.num_buckets, self.dimension)
+ state_manager.create_variable(
+ self,
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ trainable=self.trainable,
+ initializer=self.initializer)
+
def _get_dense_tensor_internal(self, transformation_cache, state_manager):
"""Private method that follows the signature of _get_dense_tensor."""
# Get sparse IDs and weights.
@@ -2558,13 +2630,8 @@ class EmbeddingColumn(
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
- embedding_shape = (self.categorical_column.num_buckets, self.dimension)
embedding_weights = state_manager.get_variable(
- self,
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer)
+ self, name='embedding_weights')
if self.ckpt_to_load_from is not None:
to_restore = embedding_weights
@@ -2637,6 +2704,68 @@ def _get_graph_for_variable(var):
return var.graph
+class SharedEmbeddingStateManager(Layer):
+ """A state manager that handle the state of shared embedding columns.
+
+ This can handle multiple sets of columns that share variables."""
+
+ def __init__(self, trainable=True, name=None, **kwargs):
+ """Constructs a `SharedEmbeddingStateManager`.
+
+ Args:
+ trainable: If true, variables created are trainable.
+ name: Name of the State Manager.
+ **kwargs: Keyword arguments.
+ """
+ super(SharedEmbeddingStateManager, self).__init__(
+ name=name, trainable=trainable, **kwargs)
+ self._var_dict = {}
+
+ def create_variable(self,
+ name,
+ shape,
+ dtype=None,
+ trainable=True,
+ initializer=None):
+ """Creates a variable.
+
+ Makes sure only one var is created per `shared_collection_name`. `name` is
+ ignored here as the variable is named `shared_collection_name` instead.
+
+ Args:
+ name: Name of the variable. Not used.
+ shape: Variable shape.
+ dtype: Variable type.
+ trainable: If variable created should be trainable or not.
+ initializer: Variable initializer.
+
+ Returns:
+ A variable or partitioned variable.
+ """
+ if name in self._var_dict:
+ var = self._var_dict[name]
+ return var
+ with variable_scope.variable_scope(
+ self.name, reuse=variable_scope.AUTO_REUSE):
+ var = self.add_variable(
+ name=name,
+ shape=shape,
+ dtype=dtype,
+ trainable=self.trainable and trainable,
+ initializer=initializer,
+ # TODO(rohanj): Get rid of this hack once we have a mechanism for
+ # specifying a default partitioner for an entire layer. In that case,
+ # the default getter for Layers should work.
+ getter=variable_scope.get_variable)
+ self._var_dict[name] = var
+ return var
+
+ def get_variable(self, feature_column, name):
+ if name not in self._var_dict:
+ raise ValueError('Variable name: {} not recognized.'.format(name))
+ return self._var_dict[name]
+
+
class SharedEmbeddingColumn(
DenseColumn, SequenceDenseColumn,
collections.namedtuple(
@@ -2675,6 +2804,16 @@ class SharedEmbeddingColumn(
"""See `DenseColumn` base class."""
return tensor_shape.vector(self.dimension)
+ def create_state(self, state_manager):
+ """Creates the shared embedding lookup variable."""
+ embedding_shape = (self.categorical_column.num_buckets, self.dimension)
+ state_manager.create_variable(
+ name=self.shared_collection_name,
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ trainable=self.trainable,
+ initializer=self.initializer)
+
def _get_dense_tensor_internal(self, transformation_cache, state_manager):
"""Private method that follows the signature of _get_dense_tensor."""
# This method is called from a variable_scope with name _var_scope_name,
@@ -2687,13 +2826,8 @@ class SharedEmbeddingColumn(
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
- embedding_shape = (self.categorical_column.num_buckets, self.dimension)
embedding_weights = state_manager.get_variable(
- self,
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer)
+ self, name=self.shared_collection_name)
if self.ckpt_to_load_from is not None:
to_restore = embedding_weights
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
index ad578d287a..6b343ecf3e 100644
--- a/tensorflow/python/feature_column/feature_column_v2_test.py
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -33,12 +33,12 @@ from tensorflow.python.eager import context
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.feature_column import feature_column_v2 as fc
+from tensorflow.python.feature_column.feature_column_v2 import _LinearModel
+from tensorflow.python.feature_column.feature_column_v2 import _transform_features
from tensorflow.python.feature_column.feature_column_v2 import FeatureColumn
+from tensorflow.python.feature_column.feature_column_v2 import FeatureLayer
from tensorflow.python.feature_column.feature_column_v2 import FeatureTransformationCache
-from tensorflow.python.feature_column.feature_column_v2 import InputLayer
from tensorflow.python.feature_column.feature_column_v2 import StateManager
-from tensorflow.python.feature_column.feature_column_v2 import _LinearModel
-from tensorflow.python.feature_column.feature_column_v2 import _transform_features
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -824,22 +824,6 @@ class HashedCategoricalColumnTest(test.TestCase):
self.assertEqual(
transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor)
- def DISABLED_test_get_sparse_tensors_weight_collections(self):
- column = fc.categorical_column_with_hash_bucket('aaa', 10)
- inputs = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'],
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- column._get_sparse_tensors(
- FeatureTransformationCache({
- 'aaa': inputs
- }),
- weight_collections=('my_weights',))
-
- self.assertItemsEqual(
- [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
- self.assertItemsEqual([], ops.get_collection('my_weights'))
-
def test_get_sparse_tensors_dense_input(self):
hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
transformation_cache = FeatureTransformationCache({
@@ -2640,13 +2624,13 @@ class _LinearModelTest(test.TestCase):
sess.run(net, feed_dict={features['price']: np.array(1)})
-class InputLayerTest(test.TestCase):
+class FeatureLayerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def test_retrieving_input(self):
features = {'a': [0.]}
- input_layer = InputLayer(fc_old.numeric_column('a'))
- inputs = self.evaluate(input_layer(features))
+ feature_layer = FeatureLayer(fc.numeric_column('a'))
+ inputs = self.evaluate(feature_layer(features))
self.assertAllClose([[0.]], inputs)
def test_reuses_variables(self):
@@ -2657,7 +2641,7 @@ class InputLayerTest(test.TestCase):
dense_shape=(3, 3))
# Create feature columns (categorical and embedding).
- categorical_column = fc_old.categorical_column_with_identity(
+ categorical_column = fc.categorical_column_with_identity(
key='a', num_buckets=3)
embedding_dimension = 2
def _embedding_column_initializer(shape, dtype, partition_info):
@@ -2670,16 +2654,16 @@ class InputLayerTest(test.TestCase):
(1, 1)) # id 2
return embedding_values
- embedding_column = fc_old.embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column,
dimension=embedding_dimension,
initializer=_embedding_column_initializer)
- input_layer = InputLayer([embedding_column])
+ feature_layer = FeatureLayer([embedding_column])
features = {'a': sparse_input}
- inputs = input_layer(features)
- variables = input_layer.variables
+ inputs = feature_layer(features)
+ variables = feature_layer.variables
# Sanity check: test that the inputs are correct.
self.assertAllEqual([[1, 0], [0, 1], [1, 1]], inputs)
@@ -2687,13 +2671,13 @@ class InputLayerTest(test.TestCase):
# Check that only one variable was created.
self.assertEqual(1, len(variables))
- # Check that invoking input_layer on the same features does not create
+ # Check that invoking feature_layer on the same features does not create
# additional variables
- _ = input_layer(features)
+ _ = feature_layer(features)
self.assertEqual(1, len(variables))
- self.assertEqual(variables[0], input_layer.variables[0])
+ self.assertEqual(variables[0], feature_layer.variables[0])
- def test_feature_column_input_layer_gradient(self):
+ def test_feature_column_feature_layer_gradient(self):
with context.eager_mode():
sparse_input = sparse_tensor.SparseTensor(
indices=((0, 0), (1, 0), (2, 0)),
@@ -2701,7 +2685,7 @@ class InputLayerTest(test.TestCase):
dense_shape=(3, 3))
# Create feature columns (categorical and embedding).
- categorical_column = fc_old.categorical_column_with_identity(
+ categorical_column = fc.categorical_column_with_identity(
key='a', num_buckets=3)
embedding_dimension = 2
@@ -2715,16 +2699,16 @@ class InputLayerTest(test.TestCase):
(1, 1)) # id 2
return embedding_values
- embedding_column = fc_old.embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column,
dimension=embedding_dimension,
initializer=_embedding_column_initializer)
- input_layer = InputLayer([embedding_column])
+ feature_layer = FeatureLayer([embedding_column])
features = {'a': sparse_input}
def scale_matrix():
- matrix = input_layer(features)
+ matrix = feature_layer(features)
return 2 * matrix
# Sanity check: Verify that scale_matrix returns the correct output.
@@ -2739,185 +2723,139 @@ class InputLayerTest(test.TestCase):
self.assertAllEqual([0, 1, 2], indexed_slice.indices)
self.assertAllEqual([[2, 2], [2, 2], [2, 2]], gradient)
-
-class FunctionalInputLayerTest(test.TestCase):
-
def test_raises_if_empty_feature_columns(self):
with self.assertRaisesRegexp(ValueError,
'feature_columns must not be empty'):
- fc.input_layer(features={}, feature_columns=[])
+ FeatureLayer(feature_columns=[])(features={})
def test_should_be_dense_column(self):
- with self.assertRaisesRegexp(ValueError, 'must be a _DenseColumn'):
- fc.input_layer(
- features={'a': [[0]]},
- feature_columns=[
- fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- ])
+ with self.assertRaisesRegexp(ValueError, 'must be a DenseColumn'):
+ FeatureLayer(feature_columns=[
+ fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ ])(
+ features={
+ 'a': [[0]]
+ })
def test_does_not_support_dict_columns(self):
with self.assertRaisesRegexp(
ValueError, 'Expected feature_columns to be iterable, found dict.'):
- fc.input_layer(
- features={'a': [[0]]},
- feature_columns={'a': fc_old.numeric_column('a')})
+ FeatureLayer(feature_columns={'a': fc.numeric_column('a')})(
+ features={
+ 'a': [[0]]
+ })
def test_bare_column(self):
with ops.Graph().as_default():
features = features = {'a': [0.]}
- net = fc.input_layer(features, fc_old.numeric_column('a'))
+ net = FeatureLayer(fc.numeric_column('a'))(features)
with _initialized_session():
self.assertAllClose([[0.]], net.eval())
def test_column_generator(self):
with ops.Graph().as_default():
features = features = {'a': [0.], 'b': [1.]}
- columns = (fc_old.numeric_column(key) for key in features)
- net = fc.input_layer(features, columns)
+ columns = (fc.numeric_column(key) for key in features)
+ net = FeatureLayer(columns)(features)
with _initialized_session():
self.assertAllClose([[0., 1.]], net.eval())
def test_raises_if_duplicate_name(self):
with self.assertRaisesRegexp(
ValueError, 'Duplicate feature column name found for columns'):
- fc.input_layer(
- features={'a': [[0]]},
- feature_columns=[
- fc_old.numeric_column('a'),
- fc_old.numeric_column('a')
- ])
+ FeatureLayer(
+ feature_columns=[fc.numeric_column('a'),
+ fc.numeric_column('a')])(
+ features={
+ 'a': [[0]]
+ })
def test_one_column(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- net = fc.input_layer(features, [price])
+ net = FeatureLayer([price])(features)
with _initialized_session():
self.assertAllClose([[1.], [5.]], net.eval())
def test_multi_dimension(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1., 2.], [5., 6.]]}
- net = fc.input_layer(features, [price])
+ net = FeatureLayer([price])(features)
with _initialized_session():
self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
def test_raises_if_shape_mismatch(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
with self.assertRaisesRegexp(
Exception,
r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
- fc.input_layer(features, [price])
+ FeatureLayer([price])(features)
def test_reshaping(self):
- price = fc_old.numeric_column('price', shape=[1, 2])
+ price = fc.numeric_column('price', shape=[1, 2])
with ops.Graph().as_default():
features = {'price': [[[1., 2.]], [[5., 6.]]]}
- net = fc.input_layer(features, [price])
+ net = FeatureLayer([price])(features)
with _initialized_session():
self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
def test_multi_column(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': [[1., 2.], [5., 6.]],
'price2': [[3.], [4.]]
}
- net = fc.input_layer(features, [price1, price2])
+ net = FeatureLayer([price1, price2])(features)
with _initialized_session():
self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], net.eval())
- def test_fills_cols_to_vars(self):
- # Provide three _DenseColumn's to input_layer: a _NumericColumn, a
- # _BucketizedColumn, and an _EmbeddingColumn. Only the _EmbeddingColumn
- # creates a Variable.
- price1 = fc_old.numeric_column('price1')
- dense_feature = fc_old.numeric_column('dense_feature')
- dense_feature_bucketized = fc_old.bucketized_column(
- dense_feature, boundaries=[0.])
- some_sparse_column = fc_old.categorical_column_with_hash_bucket(
- 'sparse_feature', hash_bucket_size=5)
- some_embedding_column = fc_old.embedding_column(
- some_sparse_column, dimension=10)
- with ops.Graph().as_default():
- features = {
- 'price1': [[3.], [4.]],
- 'dense_feature': [[-1.], [4.]],
- 'sparse_feature': [['a'], ['x']],
- }
- cols_to_vars = {}
- all_cols = [price1, dense_feature_bucketized, some_embedding_column]
- fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
- self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
- self.assertEqual(0, len(cols_to_vars[price1]))
- self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
- self.assertEqual(1, len(cols_to_vars[some_embedding_column]))
- self.assertIsInstance(cols_to_vars[some_embedding_column][0],
- variables_lib.Variable)
- self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10])
-
- def test_fills_cols_to_vars_partitioned_variables(self):
- price1 = fc_old.numeric_column('price1')
- dense_feature = fc_old.numeric_column('dense_feature')
- dense_feature_bucketized = fc_old.bucketized_column(
- dense_feature, boundaries=[0.])
- some_sparse_column = fc_old.categorical_column_with_hash_bucket(
- 'sparse_feature', hash_bucket_size=5)
- some_embedding_column = fc_old.embedding_column(
- some_sparse_column, dimension=10)
+ def test_cols_to_output_tensors(self):
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
- features = {
- 'price1': [[3.], [4.]],
- 'dense_feature': [[-1.], [4.]],
- 'sparse_feature': [['a'], ['x']],
- }
- cols_to_vars = {}
- all_cols = [price1, dense_feature_bucketized, some_embedding_column]
- with variable_scope.variable_scope(
- 'input_from_feature_columns',
- partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0)):
- fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
- self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
- self.assertEqual(0, len(cols_to_vars[price1]))
- self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
- self.assertEqual(3, len(cols_to_vars[some_embedding_column]))
- self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [2, 10])
- self.assertAllEqual(cols_to_vars[some_embedding_column][1].shape, [2, 10])
- self.assertAllEqual(cols_to_vars[some_embedding_column][2].shape, [1, 10])
+ cols_dict = {}
+ features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
+ feature_layer = FeatureLayer([price1, price2])
+ net = feature_layer(features, cols_dict)
+ with _initialized_session():
+ self.assertAllClose([[1., 2.], [5., 6.]], cols_dict[price1].eval())
+ self.assertAllClose([[3.], [4.]], cols_dict[price2].eval())
+ self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], net.eval())
def test_column_order(self):
- price_a = fc_old.numeric_column('price_a')
- price_b = fc_old.numeric_column('price_b')
+ price_a = fc.numeric_column('price_a')
+ price_b = fc.numeric_column('price_b')
with ops.Graph().as_default():
features = {
'price_a': [[1.]],
'price_b': [[3.]],
}
- net1 = fc.input_layer(features, [price_a, price_b])
- net2 = fc.input_layer(features, [price_b, price_a])
+ net1 = FeatureLayer([price_a, price_b])(features)
+ net2 = FeatureLayer([price_b, price_a])(features)
with _initialized_session():
self.assertAllClose([[1., 3.]], net1.eval())
self.assertAllClose([[1., 3.]], net2.eval())
def test_fails_for_categorical_column(self):
- animal = fc_old.categorical_column_with_identity('animal', num_buckets=4)
+ animal = fc.categorical_column_with_identity('animal', num_buckets=4)
with ops.Graph().as_default():
features = {
'animal':
sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
}
- with self.assertRaisesRegexp(Exception, 'must be a _DenseColumn'):
- fc.input_layer(features, [animal])
+ with self.assertRaisesRegexp(Exception, 'must be a DenseColumn'):
+ FeatureLayer([animal])(features)
def test_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': [[1.], [5.], [7.]], # batchsize = 3
@@ -2926,12 +2864,12 @@ class FunctionalInputLayerTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- fc.input_layer(features, [price1, price2])
+ FeatureLayer([price1, price2])(features)
def test_subset_of_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- price3 = fc_old.numeric_column('price3')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ price3 = fc.numeric_column('price3')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
@@ -2941,31 +2879,31 @@ class FunctionalInputLayerTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- fc.input_layer(features, [price1, price2, price3])
+ FeatureLayer([price1, price2, price3])(features)
def test_runtime_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
'price2': [[3.], [4.]] # batchsize = 2
}
- net = fc.input_layer(features, [price1, price2])
+ net = FeatureLayer([price1, price2])(features)
with _initialized_session() as sess:
with self.assertRaisesRegexp(errors.OpError,
'Dimensions of inputs should match'):
sess.run(net, feed_dict={features['price1']: [[1.], [5.], [7.]]})
def test_runtime_batch_size_matches(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
}
- net = fc.input_layer(features, [price1, price2])
+ net = FeatureLayer([price1, price2])(features)
with _initialized_session() as sess:
sess.run(
net,
@@ -2975,9 +2913,9 @@ class FunctionalInputLayerTest(test.TestCase):
})
def test_multiple_layers_with_same_embedding_column(self):
- some_sparse_column = fc_old.categorical_column_with_hash_bucket(
+ some_sparse_column = fc.categorical_column_with_hash_bucket(
'sparse_feature', hash_bucket_size=5)
- some_embedding_column = fc_old.embedding_column(
+ some_embedding_column = fc.embedding_column(
some_sparse_column, dimension=10)
with ops.Graph().as_default():
@@ -2985,28 +2923,30 @@ class FunctionalInputLayerTest(test.TestCase):
'sparse_feature': [['a'], ['x']],
}
all_cols = [some_embedding_column]
- fc.input_layer(features, all_cols)
- fc.input_layer(features, all_cols)
+ FeatureLayer(all_cols)(features)
+ FeatureLayer(all_cols)(features)
# Make sure that 2 variables get created in this case.
self.assertEqual(2, len(
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
expected_var_names = [
- 'input_layer/sparse_feature_embedding/embedding_weights:0',
- 'input_layer_1/sparse_feature_embedding/embedding_weights:0'
+ 'feature_layer/sparse_feature_embedding/embedding_weights:0',
+ 'feature_layer_1/sparse_feature_embedding/embedding_weights:0'
]
self.assertItemsEqual(
expected_var_names,
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
def test_multiple_layers_with_same_shared_embedding_column(self):
- categorical_column_a = fc_old.categorical_column_with_identity(
+ categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=3)
- categorical_column_b = fc_old.categorical_column_with_identity(
+ categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
- embedding_column_b, embedding_column_a = fc_old.shared_embedding_columns(
+ embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2(
[categorical_column_b, categorical_column_a],
dimension=embedding_dimension)
+ shared_state_manager = fc.SharedEmbeddingStateManager(
+ name='shared_feature_layer')
with ops.Graph().as_default():
features = {
@@ -3022,27 +2962,33 @@ class FunctionalInputLayerTest(test.TestCase):
dense_shape=(2, 2)),
}
all_cols = [embedding_column_a, embedding_column_b]
- fc.input_layer(features, all_cols)
- fc.input_layer(features, all_cols)
+ FeatureLayer(
+ all_cols, shared_state_manager=shared_state_manager)(
+ features)
+ FeatureLayer(
+ all_cols, shared_state_manager=shared_state_manager)(
+ features)
# Make sure that only 1 variable gets created in this case.
self.assertEqual(1, len(
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
self.assertItemsEqual(
- ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
+ ['shared_feature_layer/aaa_bbb_shared_embedding:0'],
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
def test_multiple_layers_with_same_shared_embedding_column_diff_graphs(self):
- categorical_column_a = fc_old.categorical_column_with_identity(
+ categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=3)
- categorical_column_b = fc_old.categorical_column_with_identity(
+ categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
- embedding_column_b, embedding_column_a = fc_old.shared_embedding_columns(
+ embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2(
[categorical_column_b, categorical_column_a],
dimension=embedding_dimension)
all_cols = [embedding_column_a, embedding_column_b]
with ops.Graph().as_default():
+ shared_state_manager1 = fc.SharedEmbeddingStateManager(
+ name='shared_feature_layer')
features = {
'aaa':
sparse_tensor.SparseTensor(
@@ -3055,12 +3001,16 @@ class FunctionalInputLayerTest(test.TestCase):
values=(1, 2, 1),
dense_shape=(2, 2)),
}
- fc.input_layer(features, all_cols)
+ FeatureLayer(
+ all_cols, shared_state_manager=shared_state_manager1)(
+ features)
# Make sure that only 1 variable gets created in this case.
self.assertEqual(1, len(
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
with ops.Graph().as_default():
+ shared_state_manager2 = fc.SharedEmbeddingStateManager(
+ name='shared_feature_layer')
features1 = {
'aaa':
sparse_tensor.SparseTensor(
@@ -3074,12 +3024,14 @@ class FunctionalInputLayerTest(test.TestCase):
dense_shape=(2, 2)),
}
- fc.input_layer(features1, all_cols)
+ FeatureLayer(
+ all_cols, shared_state_manager=shared_state_manager2)(
+ features1)
# Make sure that only 1 variable gets created in this case.
self.assertEqual(1, len(
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
self.assertItemsEqual(
- ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
+ ['shared_feature_layer/aaa_bbb_shared_embedding:0'],
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
def test_with_numpy_input_fn(self):
@@ -3092,14 +3044,14 @@ class FunctionalInputLayerTest(test.TestCase):
del shape, dtype, partition_info
return embedding_values
- # price has 1 dimension in input_layer
- price = fc_old.numeric_column('price')
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ # price has 1 dimension in feature_layer
+ price = fc.numeric_column('price')
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
- # one_hot_body_style has 3 dims in input_layer.
- one_hot_body_style = fc_old.indicator_column(body_style)
- # embedded_body_style has 5 dims in input_layer.
- embedded_body_style = fc_old.embedding_column(
+ # one_hot_body_style has 3 dims in feature_layer.
+ one_hot_body_style = fc.indicator_column(body_style)
+ # embedded_body_style has 5 dims in feature_layer.
+ embedded_body_style = fc.embedding_column(
body_style, dimension=5, initializer=_initializer)
input_fn = numpy_io.numpy_input_fn(
@@ -3110,8 +3062,8 @@ class FunctionalInputLayerTest(test.TestCase):
batch_size=2,
shuffle=False)
features = input_fn()
- net = fc.input_layer(features,
- [price, one_hot_body_style, embedded_body_style])
+ net = FeatureLayer([price, one_hot_body_style, embedded_body_style])(
+ features)
self.assertEqual(1 + 3 + 5, net.shape[1])
with _initialized_session() as sess:
coord = coordinator.Coordinator()
@@ -3137,18 +3089,18 @@ class FunctionalInputLayerTest(test.TestCase):
del shape, dtype, partition_info
return embedding_values
- # price has 1 dimension in input_layer
- price = fc_old.numeric_column('price')
+ # price has 1 dimension in feature_layer
+ price = fc.numeric_column('price')
- # one_hot_body_style has 3 dims in input_layer.
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ # one_hot_body_style has 3 dims in feature_layer.
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
- one_hot_body_style = fc_old.indicator_column(body_style)
+ one_hot_body_style = fc.indicator_column(body_style)
- # embedded_body_style has 5 dims in input_layer.
- country = fc_old.categorical_column_with_vocabulary_list(
+ # embedded_body_style has 5 dims in feature_layer.
+ country = fc.categorical_column_with_vocabulary_list(
'country', vocabulary_list=['US', 'JP', 'CA'])
- embedded_country = fc_old.embedding_column(
+ embedded_country = fc.embedding_column(
country, dimension=5, initializer=_initializer)
# Provides 1-dim tensor and dense tensor.
@@ -3165,8 +3117,7 @@ class FunctionalInputLayerTest(test.TestCase):
self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
self.assertEqual(1, features['country'].shape.ndims)
- net = fc.input_layer(features,
- [price, one_hot_body_style, embedded_country])
+ net = FeatureLayer([price, one_hot_body_style, embedded_country])(features)
self.assertEqual(1 + 3 + 5, net.shape[1])
with _initialized_session() as sess:
@@ -3187,18 +3138,18 @@ class FunctionalInputLayerTest(test.TestCase):
del shape, dtype, partition_info
return embedding_values
- # price has 1 dimension in input_layer
- price = fc_old.numeric_column('price')
+ # price has 1 dimension in feature_layer
+ price = fc.numeric_column('price')
- # one_hot_body_style has 3 dims in input_layer.
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ # one_hot_body_style has 3 dims in feature_layer.
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
- one_hot_body_style = fc_old.indicator_column(body_style)
+ one_hot_body_style = fc.indicator_column(body_style)
- # embedded_body_style has 5 dims in input_layer.
- country = fc_old.categorical_column_with_vocabulary_list(
+ # embedded_body_style has 5 dims in feature_layer.
+ country = fc.categorical_column_with_vocabulary_list(
'country', vocabulary_list=['US', 'JP', 'CA'])
- embedded_country = fc_old.embedding_column(
+ embedded_country = fc.embedding_column(
country, dimension=2, initializer=_initializer)
# Provides 1-dim tensor and dense tensor.
@@ -3219,8 +3170,7 @@ class FunctionalInputLayerTest(test.TestCase):
dense_shape=(2,))
country_data = np.array([['US'], ['CA']])
- net = fc.input_layer(features,
- [price, one_hot_body_style, embedded_country])
+ net = FeatureLayer([price, one_hot_body_style, embedded_country])(features)
self.assertEqual(1 + 3 + 2, net.shape[1])
with _initialized_session() as sess:
@@ -3237,8 +3187,8 @@ class FunctionalInputLayerTest(test.TestCase):
}))
def test_with_rank_0_feature(self):
- # price has 1 dimension in input_layer
- price = fc_old.numeric_column('price')
+ # price has 1 dimension in feature_layer
+ price = fc.numeric_column('price')
features = {
'price': constant_op.constant(0),
}
@@ -3246,13 +3196,13 @@ class FunctionalInputLayerTest(test.TestCase):
# Static rank 0 should fail
with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
- fc.input_layer(features, [price])
+ FeatureLayer([price])(features)
# Dynamic rank 0 should fail
features = {
'price': array_ops.placeholder(dtypes.float32),
}
- net = fc.input_layer(features, [price])
+ net = FeatureLayer([price])(features)
self.assertEqual(1, net.shape[1])
with _initialized_session() as sess:
with self.assertRaisesOpError('Feature .* cannot have rank 0'):
@@ -3267,7 +3217,7 @@ class MakeParseExampleSpecTest(test.TestCase):
@property
def name(self):
- return "_TestFeatureColumn"
+ return '_TestFeatureColumn'
def transform_feature(self, transformation_cache, state_manager):
pass
@@ -3593,25 +3543,6 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
dense_shape=inputs.dense_shape),
id_tensor.eval())
- def DISABLED_test_get_sparse_tensors_weight_collections(self):
- column = fc.categorical_column_with_vocabulary_file(
- key='aaa',
- vocabulary_file=self._wire_vocabulary_file_name,
- vocabulary_size=self._wire_vocabulary_size)
- inputs = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'],
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- column.get_sparse_tensors(
- FeatureTransformationCache({
- 'aaa': inputs
- }),
- weight_collections=('my_weights',))
-
- self.assertItemsEqual(
- [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
- self.assertItemsEqual([], ops.get_collection('my_weights'))
-
def test_get_sparse_tensors_dense_input(self):
column = fc.categorical_column_with_vocabulary_file(
key='aaa',
@@ -4043,24 +3974,6 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
dense_shape=inputs.dense_shape),
id_tensor.eval())
- def DISABLED_test_get_sparse_tensors_weight_collections(self):
- column = fc.categorical_column_with_vocabulary_list(
- key='aaa',
- vocabulary_list=('omar', 'stringer', 'marlo'))
- inputs = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'],
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- column.get_sparse_tensors(
- FeatureTransformationCache({
- 'aaa': inputs
- }),
- weight_collections=('my_weights',))
-
- self.assertItemsEqual(
- [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
- self.assertItemsEqual([], ops.get_collection('my_weights'))
-
def test_get_sparse_tensors_dense_input(self):
column = fc.categorical_column_with_vocabulary_list(
key='aaa',
@@ -4356,22 +4269,6 @@ class IdentityCategoricalColumnTest(test.TestCase):
dense_shape=inputs.dense_shape),
id_tensor.eval())
- def DISABLED_test_get_sparse_tensors_weight_collections(self):
- column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
- inputs = sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 1, 0),
- dense_shape=(2, 2))
- column.get_sparse_tensors(
- FeatureTransformationCache({
- 'aaa': inputs
- }),
- weight_collections=('my_weights',))
-
- self.assertItemsEqual(
- [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
- self.assertItemsEqual([], ops.get_collection('my_weights'))
-
def test_get_sparse_tensors_dense_input(self):
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
id_weight_pair = column.get_sparse_tensors(
@@ -4765,16 +4662,16 @@ class IndicatorColumnTest(test.TestCase):
weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
self.assertAllClose([[2. + 3.]], predictions.eval())
- def test_input_layer(self):
- animal = fc_old.indicator_column(
- fc_old.categorical_column_with_identity('animal', num_buckets=4))
+ def test_feature_layer(self):
+ animal = fc.indicator_column(
+ fc.categorical_column_with_identity('animal', num_buckets=4))
with ops.Graph().as_default():
features = {
'animal':
sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
}
- net = fc.input_layer(features, [animal])
+ net = FeatureLayer([animal])(features)
with _initialized_session():
self.assertAllClose([[0., 1., 1., 0.]], net.eval())
@@ -4786,12 +4683,13 @@ class _TestStateManager(StateManager):
self._all_variables = {}
self._trainable = trainable
- def get_variable(self,
- feature_column,
- name,
- shape,
- dtype=None,
- initializer=None):
+ def create_variable(self,
+ feature_column,
+ name,
+ shape,
+ dtype=None,
+ trainable=True,
+ initializer=None):
if feature_column not in self._all_variables:
self._all_variables[feature_column] = {}
var_dict = self._all_variables[feature_column]
@@ -4801,11 +4699,19 @@ class _TestStateManager(StateManager):
var = variable_scope.get_variable(
name=name,
shape=shape,
- initializer=initializer,
- trainable=self._trainable)
+ dtype=dtype,
+ trainable=self._trainable and trainable,
+ initializer=initializer)
var_dict[name] = var
return var
+ def get_variable(self, feature_column, name):
+ if feature_column not in self._all_variables:
+ raise ValueError('Do not recognize FeatureColumn.')
+ if name in self._all_variables[feature_column]:
+ return self._all_variables[feature_column][name]
+ raise ValueError('Could not find variable.')
+
class EmbeddingColumnTest(test.TestCase):
@@ -4967,6 +4873,7 @@ class EmbeddingColumnTest(test.TestCase):
categorical_column, dimension=embedding_dimension,
initializer=_initializer)
state_manager = _TestStateManager()
+ embedding_column.create_state(state_manager)
# Provide sparse input and get dense result.
embedding_lookup = embedding_column.get_dense_tensor(
@@ -5028,6 +4935,7 @@ class EmbeddingColumnTest(test.TestCase):
categorical_column, dimension=embedding_dimension,
initializer=_initializer)
state_manager = _TestStateManager()
+ embedding_column.create_state(state_manager)
# Provide sparse input and get dense result.
embedding_lookup = embedding_column.get_dense_tensor(
@@ -5043,36 +4951,6 @@ class EmbeddingColumnTest(test.TestCase):
self.assertAllEqual(embedding_values, global_vars[0].eval())
self.assertAllEqual(expected_lookups, embedding_lookup.eval())
- def DISABLED_test_get_dense_tensor_weight_collections(self):
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- # example 2, ids []
- # example 3, ids [1]
- indices=((0, 0), (1, 0), (1, 4), (3, 0)),
- values=(2, 0, 1, 1),
- dense_shape=(4, 5))
-
- # Build columns.
- categorical_column = fc.categorical_column_with_identity(
- key='aaa', num_buckets=3)
- embedding_column = fc.embedding_column(categorical_column, dimension=2)
-
- # Provide sparse input and get dense result.
- embedding_column.get_dense_tensor(
- FeatureTransformationCache({
- 'aaa': sparse_input
- }),
- weight_collections=('my_vars',))
-
- # Assert expected embedding variable and lookups.
- global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(('embedding_weights:0',),
- tuple([v.name for v in global_vars]))
- my_vars = ops.get_collection('my_vars')
- self.assertItemsEqual(
- ('embedding_weights:0',), tuple([v.name for v in my_vars]))
-
def test_get_dense_tensor_placeholder_inputs(self):
# Inputs.
vocabulary_size = 3
@@ -5117,6 +4995,7 @@ class EmbeddingColumnTest(test.TestCase):
categorical_column, dimension=embedding_dimension,
initializer=_initializer)
state_manager = _TestStateManager()
+ embedding_column.create_state(state_manager)
# Provide sparse input and get dense result.
input_indices = array_ops.placeholder(dtype=dtypes.int64)
@@ -5187,6 +5066,7 @@ class EmbeddingColumnTest(test.TestCase):
ckpt_to_load_from=ckpt_path,
tensor_name_in_ckpt=ckpt_tensor)
state_manager = _TestStateManager()
+ embedding_column.create_state(state_manager)
# Provide sparse input and get dense result.
embedding_lookup = embedding_column.get_dense_tensor(
@@ -5354,7 +5234,7 @@ class EmbeddingColumnTest(test.TestCase):
# = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
- def test_input_layer(self):
+ def test_feature_layer(self):
# Inputs.
vocabulary_size = 3
sparse_input = sparse_tensor.SparseTensorValue(
@@ -5392,30 +5272,29 @@ class EmbeddingColumnTest(test.TestCase):
)
# Build columns.
- categorical_column = fc_old.categorical_column_with_identity(
+ categorical_column = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- embedding_column = fc_old.embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column,
dimension=embedding_dimension,
initializer=_initializer)
# Provide sparse input and get dense result.
- input_layer = fc.input_layer({'aaa': sparse_input}, (embedding_column,))
+ l = FeatureLayer((embedding_column,))
+ feature_layer = l({'aaa': sparse_input})
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(
- ('input_layer/aaa_embedding/embedding_weights:0',),
- tuple([v.name for v in global_vars]))
+ self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- self.assertItemsEqual(
- ('input_layer/aaa_embedding/embedding_weights:0',),
- tuple([v.name for v in trainable_vars]))
+ self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',),
+ tuple([v.name for v in trainable_vars]))
with _initialized_session():
self.assertAllEqual(embedding_values, trainable_vars[0].eval())
- self.assertAllEqual(expected_lookups, input_layer.eval())
+ self.assertAllEqual(expected_lookups, feature_layer.eval())
- def test_input_layer_not_trainable(self):
+ def test_feature_layer_not_trainable(self):
# Inputs.
vocabulary_size = 3
sparse_input = sparse_tensor.SparseTensorValue(
@@ -5453,65 +5332,26 @@ class EmbeddingColumnTest(test.TestCase):
)
# Build columns.
- categorical_column = fc_old.categorical_column_with_identity(
+ categorical_column = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- embedding_column = fc_old.embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column,
dimension=embedding_dimension,
initializer=_initializer,
trainable=False)
# Provide sparse input and get dense result.
- input_layer = fc.input_layer({'aaa': sparse_input}, (embedding_column,))
+ feature_layer = FeatureLayer((embedding_column,))({'aaa': sparse_input})
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(
- ('input_layer/aaa_embedding/embedding_weights:0',),
- tuple([v.name for v in global_vars]))
+ self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
self.assertItemsEqual(
[], ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
with _initialized_session():
self.assertAllEqual(embedding_values, global_vars[0].eval())
- self.assertAllEqual(expected_lookups, input_layer.eval())
-
-
-class _TestSharedEmbeddingStateManager(StateManager):
- """Manages the state for shared embedding columns.
-
- This can handle multiple groups of shared embedding columns.
- """
-
- def __init__(self, trainable=True):
- # Dict of shared_embedding_collection_name to a dict of variables.
- self._all_variables = {}
- self._trainable = trainable
-
- def get_variable(self,
- feature_column,
- name,
- shape,
- dtype=None,
- initializer=None):
- if not isinstance(feature_column, fc.SharedEmbeddingColumn):
- raise ValueError(
- 'SharedEmbeddingStateManager can only handle SharedEmbeddingColumns. '
- 'Given type: {} '.format(type(feature_column)))
-
- collection_name = feature_column.shared_collection_name
- if collection_name not in self._all_variables:
- self._all_variables[collection_name] = {}
- var_dict = self._all_variables[collection_name]
- if name in var_dict:
- return var_dict[name]
- else:
- var = variable_scope.get_variable(
- name=name,
- shape=shape,
- initializer=initializer,
- trainable=self._trainable)
- var_dict[name] = var
- return var
+ self.assertAllEqual(expected_lookups, feature_layer.eval())
class SharedEmbeddingColumnTest(test.TestCase):
@@ -5522,7 +5362,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
- embedding_column_b, embedding_column_a = fc.shared_embedding_columns(
+ embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2(
[categorical_column_b, categorical_column_a],
dimension=embedding_dimension)
self.assertIs(categorical_column_a, embedding_column_a.categorical_column)
@@ -5560,7 +5400,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
- embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
combiner='my_combiner',
@@ -5605,7 +5445,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
- original_a, _ = fc.shared_embedding_columns(
+ original_a, _ = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
combiner='my_combiner',
@@ -5613,7 +5453,8 @@ class SharedEmbeddingColumnTest(test.TestCase):
shared_embedding_collection_name='shared_embedding_collection_name',
ckpt_to_load_from='my_ckpt',
tensor_name_in_ckpt='my_ckpt_tensor',
- max_norm=42., trainable=False)
+ max_norm=42.,
+ trainable=False)
for embedding_column_a in (original_a, copy.deepcopy(original_a)):
self.assertEqual('aaa', embedding_column_a.categorical_column.name)
self.assertEqual(3, embedding_column_a.categorical_column.num_buckets)
@@ -5642,8 +5483,9 @@ class SharedEmbeddingColumnTest(test.TestCase):
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
with self.assertRaisesRegexp(ValueError, 'initializer must be callable'):
- fc.shared_embedding_columns(
- [categorical_column_a, categorical_column_b], dimension=2,
+ fc.shared_embedding_columns_v2(
+ [categorical_column_a, categorical_column_b],
+ dimension=2,
initializer='not_fn')
def test_incompatible_column_type(self):
@@ -5656,7 +5498,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError, 'all categorical_columns must have the same type.*'
'IdentityCategoricalColumn.*HashedCategoricalColumn'):
- fc.shared_embedding_columns(
+ fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b, categorical_column_c],
dimension=2)
@@ -5669,11 +5511,11 @@ class SharedEmbeddingColumnTest(test.TestCase):
key='bbb', num_buckets=3)
weighted_categorical_column_b = fc.weighted_categorical_column(
categorical_column_b, weight_feature_key='bbb_weights')
- fc.shared_embedding_columns(
+ fc.shared_embedding_columns_v2(
[weighted_categorical_column_a, categorical_column_b], dimension=2)
- fc.shared_embedding_columns(
+ fc.shared_embedding_columns_v2(
[categorical_column_a, weighted_categorical_column_b], dimension=2)
- fc.shared_embedding_columns(
+ fc.shared_embedding_columns_v2(
[weighted_categorical_column_a, weighted_categorical_column_b],
dimension=2)
@@ -5682,8 +5524,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
b = fc.categorical_column_with_vocabulary_list(
key='bbb', vocabulary_list=('omar', 'stringer', 'marlo'))
- a_embedded, b_embedded = fc.shared_embedding_columns(
- [a, b], dimension=2)
+ a_embedded, b_embedded = fc.shared_embedding_columns_v2([a, b], dimension=2)
data = example_pb2.Example(features=feature_pb2.Features(
feature={
'aaa':
@@ -5717,8 +5558,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
def test_transform_feature(self):
a = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
b = fc.categorical_column_with_identity(key='bbb', num_buckets=3)
- a_embedded, b_embedded = fc.shared_embedding_columns(
- [a, b], dimension=2)
+ a_embedded, b_embedded = fc.shared_embedding_columns_v2([a, b], dimension=2)
features = {
'aaa': sparse_tensor.SparseTensor(
indices=((0, 0), (1, 0), (1, 1)),
@@ -5788,10 +5628,13 @@ class SharedEmbeddingColumnTest(test.TestCase):
key='aaa', num_buckets=vocabulary_size)
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
- dimension=embedding_dimension, initializer=_initializer)
- state_manager = _TestSharedEmbeddingStateManager()
+ dimension=embedding_dimension,
+ initializer=_initializer)
+ state_manager = fc.SharedEmbeddingStateManager(name='shared_feature_layer')
+ embedding_column_a.create_state(state_manager)
+ embedding_column_b.create_state(state_manager)
# Provide sparse input and get dense result.
embedding_lookup_a = embedding_column_a.get_dense_tensor(
@@ -5801,7 +5644,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(('embedding_weights:0',),
+ self.assertItemsEqual(('shared_feature_layer/aaa_bbb_shared_embedding:0',),
tuple([v.name for v in global_vars]))
embedding_var = global_vars[0]
with _initialized_session():
@@ -5809,58 +5652,6 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval())
self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval())
- def DISABLED_test_get_dense_tensor_weight_collections(self):
- # Inputs.
- vocabulary_size = 3
- # -1 values are ignored.
- input_a = np.array([
- [2, -1, -1], # example 0, ids [2]
- [0, 1, -1]
- ]) # example 1, ids [0, 1]
- input_b = np.array([
- [0, -1, -1], # example 0, ids [0]
- [-1, -1, -1]
- ]) # example 1, ids []
- input_features = {'aaa': input_a, 'bbb': input_b}
-
- # Embedding variable.
- embedding_dimension = 2
- embedding_values = (
- (1., 2.), # id 0
- (3., 5.), # id 1
- (7., 11.) # id 2
- )
-
- def _initializer(shape, dtype, partition_info):
- self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
- self.assertEqual(dtypes.float32, dtype)
- self.assertIsNone(partition_info)
- return embedding_values
-
- # Build columns.
- categorical_column_a = fc.categorical_column_with_identity(
- key='aaa', num_buckets=vocabulary_size)
- categorical_column_b = fc.categorical_column_with_identity(
- key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
- [categorical_column_a, categorical_column_b],
- dimension=embedding_dimension,
- initializer=_initializer)
-
- fc.input_layer(
- input_features, [embedding_column_a, embedding_column_b],
- weight_collections=('my_vars',))
-
- # Assert expected embedding variable and lookups.
- global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(
- ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
- tuple(v.name for v in global_vars))
- my_vars = ops.get_collection('my_vars')
- self.assertItemsEqual(
- ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
- tuple(v.name for v in my_vars))
-
def test_get_dense_tensor_placeholder_inputs(self):
# Inputs.
vocabulary_size = 3
@@ -5903,10 +5694,13 @@ class SharedEmbeddingColumnTest(test.TestCase):
key='aaa', num_buckets=vocabulary_size)
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
- dimension=embedding_dimension, initializer=_initializer)
- state_manager = _TestSharedEmbeddingStateManager()
+ dimension=embedding_dimension,
+ initializer=_initializer)
+ state_manager = fc.SharedEmbeddingStateManager()
+ embedding_column_a.create_state(state_manager)
+ embedding_column_b.create_state(state_manager)
# Provide sparse input and get dense result.
embedding_lookup_a = embedding_column_a.get_dense_tensor(
@@ -6096,7 +5890,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
# = [3*1 + 5*2, 3*0 +5*0] = [13, 0]
self.assertAllClose([[94. + 13.], [29.]], predictions.eval())
- def _test_input_layer(self, trainable=True):
+ def _test_feature_layer(self, trainable=True):
# Inputs.
vocabulary_size = 3
sparse_input_a = sparse_tensor.SparseTensorValue(
@@ -6111,6 +5905,18 @@ class SharedEmbeddingColumnTest(test.TestCase):
indices=((0, 0),),
values=(0,),
dense_shape=(2, 5))
+ sparse_input_c = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 1), (1, 1), (1, 3)),
+ values=(2, 0, 1),
+ dense_shape=(2, 5))
+ sparse_input_d = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids []
+ indices=((0, 1),),
+ values=(2,),
+ dense_shape=(2, 5))
# Embedding variable.
embedding_dimension = 2
@@ -6130,51 +5936,127 @@ class SharedEmbeddingColumnTest(test.TestCase):
# example 0:
# A ids [2], embedding = [7, 11]
# B ids [0], embedding = [1, 2]
- (7., 11., 1., 2.),
+ # C ids [2], embedding = [7, 11]
+ # D ids [2], embedding = [7, 11]
+ (7., 11., 1., 2., 7., 11., 7., 11.),
# example 1:
# A ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
# B ids [], embedding = [0, 0]
- (2., 3.5, 0., 0.),
+ # C ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ # D ids [], embedding = [0, 0]
+ (2., 3.5, 0., 0., 2., 3.5, 0., 0.),
)
# Build columns.
- categorical_column_a = fc_old.categorical_column_with_identity(
+ categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- categorical_column_b = fc_old.categorical_column_with_identity(
+ categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns(
+ categorical_column_c = fc.categorical_column_with_identity(
+ key='ccc', num_buckets=vocabulary_size)
+ categorical_column_d = fc.categorical_column_with_identity(
+ key='ddd', num_buckets=vocabulary_size)
+
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
initializer=_initializer,
trainable=trainable)
+ embedding_column_c, embedding_column_d = fc.shared_embedding_columns_v2(
+ [categorical_column_c, categorical_column_d],
+ dimension=embedding_dimension,
+ initializer=_initializer,
+ trainable=trainable)
+ shared_state_manager = fc.SharedEmbeddingStateManager(
+ name='shared_feature_layer')
+
+ features = {
+ 'aaa': sparse_input_a,
+ 'bbb': sparse_input_b,
+ 'ccc': sparse_input_c,
+ 'ddd': sparse_input_d
+ }
# Provide sparse input and get dense result.
- input_layer = fc.input_layer(
- features={'aaa': sparse_input_a, 'bbb': sparse_input_b},
- feature_columns=(embedding_column_b, embedding_column_a))
+ feature_layer = FeatureLayer(
+ feature_columns=(embedding_column_b, embedding_column_a,
+ embedding_column_c, embedding_column_d),
+ shared_state_manager=shared_state_manager)(
+ features)
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(
- ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
- tuple([v.name for v in global_vars]))
+ self.assertItemsEqual([
+ 'shared_feature_layer/aaa_bbb_shared_embedding:0',
+ 'shared_feature_layer/ccc_ddd_shared_embedding:0'
+ ], tuple([v.name for v in global_vars]))
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
if trainable:
- self.assertItemsEqual(
- ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
- tuple([v.name for v in trainable_vars]))
+ self.assertItemsEqual([
+ 'shared_feature_layer/aaa_bbb_shared_embedding:0',
+ 'shared_feature_layer/ccc_ddd_shared_embedding:0'
+ ], tuple([v.name for v in trainable_vars]))
else:
self.assertItemsEqual([], tuple([v.name for v in trainable_vars]))
shared_embedding_vars = global_vars
with _initialized_session():
self.assertAllEqual(embedding_values, shared_embedding_vars[0].eval())
- self.assertAllEqual(expected_lookups, input_layer.eval())
+ self.assertAllEqual(expected_lookups, feature_layer.eval())
+
+ def test_feature_layer(self):
+ self._test_feature_layer()
+
+ def test_feature_layer_no_trainable(self):
+ self._test_feature_layer(trainable=False)
+
- def test_input_layer(self):
- self._test_input_layer()
+class SharedEmbeddingStateManagerTest(test.TestCase):
- def test_input_layer_no_trainable(self):
- self._test_input_layer(trainable=False)
+ def test_basic(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ fc.shared_embedding_columns_v2(
+ [categorical_column_a, categorical_column_b], dimension=2)
+ shared_state_manager = fc.SharedEmbeddingStateManager(
+ name='shared_feature_layer')
+ var_a = shared_state_manager.create_variable('aaa_bbb_shared_embedding',
+ [5, 10])
+ var_b = shared_state_manager.create_variable('aaa_bbb_shared_embedding',
+ [5, 10])
+ self.assertEqual(var_a, var_b)
+ self.assertEqual('shared_feature_layer/aaa_bbb_shared_embedding:0',
+ var_a.name)
+ self.assertIsInstance(var_a, variables_lib.Variable)
+
+ def test_multiple_sets(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ categorical_column_c = fc.categorical_column_with_identity(
+ key='ccc', num_buckets=3)
+ categorical_column_d = fc.categorical_column_with_identity(
+ key='ddd', num_buckets=3)
+
+ fc.shared_embedding_columns_v2(
+ [categorical_column_a, categorical_column_b], dimension=2)
+ fc.shared_embedding_columns_v2(
+ [categorical_column_c, categorical_column_d], dimension=2)
+ shared_state_manager = fc.SharedEmbeddingStateManager(
+ name='shared_feature_layer')
+ var_a = shared_state_manager.create_variable('aaa_bbb_shared_embedding',
+ [5, 10])
+ var_c = shared_state_manager.create_variable('ccc_ddd_shared_embedding',
+ [5, 10])
+ self.assertIsInstance(var_a, variables_lib.Variable)
+ self.assertIsInstance(var_c, variables_lib.Variable)
+ self.assertNotEquals(var_a, var_c)
+ self.assertEqual('shared_feature_layer/aaa_bbb_shared_embedding:0',
+ var_a.name)
+ self.assertEqual('shared_feature_layer/ccc_ddd_shared_embedding:0',
+ var_c.name)
class WeightedCategoricalColumnTest(test.TestCase):
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index f47c0d8a5e..a8aef3a009 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -23,7 +23,6 @@ from __future__ import print_function
import collections
import hashlib
-import sys
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
@@ -34,7 +33,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import cond_v2_impl
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import compat
@@ -42,9 +40,6 @@ from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_inspect
-# This is to avoid a circular dependency with cond_v2_impl.
-cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-access
-
class Defun(object):
"""Decorator used to define TensorFlow functions.
@@ -1029,20 +1024,10 @@ def _from_definition(fdef, grad_func=None):
result = _DefinedFunction(func, argnames, input_types, func_name, grad_func,
python_grad_func, out_names)
# pylint: disable=protected-access
- if ops._USE_C_API:
- serialized = fdef.SerializeToString()
- c_func = c_api.TF_FunctionImportFunctionDef(serialized)
- result._c_func = c_api_util.ScopedTFFunction(c_func)
- result._extra_inputs = []
- else:
- result._definition = fdef
- # Captured inputs are added as regular inputs to a function when it's
- # serialized, i.e. any extra inputs from the original function are now
- # included in `result`._args
- result._extra_inputs = []
- result._hash_str = result._create_hash_str(
- result._definition.signature.input_arg,
- result._definition.signature.output_arg, result._definition.node_def)
+ serialized = fdef.SerializeToString()
+ c_func = c_api.TF_FunctionImportFunctionDef(serialized)
+ result._c_func = c_api_util.ScopedTFFunction(c_func)
+ result._extra_inputs = []
# pylint: enable=protected-access
return result
diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py
index 1b09506662..a04fa369ae 100644
--- a/tensorflow/python/framework/function_def_to_graph.py
+++ b/tensorflow/python/framework/function_def_to_graph.py
@@ -23,7 +23,7 @@ import sys
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.framework import versions_pb2
-from tensorflow.python.framework import function
+from tensorflow.python.eager import function
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import versions
@@ -34,13 +34,13 @@ cond_v2_impl._function_def_to_graph = sys.modules[__name__] # pylint: disable=p
def function_def_to_graph(fdef, input_shapes=None):
- """Converts a FunctionDef to a function._FuncGraph (sub-class Graph).
+ """Converts a FunctionDef to a function.FuncGraph (sub-class Graph).
- The returned _FuncGraph's `name`, `inputs` and `outputs` fields will be set.
+ The returned FuncGraph's `name`, `inputs` and `outputs` fields will be set.
The input tensors are represented as placeholders.
- Note: `_FuncGraph.inputs` and `_FuncGraph._captured` are not set and may be
- set by the caller.
+ Note: `FuncGraph.inputs` and `FuncGraph.captures` are not set and may be set
+ by the caller.
Args:
fdef: FunctionDef.
@@ -50,9 +50,9 @@ def function_def_to_graph(fdef, input_shapes=None):
placeholder will have unknown shape.
Returns:
- A _FuncGraph.
+ A FuncGraph.
"""
- func_graph = function._FuncGraph(fdef.signature.name, capture_by_value=False) # pylint: disable=protected-access
+ func_graph = function.FuncGraph(fdef.signature.name)
graph_def, nested_to_flat_tensor_name = function_def_to_graph_def(
fdef, input_shapes)
@@ -60,7 +60,7 @@ def function_def_to_graph(fdef, input_shapes=None):
# Add all function nodes to the graph.
importer.import_graph_def(graph_def, name="")
- # Initialize fields specific to _FuncGraph.
+ # Initialize fields specific to FuncGraph.
# inputs
input_tensor_names = [
@@ -144,6 +144,8 @@ def function_def_to_graph_def(fdef, input_shapes=None):
for arg_def in fdef.signature.input_arg:
nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name)
+ control_name = "^" + arg_def.name
+ nested_to_flat_tensor_name[control_name] = control_name
for node_def in fdef.node_def:
op_def = ops.get_default_graph()._get_op_def(node_def.op) # pylint: disable=protected-access
@@ -172,6 +174,8 @@ def function_def_to_graph_def(fdef, input_shapes=None):
flat_name = "{}:{}".format(node_def.name, flattened_index)
nested_to_flat_tensor_name[nested_name] = flat_name
flattened_index += 1
+ control_name = "^" + node_def.name
+ nested_to_flat_tensor_name[control_name] = control_name
# Update inputs of all nodes in graph.
for node_def in graph_def.node:
diff --git a/tensorflow/python/framework/function_def_to_graph_test.py b/tensorflow/python/framework/function_def_to_graph_test.py
index 21d2c7d990..938814f1d0 100644
--- a/tensorflow/python/framework/function_def_to_graph_test.py
+++ b/tensorflow/python/framework/function_def_to_graph_test.py
@@ -18,9 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
from tensorflow.python.framework import function_def_to_graph
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
@@ -154,14 +154,20 @@ class FunctionDefToGraphDefTest(test.TestCase):
self.assertDictEqual(
tensor_name_map, {
"x": "x:0",
+ "^x": "^x",
"y": "y:0",
+ "^y": "^y",
"z": "z:0",
+ "^z": "^z",
"foo_1:d:0": "foo_1:0",
"foo_1:e:0": "foo_1:1",
+ "^foo_1": "^foo_1",
"list_output:a:0": "list_output:0",
"list_output:a:1": "list_output:1",
+ "^list_output": "^list_output",
"foo_2:d:0": "foo_2:0",
"foo_2:e:0": "foo_2:1",
+ "^foo_2": "^foo_2",
})
def testShapes(self):
@@ -184,23 +190,25 @@ class FunctionDefToGraphDefTest(test.TestCase):
x = constant_op.constant(5.0)
y = constant_op.constant(10.0)
- @function.Defun()
+ @function.defun
def fn():
- @function.Defun()
+ @function.defun
def inner_fn():
return x + y
return inner_fn()
- # Instantiate the function in this graph so that
- # `function_def_to_graph` can find it.
- fn()
-
def fn2():
return 2 * fn()
- fdef = function._DefinedFunction(fn2, [], []).definition
+ fn2_defun = function.make_defun_op(fn2)
+
+ # Call `fn2` to make sure `fn` is correctly instantiated so
+ # `function_def_to_graph` can find it.
+ fn2_defun()
+
+ fdef = fn2_defun._inference_function.definition
func_graph = function_def_to_graph.function_def_to_graph(fdef)
with func_graph.as_default():
x_ph, y_ph = func_graph.inputs
@@ -211,6 +219,25 @@ class FunctionDefToGraphDefTest(test.TestCase):
y_ph: 10.0
}), 30.0)
+ def testControlDependencies(self):
+
+ def fn(inp):
+ x = constant_op.constant(2.0, name="x")
+ # TODO(b/79881896): Test external control dependency once that's
+ # supported.
+ with ops.control_dependencies([x, inp]):
+ constant_op.constant(3.0, name="y")
+ return 4.0
+
+ inp = constant_op.constant(1.0)
+ fdef = function.make_defun_op(fn, inp)._inference_function.definition
+ func_graph = function_def_to_graph.function_def_to_graph(fdef)
+
+ op = func_graph.get_operation_by_name("y")
+ self.assertEqual(len(op.control_inputs), 2)
+ self.assertEqual(op.control_inputs[0].name, "x")
+ self.assertEqual(op.control_inputs[1].name, "placeholder")
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 5ebe43ff93..8c85a422e7 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import collections
import copy
-import os
import re
import sys
import threading
@@ -67,7 +66,7 @@ from tensorflow.python.util.tf_export import tf_export
# Temporary global switches determining if we should enable the work-in-progress
# calls to the C API. These will be removed once all functionality is supported.
_USE_C_API = True
-_USE_C_SHAPES = os.getenv("TF_C_API_GRAPH_CONSTRUCTION_SHAPES", "1") != "0"
+_USE_C_SHAPES = True
def tensor_id(tensor):
@@ -2859,19 +2858,11 @@ class Graph(object):
# TODO(skyewm): fold as much of the above as possible into the C
# implementation
- if self._use_c_api_hack():
- self._scoped_c_graph = c_api_util.ScopedTFGraph()
- # The C API requires all ops to have shape functions. Disable this
- # requirement (many custom ops do not have shape functions, and we don't
- # want to break these existing cases).
- c_api.SetRequireShapeInferenceFns(self._c_graph, False)
- else:
- self._scoped_c_graph = None
-
- # TODO(apassos) remove once the C API is used by default.
- def _use_c_api_hack(self):
- """Temporary hack; can be overridden to force C API usage."""
- return _USE_C_API
+ self._scoped_c_graph = c_api_util.ScopedTFGraph()
+ # The C API requires all ops to have shape functions. Disable this
+ # requirement (many custom ops do not have shape functions, and we don't
+ # want to break these existing cases).
+ c_api.SetRequireShapeInferenceFns(self._c_graph, False)
# Note: this method is private because the API of tf.Graph() is public and
# frozen, and this functionality is still not ready for public visibility.
@@ -3121,7 +3112,7 @@ class Graph(object):
Returns:
bool indicating whether or not 'name' is registered in function library.
"""
- return name in self._functions
+ return compat.as_str(name) in self._functions
def _get_function(self, name):
"""Returns the function definition for 'name'.
@@ -3131,7 +3122,7 @@ class Graph(object):
Returns:
The function def proto.
"""
- return self._functions.get(name, None)
+ return self._functions.get(compat.as_str(name), None)
def _add_function(self, function):
"""Adds a function to the graph.
@@ -3167,7 +3158,7 @@ class Graph(object):
c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient)
# pylint: enable=protected-access
- self._functions[name] = function
+ self._functions[compat.as_str(name)] = function
# Need a new-enough consumer to support the functions we add to the graph.
if self._graph_def_versions.min_consumer < 12:
diff --git a/tensorflow/python/framework/smart_cond.py b/tensorflow/python/framework/smart_cond.py
index 48a834392b..7ee2b5b347 100644
--- a/tensorflow/python/framework/smart_cond.py
+++ b/tensorflow/python/framework/smart_cond.py
@@ -77,11 +77,9 @@ def smart_constant_value(pred):
pred_value = pred
elif isinstance(pred, ops.Tensor):
pred_value = tensor_util.constant_value(pred)
- # TODO(skyewm): consider folding this into tensor_util.constant_value when
- # _USE_C_API is removed (there may be performance and correctness bugs, so I
- # wanted to limit the change hidden behind _USE_C_API).
+ # TODO(skyewm): consider folding this into tensor_util.constant_value.
# pylint: disable=protected-access
- if pred_value is None and ops._USE_C_API:
+ if pred_value is None:
pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph,
pred._as_tf_output())
# pylint: enable=protected-access
diff --git a/tensorflow/python/framework/subscribe.py b/tensorflow/python/framework/subscribe.py
index cee7398974..00759eb611 100644
--- a/tensorflow/python/framework/subscribe.py
+++ b/tensorflow/python/framework/subscribe.py
@@ -137,12 +137,7 @@ def _subscribe_new(tensor, side_effects, control_cache):
# are subscribed at the same time, we remove the control dependency from
# the original op only once and we add the dependencies to all the
# new identities.
- if ops._USE_C_API: # pylint: disable=protected-access
- new_control_inputs = consumer_op.control_inputs
- else:
- # Make a copy so we don't modify the actual control inputs (this is fixed
- # in the C API).
- new_control_inputs = list(consumer_op.control_inputs)
+ new_control_inputs = consumer_op.control_inputs
if tensor.op in new_control_inputs:
new_control_inputs.remove(tensor.op)
new_control_inputs.append(out.op)
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index ec0daeaddb..266af56611 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -1409,8 +1409,8 @@ class TestCTC(test.TestCase):
np.array([seq_len_0], dtype=np.int32))
# batch_size length vector of negative log probabilities
log_prob_truth = np.array([
- 0.584855, # output beam 0
- 0.389139 # output beam 1
+ -3.5821197, # output beam 0
+ -3.777835 # output beam 1
], np.float32)[np.newaxis, :]
decode_truth = [np.array([1, 0]), np.array([0, 1, 0])]
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py
index cf6fb44275..9f4019e29c 100644
--- a/tensorflow/python/keras/engine/sequential.py
+++ b/tensorflow/python/keras/engine/sequential.py
@@ -332,6 +332,7 @@ class Sequential(Model):
else:
name = None
build_input_shape = None
+ layer_configs = config
model = cls(name=name)
for layer_config in layer_configs:
layer = layer_module.deserialize(layer_config,
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index 65171acfb6..cff612a8de 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -73,19 +73,27 @@ class StackedRNNCells(Layer):
'`state_size` attribute. '
'received cells:', cells)
self.cells = cells
+ # reverse_state_order determines whether the state size will be in a reverse
+ # order of the cells' state. User might want to set this to True to keep the
+ # existing behavior. This is only useful when use RNN(return_state=True)
+ # since the state will be returned as the same order of state_size.
+ self.reverse_state_order = kwargs.pop('reverse_state_order', False)
+ if self.reverse_state_order:
+ logging.warning('reverse_state_order=True in StackedRNNCells will soon '
+ 'be deprecated. Please update the code to work with the '
+ 'natural order of states if you reply on the RNN states, '
+ 'eg RNN(return_state=True).')
super(StackedRNNCells, self).__init__(**kwargs)
@property
def state_size(self):
- # States are a flat list
- # in reverse order of the cell stack.
- # This allows to preserve the requirement
- # `stack.state_size[0] == output_dim`.
- # e.g. states of a 2-layer LSTM would be
- # `[h2, c2, h1, c1]`
+ # States are a flat list of the individual cell state size.
+ # e.g. states of a 2-layer LSTM would be `[h1, c1, h2, c2]`.
# (assuming one LSTM has states [h, c])
+ # In the case of reverse_state_order=True, the state_size will be
+ # [h2, c2, h1, c1].
state_size = []
- for cell in self.cells[::-1]:
+ for cell in self.cells[::-1] if self.reverse_state_order else self.cells:
if _is_multiple_state(cell.state_size):
state_size += list(cell.state_size)
else:
@@ -96,15 +104,16 @@ class StackedRNNCells(Layer):
def output_size(self):
if getattr(self.cells[-1], 'output_size', None) is not None:
return self.cells[-1].output_size
+ elif _is_multiple_state(self.cells[-1].state_size):
+ return self.cells[-1].state_size[0]
else:
- return self.state_size[0]
+ return self.cells[-1].state_size
def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
- # The init state is in reverse order of cell's initial state since the
- # state_size is in reverse order. It is flattened into a list also because
- # the state_size is a flattened list.
+ # The init state is flattened into a list because state_size is a flattened
+ # list.
initial_states = []
- for cell in self.cells[::-1]:
+ for cell in self.cells[::-1] if self.reverse_state_order else self.cells:
get_initial_state_fn = getattr(cell, 'get_initial_state', None)
if get_initial_state_fn:
initial_states.append(get_initial_state_fn(
@@ -118,14 +127,15 @@ class StackedRNNCells(Layer):
def call(self, inputs, states, constants=None, **kwargs):
# Recover per-cell states.
nested_states = []
- for cell in self.cells[::-1]:
+ for cell in self.cells[::-1] if self.reverse_state_order else self.cells:
if _is_multiple_state(cell.state_size):
nested_states.append(states[:len(cell.state_size)])
states = states[len(cell.state_size):]
else:
nested_states.append([states[0]])
states = states[1:]
- nested_states = nested_states[::-1]
+ if self.reverse_state_order:
+ nested_states = nested_states[::-1]
# Call the cells in order and store the returned states.
new_nested_states = []
@@ -139,11 +149,12 @@ class StackedRNNCells(Layer):
new_nested_states.append(states)
# Format the new states as a flat list
- # in reverse cell order.
- states = []
- for cell_states in new_nested_states[::-1]:
- states += cell_states
- return inputs, states
+ new_states = []
+ if self.reverse_state_order:
+ new_nested_states = new_nested_states[::-1]
+ for cell_states in new_nested_states:
+ new_states += cell_states
+ return inputs, new_states
@tf_utils.shape_type_conversion
def build(self, input_shape):
@@ -156,7 +167,9 @@ class StackedRNNCells(Layer):
cell.build([input_shape] + constants_shape)
else:
cell.build(input_shape)
- if _is_multiple_state(cell.state_size):
+ if getattr(cell, 'output_size', None) is not None:
+ output_dim = cell.output_size
+ elif _is_multiple_state(cell.state_size):
output_dim = cell.state_size[0]
else:
output_dim = cell.state_size
diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py
index b52bfc05a5..a3861e44d5 100644
--- a/tensorflow/python/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/layers/recurrent_test.py
@@ -103,7 +103,8 @@ class RNNTest(test.TestCase):
MinimalRNNCell(16, 8),
MinimalRNNCell(32, 16)]
layer = keras.layers.RNN(cells)
- assert layer.cell.state_size == (32, 32, 16, 16, 8, 8)
+ self.assertEqual(layer.cell.state_size, (8, 8, 16, 16, 32, 32))
+ self.assertEqual(layer.cell.output_size, 32)
y = layer(x)
model = keras.models.Model(x, y)
model.compile(optimizer='rmsprop', loss='mse')
@@ -551,6 +552,21 @@ class RNNTest(test.TestCase):
layer = keras.layers.RNN(cells, return_state=True, return_sequences=True)
output_shape = layer.compute_output_shape((None, timesteps, embedding_dim))
expected_output_shape = [(None, timesteps, 6),
+ (None, 3),
+ (None, 3),
+ (None, 6),
+ (None, 6)]
+ self.assertEqual(
+ [tuple(o.as_list()) for o in output_shape],
+ expected_output_shape)
+
+ # Test reverse_state_order = True for stacked cell.
+ stacked_cell = keras.layers.StackedRNNCells(
+ cells, reverse_state_order=True)
+ layer = keras.layers.RNN(
+ stacked_cell, return_state=True, return_sequences=True)
+ output_shape = layer.compute_output_shape((None, timesteps, embedding_dim))
+ expected_output_shape = [(None, timesteps, 6),
(None, 6),
(None, 6),
(None, 3),
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index 6bc256d2ec..39b6042597 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -33,6 +33,7 @@ from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.checkpointable import data_structures
+from tensorflow.python.util.tf_export import tf_export
# API entries importable from `keras.models`:
Model = training.Model # pylint: disable=invalid-name
@@ -226,6 +227,7 @@ def _clone_sequential_model(model, input_tensors=None):
return Sequential(layers=[input_layer] + layers, name=model.name)
+@tf_export('keras.models.clone_model')
def clone_model(model, input_tensors=None):
"""Clone any `Model` instance.
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index b9910133d8..0dc3c53bc0 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -20,9 +20,9 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2
@@ -158,7 +158,7 @@ class CondV2Test(test.TestCase):
def true_fn():
- @function.Defun()
+ @function.defun
def fn():
return x * y * 2.0
@@ -172,6 +172,8 @@ class CondV2Test(test.TestCase):
self._testCond(true_fn, false_fn, [y])
def testNestedDefunInCond(self):
+ self.skipTest("b/110550782")
+
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -180,10 +182,10 @@ class CondV2Test(test.TestCase):
def false_fn():
- @function.Defun()
+ @function.defun
def fn():
- @function.Defun()
+ @function.defun
def nested_fn():
return x * y * 2.0
@@ -196,18 +198,20 @@ class CondV2Test(test.TestCase):
self._testCond(true_fn, false_fn, [y])
def testDoubleNestedDefunInCond(self):
+ self.skipTest("b/110550782")
+
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
def true_fn():
- @function.Defun()
+ @function.defun
def fn():
- @function.Defun()
+ @function.defun
def nested_fn():
- @function.Defun()
+ @function.defun
def nested_nested_fn():
return x * y * 2.0
@@ -368,7 +372,7 @@ class CondV2Test(test.TestCase):
pred_outer, true_fn, false_fn, name="outer_cond")
# Compute grads inside a Defun.
- @function.Defun()
+ @function.defun
def nesting_fn():
return gradients_impl.gradients(cond_outer, [x, y])
@@ -426,10 +430,10 @@ class CondV2Test(test.TestCase):
pred_outer, true_fn, false_fn, name="outer_cond")
# Compute grads inside a Defun.
- @function.Defun()
+ @function.defun
def nesting_fn():
- @function.Defun()
+ @function.defun
def inner_nesting_fn():
return gradients_impl.gradients(cond_outer, [x, y])
@@ -464,6 +468,7 @@ class CondV2Test(test.TestCase):
}), [5., 0.])
def testBuildCondAndGradientInsideDefun(self):
+ self.skipTest("b/110550782")
def build_graph():
pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer")
@@ -472,7 +477,7 @@ class CondV2Test(test.TestCase):
y = constant_op.constant(2.0, name="y")
# Build cond and its gradient inside a Defun.
- @function.Defun()
+ @function.defun
def fn():
def true_fn():
@@ -718,6 +723,7 @@ class CondV2ContainerTest(test.TestCase):
Make sure the containers are set correctly for both variable creation
(tested by variables.Variable) and for stateful ops (tested by FIFOQueue)
"""
+ self.skipTest("b/113048653")
with ops.Graph().as_default() as g:
with self.test_session(graph=g):
@@ -795,6 +801,7 @@ class CondV2ContainerTest(test.TestCase):
class CondV2ColocationGroupAndDeviceTest(test.TestCase):
def testColocateWithBeforeCond(self):
+ self.skipTest("b/112414483")
with ops.Graph().as_default() as g:
with self.test_session(graph=g):
@@ -819,6 +826,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
def testColocateWithInAndOutOfCond(self):
+ self.skipTest("b/112414483")
with ops.Graph().as_default() as g:
with self.test_session(graph=g):
@@ -866,6 +874,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
self.assertTrue(len(run_metadata.partition_graphs) >= 2)
def testDeviceBeforeCond(self):
+ self.skipTest("b/112166045")
with ops.Graph().as_default() as g:
with self.test_session(graph=g):
def fn():
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 5e0447e4ff..4a3e767f4d 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -32,6 +32,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import device_lib
from tensorflow.python.client import session
from tensorflow.python.eager import context
+from tensorflow.python.eager import function as _ # pylint: disable=unused-import
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
diff --git a/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py b/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py
index e1920eb568..41ae0b456f 100644
--- a/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py
+++ b/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py
@@ -188,11 +188,11 @@ class CTCGreedyDecoderTest(test.TestCase):
],
dtype=np.float32)
# Add arbitrary offset - this is fine
- input_log_prob_matrix_0 = np.log(input_prob_matrix_0) + 2.0
+ input_prob_matrix_0 = input_prob_matrix_0 + 2.0
# len max_time_steps array of batch_size x depth matrices
inputs = ([
- input_log_prob_matrix_0[t, :][np.newaxis, :] for t in range(seq_len_0)
+ input_prob_matrix_0[t, :][np.newaxis, :] for t in range(seq_len_0)
] # Pad to max_time_steps = 8
+ 2 * [np.zeros(
(1, depth), dtype=np.float32)])
@@ -200,11 +200,11 @@ class CTCGreedyDecoderTest(test.TestCase):
# batch_size length vector of sequence_lengths
seq_lens = np.array([seq_len_0], dtype=np.int32)
- # batch_size length vector of negative log probabilities
+ # batch_size length vector of log probabilities
log_prob_truth = np.array(
[
- 0.584855, # output beam 0
- 0.389139 # output beam 1
+ -5.811451, # output beam 0
+ -6.63339 # output beam 1
],
np.float32)[np.newaxis, :]
@@ -215,11 +215,11 @@ class CTCGreedyDecoderTest(test.TestCase):
[[0, 0], [0, 1]], dtype=np.int64), np.array(
[1, 0], dtype=np.int64), np.array(
[1, 2], dtype=np.int64)),
- # beam 1, batch 0, three outputs decoded
+ # beam 1, batch 0, one output decoded
(np.array(
- [[0, 0], [0, 1], [0, 2]], dtype=np.int64), np.array(
- [0, 1, 0], dtype=np.int64), np.array(
- [1, 3], dtype=np.int64)),
+ [[0, 0]], dtype=np.int64), np.array(
+ [1], dtype=np.int64), np.array(
+ [1, 1], dtype=np.int64)),
]
# Test correct decoding.
diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py
index 9eaafb4435..b167278984 100644
--- a/tensorflow/python/kernel_tests/matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/matmul_op_test.py
@@ -142,7 +142,7 @@ class MatMulStatsTest(test_lib.TestCase):
for op in g.get_operations():
flops = ops.get_stats_for_node_def(g, op.node_def, "flops").value
if op.name == "MatMul":
- self.assertEqual(6975, flops)
+ self.assertEqual(7200, flops)
def testTransposedStatistics(self):
g = ops.Graph()
@@ -153,7 +153,7 @@ class MatMulStatsTest(test_lib.TestCase):
for op in g.get_operations():
flops = ops.get_stats_for_node_def(g, op.node_def, "flops").value
if op.name == "MatMul":
- self.assertEqual(6975, flops)
+ self.assertEqual(7200, flops)
try:
diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py
index 1d0c2dceba..15d5702252 100644
--- a/tensorflow/python/kernel_tests/partitioned_variables_test.py
+++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py
@@ -27,15 +27,12 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-from tensorflow.python.training import gradient_descent
from tensorflow.python.training import saver as saver_lib
@@ -549,6 +546,32 @@ class PartitionedVariablesTestCase(test.TestCase):
partitioned_variables.create_partitioned_variables(
[10, 43], [1, 50], rnd.initialized_value())
+ def testControlDepsNone(self):
+ with self.test_session() as session:
+ c = constant_op.constant(1.0)
+ with ops.control_dependencies([c]):
+ # d get the control dependency.
+ d = constant_op.constant(2.0)
+ # Partitioned variables do not.
+ var_x = variable_scope.get_variable(
+ "x",
+ shape=[2],
+ initializer=init_ops.ones_initializer(),
+ partitioner=partitioned_variables.variable_axis_size_partitioner(4))
+
+ ops_before_read = session.graph.get_operations()
+ var_x.as_tensor() # Caches the ops for subsequent reads.
+ reading_ops = [
+ op for op in session.graph.get_operations()
+ if op not in ops_before_read
+ ]
+
+ self.assertEqual([c.op], d.op.control_inputs)
+ # Tests that no control dependencies are added to reading a partitioned
+ # variable which is similar to reading a variable.
+ for op in reading_ops:
+ self.assertEqual([], op.control_inputs)
+
def testConcat(self):
with self.test_session() as session:
var_x = variable_scope.get_variable(
@@ -574,57 +597,6 @@ class PartitionedVariablesTestCase(test.TestCase):
variables.global_variables_initializer().run()
self.assertAllClose(value.eval(), var_x.as_tensor().eval())
- def testVariableCreationInALoop(self):
- """Tests the variable created inside a loop can be used outside the loop."""
- with self.test_session():
- with variable_scope.variable_scope("ascope") as scope:
- def Body(i, _):
- var_x = variable_scope.get_variable(
- "x",
- shape=[2],
- initializer=init_ops.ones_initializer(),
- partitioner=partitioned_variables.variable_axis_size_partitioner(
- 4))
- return (i + 1, var_x.as_tensor())
-
- cond = lambda i, _: i < 2
- _, x = control_flow_ops.while_loop(
- cond, Body, (0, constant_op.constant([7, 8], dtypes.float32)))
- variables.global_variables_initializer().run()
- self.assertAllClose([1.0, 1.0], x.eval())
-
- scope.reuse_variables()
- var_x = variable_scope.get_variable(
- "x",
- shape=[2],
- initializer=init_ops.ones_initializer(),
- partitioner=partitioned_variables.variable_axis_size_partitioner(4))
-
- self.assertAllClose([1.0, 1.0], var_x.as_tensor().eval())
-
- def testReadInWhileLoop(self):
- """Tests the value is current (not cached) when read within a loop."""
- with self.test_session():
- var_x = variable_scope.get_variable(
- "x",
- shape=[2],
- initializer=init_ops.ones_initializer(),
- partitioner=partitioned_variables.variable_axis_size_partitioner(4))
-
- def Body(i, _):
- # Use a SGD step to update the variable's value.
- loss = math_ops.reduce_sum(var_x)
- optimizer = gradient_descent.GradientDescentOptimizer(1.0)
- minimize = optimizer.minimize(loss * 0.7)
- with ops.control_dependencies([minimize]):
- return (i + 1, var_x.as_tensor())
-
- cond = lambda i, _: i < 2
- _, x = control_flow_ops.while_loop(
- cond, Body, (0, constant_op.constant([7, 8], dtypes.float32)))
- variables.global_variables_initializer().run()
- self.assertAllClose([-0.4, -0.4], x.eval())
-
def testMetaGraphSaveLoad(self):
save_prefix = os.path.join(self.get_temp_dir(), "ckpt")
save_graph = ops.Graph()
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index c4f200a22e..78f2993d27 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -441,11 +441,11 @@ class RNNTest(test.TestCase):
cell, inputs, dtype=dtypes.float32)
self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
self.assertEqual(len(state), 4)
- self.assertEqual(state[0].shape.as_list(), [None, output_shape])
- self.assertEqual(state[1].shape.as_list(), [None, output_shape])
- self.assertEqual(state[2].shape.as_list(), [None, 2 * output_shape])
- self.assertEqual(state[3].shape.as_list(), [None, 2 * output_shape])
- loss = losses.softmax_cross_entropy(predict, state[0])
+ self.assertEqual(state[0].shape.as_list(), [None, 2 * output_shape])
+ self.assertEqual(state[1].shape.as_list(), [None, 2 * output_shape])
+ self.assertEqual(state[2].shape.as_list(), [None, output_shape])
+ self.assertEqual(state[3].shape.as_list(), [None, output_shape])
+ loss = losses.softmax_cross_entropy(predict, state[2])
train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
sess.run([variables_lib.global_variables_initializer()])
diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py
index 76173e0f30..75a1a53eb7 100644
--- a/tensorflow/python/ops/cond_v2.py
+++ b/tensorflow/python/ops/cond_v2.py
@@ -24,7 +24,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
-from tensorflow.python.framework import function
+from tensorflow.python.eager import function
from tensorflow.python.framework import function_def_to_graph
from tensorflow.python.ops import gradients_impl
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
index b3dacff6d6..c4e9c982b5 100644
--- a/tensorflow/python/ops/cond_v2_impl.py
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -27,14 +27,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+
from tensorflow.core.framework import attr_value_pb2
-from tensorflow.python import pywrap_tensorflow as c_api
-from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_functional_ops
-from tensorflow.python.util import compat
# The following modules cannot be imported directly because they cause circular
@@ -57,46 +56,27 @@ def cond_v2(pred, true_fn, false_fn, name="cond"):
name = "cond"
with ops.name_scope(name) as scope:
- # Identify if there is a caller device, & get the innermost if possible.
- # pylint: disable=protected-access
- device_funcs = ops.get_default_graph()._device_functions_outer_to_inner
- caller_device = device_funcs[-1] if device_funcs else None
-
- caller_colocation_stack = ops.get_default_graph()._colocation_stack
- caller_container = ops.get_default_graph()._container
- caller_collection_ref = ops.get_default_graph()._collections
-
with ops.name_scope(None):
# Find the outer most graph for uniquing function names.
# TODO(jpienaar): Make this work in eager mode.
graph = ops.get_default_graph()
- while isinstance(graph, _function._FuncGraph):
- graph = graph._outer_graph
+ while isinstance(graph, _function.FuncGraph):
+ graph = graph.outer_graph
true_name = graph.unique_name(("%strue" % scope).replace("/", "_"))
false_name = graph.unique_name(("%sfalse" % scope).replace("/", "_"))
- # pylint: enable=protected-access
+
true_graph = _function.func_graph_from_py_func(
- true_fn, [], [],
- name=true_name,
- device=caller_device,
- colocation_stack=caller_colocation_stack,
- collections_ref=caller_collection_ref,
- container=caller_container)
+ true_name, true_fn, [], {})
false_graph = _function.func_graph_from_py_func(
- false_fn, [], [],
- name=false_name,
- device=caller_device,
- colocation_stack=caller_colocation_stack,
- collections_ref=caller_collection_ref,
- container=caller_container)
+ false_name, false_fn, [], {})
_check_same_outputs(true_graph, false_graph)
# Add inputs to true_graph and false_graph to make them match. Note that
# this modifies true_graph and false_graph.
cond_inputs = _make_inputs_match(true_graph, false_graph,
- true_graph.extra_inputs,
- false_graph.extra_inputs)
+ true_graph.external_captures,
+ false_graph.external_captures)
# Add all intermediate tensors as function outputs so they're available for
# the gradient computation.
@@ -148,8 +128,8 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
true_graph, false_graph = _get_func_graphs(op)
# Note: op.graph != ops.get_default_graph() when we are computing the gradient
# of a nested cond.
- assert true_graph._outer_graph == op.graph
- assert false_graph._outer_graph == op.graph
+ assert true_graph.outer_graph == op.graph
+ assert false_graph.outer_graph == op.graph
# Create grad functions that compute the gradient of the true/false forward
# graphs. These functions will capture tensors from the forward pass
@@ -164,14 +144,13 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
# Resolve references to forward graph tensors in grad graphs and ensure
# they are in-scope, i.e., belong to one of outer graphs of the grad graph.
- true_grad_extra_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
- false_grad_extra_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)
+ true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
+ false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)
# Make the inputs to true_grad_graph and false_grad_graph match. Note that
# this modifies true_grad_graph and false_grad_graph.
grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph,
- true_grad_extra_inputs,
- false_grad_extra_inputs)
+ true_grad_inputs, false_grad_inputs)
# Add all intermediate tensors as function outputs so they're available for
# higher-order gradient computations.
@@ -211,8 +190,8 @@ def _get_func_graphs(if_op):
"""
def _get_func_graph_for_branch(branch_name):
"""Generates and returns a _FuncGraph for the given branch."""
- extra_inputs = if_op.inputs[1:] # First input is pred.
- input_shapes = [t.shape for t in extra_inputs]
+ inputs = if_op.inputs[1:] # First input is pred.
+ input_shapes = [t.shape for t in inputs]
func_name = if_op.get_attr(branch_name).name
fdef = if_op.graph._get_function(func_name).definition
# `if_op.graph` may not be the same as `ops.get_default_graph()` e.g.
@@ -224,9 +203,8 @@ def _get_func_graphs(if_op):
with if_op.graph.as_default():
func_graph = _function_def_to_graph.function_def_to_graph(
fdef, input_shapes)
- func_graph.extra_inputs = extra_inputs
- func_graph.extra_args = func_graph.inputs
- func_graph._captured = dict(zip(extra_inputs, func_graph.inputs))
+ func_graph.captures = collections.OrderedDict(zip(inputs,
+ func_graph.inputs))
# Set the if op so that the gradient code can use it.
func_graph._if = if_op
return func_graph
@@ -282,12 +260,12 @@ def _grad_fn(func_graph, grads):
def _create_grad_func(func_graph, grads, name):
"""Returns the _FuncGraph representation of _grad_fn."""
- return _function.func_graph_from_py_func(lambda: _grad_fn(func_graph, grads),
- [], [], name)
+ return _function.func_graph_from_py_func(
+ name, lambda: _grad_fn(func_graph, grads), [], {})
def _resolve_grad_inputs(cond_graph, grad_graph):
- """Returns the tensors to pass as `extra_inputs` to `grad_graph`.
+ """Returns the tensors to pass as inputs to `grad_graph`.
The `grad_graph` may have external references to
1. Its outer graph containing the input gradients. These references are kept
@@ -305,10 +283,10 @@ def _resolve_grad_inputs(cond_graph, grad_graph):
Returns:
A list of inputs tensors to be passed to grad_graph.
"""
- new_extra_inputs = []
+ new_inputs = []
- for t in grad_graph.extra_inputs:
- if t.graph != grad_graph._outer_graph:
+ for t in grad_graph.external_captures:
+ if t.graph != grad_graph.outer_graph:
# `t` is a tensor in `cond_graph` or one of its ancestors. We bubble this
# tensor to the least common ancestor of the `cond_graph` and
# `grad_graph` so that it is "in-scope" for `grad_graph`.
@@ -316,19 +294,19 @@ def _resolve_grad_inputs(cond_graph, grad_graph):
# common ancestor once and re-use.
assert _is_ancestor(cond_graph, t.graph)
while not _is_ancestor(grad_graph, t.graph):
- assert isinstance(t.graph, _function._FuncGraph)
- if t in t.graph.extra_args:
- # TODO(srbs): Consider building a map of extra_args -> extra_inputs.
- # instead of searching for `t` twice.
- t = t.graph.extra_inputs[t.graph.extra_args.index(t)]
+ assert isinstance(t.graph, _function.FuncGraph)
+ if t in t.graph.internal_captures:
+ # TODO(srbs): Consider building a map of internal_captures ->
+ # external_captures instead of searching for `t` twice.
+ t = t.graph.external_captures[t.graph.internal_captures.index(t)]
else:
# Note: All intermediate tensors are output by the If op.
# TODO(srbs): .index() calls may be expensive. Optimize.
t = t.graph._if.outputs[t.graph.outputs.index(t)]
assert _is_ancestor(grad_graph, t.graph)
- new_extra_inputs.append(t)
+ new_inputs.append(t)
- return new_extra_inputs
+ return new_inputs
def _create_new_tf_function(func_graph):
@@ -340,26 +318,9 @@ def _create_new_tf_function(func_graph):
Returns:
The name of the new TF_Function.
"""
- c_func = c_api.TF_GraphToFunction_wrapper(
- func_graph._c_graph,
- compat.as_str(func_graph.name),
- False, # append_hash_to_fn_name
- None, # opers
- [t._as_tf_output() for t in func_graph.inputs],
- [t._as_tf_output() for t in func_graph.outputs],
- [],
- None, # opts
- None) # description
- _ = c_api_util.ScopedTFFunction(c_func)
-
- # TODO(b/109833212): this sucks, we're serializing the TF_Function*,
- # deserializing it into a Python FunctionDef, then reserializing it to create
- # a new TF_Function that we add to the graph.
- fdef = _function.function_def_from_tf_function(c_func)
- defined_func = _function._from_definition(fdef)
- defined_func._sub_functions = func_graph._functions
- defined_func.add_to_graph(func_graph._outer_graph)
-
+ func = _function._EagerDefinedFunction(
+ func_graph.name, func_graph, func_graph.inputs, func_graph.outputs, {})
+ func.add_to_graph(func_graph.outer_graph)
return func_graph.name
@@ -421,21 +382,20 @@ def _pad_params(true_graph, false_graph, true_params, false_params):
return new_true_params, new_false_inputs
-def _make_inputs_match(true_graph, false_graph, true_extra_inputs,
- false_extra_inputs):
+def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs):
"""Modifies true_graph and false_graph so they have the same input signature.
This method reorders and/or adds parameters to true_graph and false_graph so
- they have the same input signature, and updates the 'inputs', 'extra_inputs',
- and '_captured' fields of both graphs accordingly. It uses the input tensors
- from the outer graph to avoid duplicating shared arguments.
+ they have the same input signature, and updates the 'inputs' and 'captured'
+ fields of both graphs accordingly. It uses the input tensors from the outer
+ graph to avoid duplicating shared arguments.
Args:
true_graph: function._FuncGraph
false_graph: function._FuncGraph
- true_extra_inputs: a list of Tensors in the outer graph. The inputs for
+ true_inputs: a list of Tensors in the outer graph. The inputs for
true_graph.
- false_extra_inputs: a list of Tensors in the outer graph. The inputs for
+ false_inputs: a list of Tensors in the outer graph. The inputs for
false_graph.
Returns:
@@ -444,12 +404,12 @@ def _make_inputs_match(true_graph, false_graph, true_extra_inputs,
false_inputs.
"""
shared_inputs, true_only_inputs, false_only_inputs = _separate_unique_inputs(
- true_extra_inputs, false_extra_inputs)
+ true_inputs, false_inputs)
new_inputs = shared_inputs + true_only_inputs + false_only_inputs
- true_input_to_param = dict(zip(true_extra_inputs, true_graph.inputs))
- false_input_to_param = dict(zip(false_extra_inputs, false_graph.inputs))
+ true_input_to_param = dict(zip(true_inputs, true_graph.inputs))
+ false_input_to_param = dict(zip(false_inputs, false_graph.inputs))
true_graph.inputs = (
[true_input_to_param[t] for t in shared_inputs] +
@@ -462,14 +422,10 @@ def _make_inputs_match(true_graph, false_graph, true_extra_inputs,
[false_input_to_param[t] for t in false_only_inputs])
# Rewrite the _FuncGraphs' state to reflect the new inputs.
- true_graph.extra_inputs = new_inputs
- false_graph.extra_inputs = new_inputs
-
- true_graph.extra_args = true_graph.inputs
- false_graph.extra_args = false_graph.inputs
-
- true_graph._captured = dict(zip(new_inputs, true_graph.inputs))
- false_graph._captured = dict(zip(new_inputs, false_graph.inputs))
+ true_graph.captures = collections.OrderedDict(zip(new_inputs,
+ true_graph.inputs))
+ false_graph.captures = collections.OrderedDict(zip(new_inputs,
+ false_graph.inputs))
return new_inputs
@@ -506,10 +462,10 @@ def _get_grad_fn_name(func_graph):
counter = 1
has_conflict = True
while has_conflict:
- curr_graph = func_graph._outer_graph
+ curr_graph = func_graph.outer_graph
has_conflict = curr_graph._is_function(name)
- while not has_conflict and isinstance(curr_graph, _function._FuncGraph):
- curr_graph = curr_graph._outer_graph
+ while not has_conflict and isinstance(curr_graph, _function.FuncGraph):
+ curr_graph = curr_graph.outer_graph
has_conflict = curr_graph._is_function(name)
if has_conflict:
name = "%s_%s" % (base_name, counter)
@@ -534,6 +490,6 @@ def _check_same_outputs(true_graph, false_graph):
def _is_ancestor(graph, maybe_ancestor):
if maybe_ancestor == graph:
return True
- if isinstance(graph, _function._FuncGraph):
- return _is_ancestor(graph._outer_graph, maybe_ancestor)
+ if isinstance(graph, _function.FuncGraph):
+ return _is_ancestor(graph.outer_graph, maybe_ancestor)
return False
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index d1095c8954..e3c1aa3d5a 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -1966,8 +1966,12 @@ def cond(pred,
`true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
`false_fn` must have the same non-zero number and type of outputs.
- Note that the conditional execution applies only to the operations defined in
- `true_fn` and `false_fn`. Consider the following simple program:
+ **WARNING**: Any Tensors or Operations created outside of `true_fn` and
+ `false_fn` will be executed regardless of which branch is selected at runtime.
+
+ Although this behavior is consistent with the dataflow model of TensorFlow,
+ it has frequently surprised users who expected a lazier semantics.
+ Consider the following simple program:
```python
z = tf.multiply(a, b)
@@ -1978,8 +1982,6 @@ def cond(pred,
operation will not be executed. Since `z` is needed for at least one
branch of the `cond`, the `tf.multiply` operation is always executed,
unconditionally.
- Although this behavior is consistent with the dataflow model of TensorFlow,
- it has occasionally surprised some users who expected a lazier semantics.
Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
call to `cond`, and not at all during `Session.run()`). `cond`
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 70b5e9b4b7..9b0ab00c7a 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -618,7 +618,7 @@ def cast(x, dtype, name=None):
"""Casts a tensor to a new type.
The operation casts `x` (in case of `Tensor`) or `x.values`
- (in case of `SparseTensor`) to `dtype`.
+ (in case of `SparseTensor` or `IndexedSlices`) to `dtype`.
For example:
@@ -637,15 +637,16 @@ def cast(x, dtype, name=None):
behavior of numpy.
Args:
- x: A `Tensor` or `SparseTensor` of numeric type. It could be
- `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`,
- `float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`.
- dtype: The destination type. The list of supported dtypes is the same
- as `x`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices` of numeric type. It could
+ be `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`,
+ `int64`, `float16`, `float32`, `float64`, `complex64`, `complex128`,
+ `bfloat16`.
+ dtype: The destination type. The list of supported dtypes is the same as
+ `x`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` and
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` and
same type as `dtype`.
Raises:
@@ -659,6 +660,9 @@ def cast(x, dtype, name=None):
if isinstance(x, sparse_tensor.SparseTensor):
values_cast = cast(x.values, base_type, name=name)
x = sparse_tensor.SparseTensor(x.indices, values_cast, x.dense_shape)
+ elif isinstance(x, ops.IndexedSlices):
+ values_cast = cast(x.values, base_type, name=name)
+ x = ops.IndexedSlices(values_cast, x.indices, x.dense_shape)
else:
# TODO(josh11b): If x is not already a Tensor, we could return
# ops.convert_to_tensor(x, dtype=dtype, ...) here, but that
@@ -711,11 +715,12 @@ def to_float(x, name="ToFloat"):
"""Casts a tensor to type `float32`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `float32`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `float32`.
Raises:
TypeError: If `x` cannot be cast to the `float32`.
@@ -728,11 +733,12 @@ def to_double(x, name="ToDouble"):
"""Casts a tensor to type `float64`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `float64`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `float64`.
Raises:
TypeError: If `x` cannot be cast to the `float64`.
@@ -745,11 +751,12 @@ def to_int32(x, name="ToInt32"):
"""Casts a tensor to type `int32`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `int32`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `int32`.
Raises:
TypeError: If `x` cannot be cast to the `int32`.
@@ -762,11 +769,12 @@ def to_int64(x, name="ToInt64"):
"""Casts a tensor to type `int64`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `int64`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `int64`.
Raises:
TypeError: If `x` cannot be cast to the `int64`.
@@ -779,11 +787,12 @@ def to_bfloat16(x, name="ToBFloat16"):
"""Casts a tensor to type `bfloat16`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `bfloat16`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `bfloat16`.
Raises:
TypeError: If `x` cannot be cast to the `bfloat16`.
@@ -796,11 +805,12 @@ def to_complex64(x, name="ToComplex64"):
"""Casts a tensor to type `complex64`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `complex64`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `complex64`.
Raises:
TypeError: If `x` cannot be cast to the `complex64`.
@@ -813,11 +823,12 @@ def to_complex128(x, name="ToComplex128"):
"""Casts a tensor to type `complex128`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `complex128`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `complex128`.
Raises:
TypeError: If `x` cannot be cast to the `complex128`.
@@ -2061,7 +2072,7 @@ def _calc_mat_mul_flops(graph, node):
output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
output_shape.assert_is_fully_defined()
output_count = np.prod(output_shape.as_list())
- return ops.OpStats("flops", ((2 * k - 1) * output_count))
+ return ops.OpStats("flops", (k * output_count * 2))
def _as_indexed_slices(x, optimize=True):
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index baba5d4093..4800352ac2 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -355,6 +355,15 @@ class ResourceVariable(variables.RefVariable):
raise ValueError("initial_value must be specified.")
init_from_fn = callable(initial_value)
+ if isinstance(initial_value, ops.Tensor) and hasattr(
+ initial_value, "graph") and initial_value.graph.building_function:
+ raise ValueError("Tensor-typed variable initializers must either be "
+ "wrapped in an init_scope or callable "
+ "(e.g., `tf.Variable(lambda : "
+ "tf.truncated_normal([10, 40]))`) when building "
+ "functions. Please file a feature request if this "
+ "restriction inconveniences you.")
+
if collections is None:
collections = [ops.GraphKeys.GLOBAL_VARIABLES]
if not isinstance(collections, (list, tuple, set)):
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index d990386b9a..38ce5236e3 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -96,6 +96,60 @@ def _make_int64_tensor(value, name):
return math_ops.cast(value, dtypes.int64)
+@tf_export("sparse.expand_dims")
+def sparse_expand_dims(sp_input, axis=None, name=None):
+ """Inserts a dimension of 1 into a tensor's shape.
+
+ Given a tensor `sp_input`, this operation inserts a dimension of 1 at the
+ dimension index `axis` of `sp_input`'s shape. The dimension index `axis`
+ starts at zero; if you specify a negative number for `axis` it is counted
+ backwards from the end.
+
+ Args:
+ sp_input: A `SparseTensor`.
+ axis: 0-D (scalar). Specifies the dimension index at which to expand the
+ shape of `input`. Must be in the range `[-rank(sp_input) - 1,
+ rank(sp_input)]`.
+ name: The name of the output `SparseTensor`.
+
+ Returns:
+ A `SparseTensor` with the same data as `sp_input`, but its shape has an
+ additional dimension of size 1 added.
+ """
+ rank = sp_input.dense_shape.get_shape()[0]
+ axis = -1 if axis is None else axis
+
+ with ops.name_scope(name, default_name="expand_dims", values=[sp_input]):
+ if isinstance(axis, compat.integral_types):
+ axis = ops.convert_to_tensor(axis, name="axis", dtype=dtypes.int32)
+ elif not isinstance(axis, ops.Tensor):
+ raise TypeError("axis must be an integer value in range [-rank(sp_input)"
+ " - 1, rank(sp_input)]")
+
+ # Convert axis to a positive value if it is negative.
+ axis = array_ops.where(axis >= 0, axis, axis + rank + 1)
+
+ # Create the new column of indices for the sparse tensor by slicing
+ # the indices and inserting a new column of indices for the new dimension.
+ column_size = array_ops.shape(sp_input.indices)[0]
+ new_index = array_ops.zeros([column_size, 1], dtype=dtypes.int64)
+ indices_before = array_ops.slice(sp_input.indices, [0, 0], [-1, axis])
+ indices_after = array_ops.slice(sp_input.indices, [0, axis], [-1, -1])
+ indices = array_ops.concat(
+ [indices_before, new_index, indices_after], axis=1)
+
+ # Create the new dense shape by splicing the tensor [1] in the correct
+ # dimension of the existing shape.
+ shape_before = array_ops.slice(sp_input.dense_shape, [0], [axis])
+ shape_after = array_ops.slice(sp_input.dense_shape, [axis], [-1])
+ new_shape = ops.convert_to_tensor([1], name="new_shape", dtype=dtypes.int64)
+ shape = array_ops.concat([shape_before, new_shape, shape_after], axis=0)
+
+ # Create the output sparse tensor.
+ return sparse_tensor.SparseTensor(
+ indices=indices, values=sp_input.values, dense_shape=shape)
+
+
@tf_export("sparse.eye")
def sparse_eye(num_rows,
num_columns=None,
diff --git a/tensorflow/python/ops/sparse_ops_test.py b/tensorflow/python/ops/sparse_ops_test.py
index b10c3c2187..4ee1569249 100644
--- a/tensorflow/python/ops/sparse_ops_test.py
+++ b/tensorflow/python/ops/sparse_ops_test.py
@@ -21,6 +21,8 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import googletest
@@ -45,5 +47,35 @@ class SparseOpsTest(test_util.TensorFlowTestCase):
test_one(n, m, True)
test_one(n, m, False)
+ def testSparseExpandDims(self):
+ for rank in range(1, 4):
+ # Create a dummy input. When rank=3, shape=[2, 4, 6].
+ shape = np.arange(1, rank + 1) * 2
+ before = np.arange(np.prod(shape)).reshape(shape)
+
+ # Make entries sparse.
+ before *= np.random.binomial(1, .2, before.shape)
+ dense_shape = before.shape
+ indices = np.array(np.where(before)).T
+ values = before[before != 0]
+
+ # Try every possible valid value of axis.
+ for axis in range(-rank - 1, rank):
+ expected_after = np.expand_dims(before, axis)
+
+ for axis_as_tensor in [False, True]:
+ dense_shape_t = constant_op.constant(dense_shape, dtype=dtypes.int64)
+ indices_t = constant_op.constant(indices)
+ values_t = constant_op.constant(values)
+ before_t = sparse_tensor.SparseTensor(
+ indices=indices_t, values=values_t, dense_shape=dense_shape_t)
+
+ if axis_as_tensor:
+ axis = constant_op.constant(axis)
+
+ s = sparse_ops.sparse_expand_dims(before_t, axis)
+ d = sparse_ops.sparse_to_dense(s.indices, s.dense_shape, s.values)
+ self.assertAllEqual(self.evaluate(d), expected_after)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 571265665b..f7da3f7d64 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -2336,10 +2336,15 @@ class PartitionedVariable(object):
def as_tensor(self):
"""Returns the overall concatenated value as a `Tensor`.
+ The returned tensor will not inherit the control dependencies from the scope
+ where the value is used, which is similar to getting the value of
+ `Variable`.
+
Returns:
`Tensor` containing the concatenated value.
"""
- return self._concat()
+ with ops.control_dependencies(None):
+ return self._concat()
@staticmethod
def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False):
diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py
index 85f2904318..b7aa8264b0 100644
--- a/tensorflow/python/training/checkpoint_management.py
+++ b/tensorflow/python/training/checkpoint_management.py
@@ -510,7 +510,10 @@ class CheckpointManager(object):
max_to_keep: An integer, the number of checkpoints to keep. Unless
preserved by `keep_checkpoint_every_n_hours`, checkpoints will be
deleted from the active set, oldest first, until only `max_to_keep`
- checkpoints remain.
+ checkpoints remain. If `None`, no checkpoints are deleted and everything
+ stays in the active set. Note that `max_to_keep=None` will keep all
+ checkpoint paths in memory and in the checkpoint state protocol buffer
+ on disk.
keep_checkpoint_every_n_hours: Upon removal from the active set, a
checkpoint will be preserved if it has been at least
`keep_checkpoint_every_n_hours` since the last preserved checkpoint. The
@@ -521,9 +524,10 @@ class CheckpointManager(object):
"""
self._checkpoint = checkpoint
self._save_counter_assign = None
- if not max_to_keep or max_to_keep < 0:
+ if max_to_keep is not None and max_to_keep <= 0:
raise ValueError(
- "Expected a positive integer for `max_to_max_to_keep`, got %d."
+ ("Expected a positive integer or `None` for `max_to_max_to_keep`, "
+ "got %d.")
% (max_to_keep,))
self._max_to_keep = max_to_keep
self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
@@ -586,6 +590,10 @@ class CheckpointManager(object):
def _sweep(self):
"""Deletes or preserves managed checkpoints."""
+ if not self._max_to_keep:
+ # Does not update self._last_preserved_timestamp, since everything is kept
+ # in the active set.
+ return
while len(self._maybe_delete) > self._max_to_keep:
filename, timestamp = self._maybe_delete.popitem(last=False)
# Even if we're keeping this checkpoint due to
diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py
index 22c2cc678a..d7162265e6 100644
--- a/tensorflow/python/training/checkpoint_management_test.py
+++ b/tensorflow/python/training/checkpoint_management_test.py
@@ -26,6 +26,7 @@ import tempfile
from google.protobuf import text_format
from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import test_util
@@ -333,6 +334,49 @@ class CheckpointManagerTest(test.TestCase):
self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
@test_util.run_in_graph_and_eager_modes
+ def testKeepAll(self):
+ checkpoint = util.Checkpoint()
+ directory = os.path.join(
+ self.get_temp_dir(),
+ # Avoid sharing directories between eager and graph
+ # TODO(allenl): stop run_in_graph_and_eager_modes reusing directories
+ str(context.executing_eagerly()))
+ manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=None)
+ first_path = manager.save()
+ second_path = manager.save()
+ third_path = manager.save()
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
+ self.assertEqual(third_path, manager.latest_checkpoint)
+ self.assertEqual([first_path, second_path, third_path],
+ manager.checkpoints)
+ del manager
+ manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=None)
+ fourth_path = manager.save()
+ self.assertEqual([first_path, second_path, third_path, fourth_path],
+ manager.checkpoints)
+ del manager
+ manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=3)
+ self.assertEqual([first_path, second_path, third_path, fourth_path],
+ manager.checkpoints)
+ self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
+ fifth_path = manager.save()
+ self.assertEqual([third_path, fourth_path, fifth_path],
+ manager.checkpoints)
+ self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertFalse(checkpoint_management.checkpoint_exists(second_path))
+ self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
+
+ @test_util.run_in_graph_and_eager_modes
@test.mock.patch.object(checkpoint_management, "time")
def testSaveRestoreState(self, mock_time):
directory = self.get_temp_dir()
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index ecadc56871..697b44c3ff 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -384,8 +384,8 @@ class CheckpointingTests(test.TestCase):
saver = saver_lib.Saver(var_list=[v])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
- self.evaluate(v.non_dep_variable.assign(42.))
with self.test_session() as sess:
+ self.evaluate(v.non_dep_variable.assign(42.))
save_path = saver.save(sess, prefix)
self.evaluate(v.non_dep_variable.assign(43.))
self.evaluate(v.mirrored.assign(44.))
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index 4b91d1e963..177a7ddfa5 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -363,10 +363,12 @@ class ExponentialMovingAverage(object):
`GraphKeys.ALL_VARIABLES` collection. They will be returned by calls to
`tf.global_variables()`.
- Returns an op that updates all shadow variables as described above.
+ Returns an op that updates all shadow variables from the current value of
+ their associated variables.
- Note that `apply()` can be called multiple times with different lists of
- variables.
+ Note that `apply()` can be called multiple times. When eager execution is
+ enabled each call to apply will update the variables once, so this needs to
+ be called in a loop.
Args:
var_list: A list of Variable or Tensor objects. The variables
@@ -389,31 +391,30 @@ class ExponentialMovingAverage(object):
dtypes.float64]:
raise TypeError("The variables must be half, float, or double: %s" %
var.name)
- if var in self._averages:
- raise ValueError("Moving average already computed for: %s" % var.name)
- # For variables: to lower communication bandwidth across devices we keep
- # the moving averages on the same device as the variables. For other
- # tensors, we rely on the existing device allocation mechanism.
- with ops.init_scope():
- if isinstance(var, variables.Variable):
- avg = slot_creator.create_slot(var,
- var.initialized_value(),
- self.name,
- colocate_with_primary=True)
- # NOTE(mrry): We only add `tf.Variable` objects to the
- # `MOVING_AVERAGE_VARIABLES` collection.
- ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
- else:
- avg = slot_creator.create_zeros_slot(
- var,
- self.name,
- colocate_with_primary=(var.op.type in ["Variable",
- "VariableV2",
- "VarHandleOp"]))
- if self._zero_debias:
- zero_debias_true.add(avg)
- self._averages[var] = avg
+ if var not in self._averages:
+ # For variables: to lower communication bandwidth across devices we keep
+ # the moving averages on the same device as the variables. For other
+ # tensors, we rely on the existing device allocation mechanism.
+ with ops.init_scope():
+ if isinstance(var, variables.Variable):
+ avg = slot_creator.create_slot(var,
+ var.initialized_value(),
+ self.name,
+ colocate_with_primary=True)
+ # NOTE(mrry): We only add `tf.Variable` objects to the
+ # `MOVING_AVERAGE_VARIABLES` collection.
+ ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
+ else:
+ avg = slot_creator.create_zeros_slot(
+ var,
+ self.name,
+ colocate_with_primary=(var.op.type in ["Variable",
+ "VariableV2",
+ "VarHandleOp"]))
+ if self._zero_debias:
+ zero_debias_true.add(avg)
+ self._averages[var] = avg
with ops.name_scope(self.name) as scope:
decay = ops.convert_to_tensor(self._decay, name="decay")
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index 3e85e6bfa7..fdb8d795c3 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -18,9 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import variable_scope
@@ -254,6 +256,25 @@ class ExponentialMovingAverageTest(test.TestCase):
self.assertEqual(1, sess.run(v0))
self.assertEqual([17.5], sess.run(v1_avg))
+ @test_util.run_in_graph_and_eager_modes
+ def testBasicEager(self):
+ v0 = variables.Variable(1.0)
+ v1 = variables.Variable(2.0)
+
+ ema = moving_averages.ExponentialMovingAverage(0.25)
+ op = ema.apply([v0, v1])
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(op)
+
+ self.evaluate(v0.assign(2.0))
+ self.evaluate(v1.assign(4.0))
+
+ self.evaluate(ema.apply([v0, v1]))
+
+ self.assertAllEqual(self.evaluate(ema.average(v0)), 1.75)
+ self.assertAllEqual(self.evaluate(ema.average(v1)), 3.5)
+
def averageVariablesNamesHelper(self, zero_debias):
with self.test_session():
v0 = variables.Variable(10.0, name="v0")
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index b46095d458..f5b2a22327 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -2853,8 +2853,8 @@ class CheckpointableCompatibilityTests(test.TestCase):
saver = saver_module.Saver(var_list=[v])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
- self.evaluate(v.non_dep_variable.assign(42.))
with self.test_session() as sess:
+ self.evaluate(v.non_dep_variable.assign(42.))
save_path = saver.save(sess, prefix)
self.evaluate(v.non_dep_variable.assign(43.))
saver.restore(sess, save_path)
diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py
index 2be4dbb283..a5ac430ce7 100644
--- a/tensorflow/python/util/tf_export.py
+++ b/tensorflow/python/util/tf_export.py
@@ -136,11 +136,14 @@ class api_export(object): # pylint: disable=invalid-name
has no effect on exporting a constant.
api_name: Name of the API you want to generate (e.g. `tensorflow` or
`estimator`). Default is `tensorflow`.
+ allow_multiple_exports: Allow symbol to be exported multiple time under
+ different names.
"""
self._names = args
self._names_v1 = kwargs.get('v1', args)
self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME)
self._overrides = kwargs.get('overrides', [])
+ self._allow_multiple_exports = kwargs.get('allow_multiple_exports', False)
def __call__(self, func):
"""Calls this decorator.
@@ -173,9 +176,10 @@ class api_export(object): # pylint: disable=invalid-name
# __dict__ instead of using hasattr to verify that subclasses have
# their own _tf_api_names as opposed to just inheriting it.
if api_names_attr in func.__dict__:
- raise SymbolAlreadyExposedError(
- 'Symbol %s is already exposed as %s.' %
- (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access
+ if not self._allow_multiple_exports:
+ raise SymbolAlreadyExposedError(
+ 'Symbol %s is already exposed as %s.' %
+ (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access
setattr(func, api_names_attr, names)
def export_constant(self, module_name, name):
@@ -213,4 +217,5 @@ class api_export(object): # pylint: disable=invalid-name
tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME)
-estimator_export = functools.partial(api_export, api_name=ESTIMATOR_API_NAME)
+estimator_export = functools.partial(
+ api_export, api_name=ESTIMATOR_API_NAME, allow_multiple_exports=True)
diff --git a/tensorflow/stream_executor/lib/env.h b/tensorflow/stream_executor/lib/env.h
index 3ef8deb72e..d78bbfd425 100644
--- a/tensorflow/stream_executor/lib/env.h
+++ b/tensorflow/stream_executor/lib/env.h
@@ -32,7 +32,7 @@ inline Status FileExists(const string& filename) {
}
inline Status FileExists(const port::StringPiece& filename) {
- return Env::Default()->FileExists(std::string(filename));
+ return Env::Default()->FileExists(string(filename));
}
} // namespace port
diff --git a/tensorflow/stream_executor/lib/path.cc b/tensorflow/stream_executor/lib/path.cc
index 58a862206c..3d3da103e1 100644
--- a/tensorflow/stream_executor/lib/path.cc
+++ b/tensorflow/stream_executor/lib/path.cc
@@ -33,7 +33,7 @@ string JoinPathImpl(std::initializer_list<port::StringPiece> paths) {
if (path.empty()) continue;
if (result.empty()) {
- result = std::string(path);
+ result = string(path);
continue;
}
diff --git a/tensorflow/stream_executor/lib/str_util.h b/tensorflow/stream_executor/lib/str_util.h
index b02fe4f56f..e77dfcef76 100644
--- a/tensorflow/stream_executor/lib/str_util.h
+++ b/tensorflow/stream_executor/lib/str_util.h
@@ -31,7 +31,7 @@ inline string StripSuffixString(port::StringPiece str, port::StringPiece suffix)
if (tensorflow::str_util::EndsWith(str, suffix)) {
str.remove_suffix(suffix.size());
}
- return std::string(str);
+ return string(str);
}
using tensorflow::str_util::Lowercase;
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt
index bf1f94b6ae..269e18a0a7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt
@@ -96,7 +96,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\', \'experimental_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "replace"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt
index 8ba0e7480b..7ad4a32d43 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt
@@ -9,6 +9,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member_method {
+ name: "clone_model"
+ argspec: "args=[\'model\', \'input_tensors\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "load_model"
argspec: "args=[\'filepath\', \'custom_objects\', \'compile\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt
index 3f54bc33e7..ba9e651b34 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt
@@ -9,6 +9,10 @@ tf_module {
argspec: "args=[\'inputs\', \'num_buckets\', \'hash_key\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], "
}
member_method {
+ name: "expand_dims"
+ argspec: "args=[\'sp_input\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
name: "eye"
argspec: "args=[\'num_rows\', \'num_columns\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt
index bf1f94b6ae..269e18a0a7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt
@@ -96,7 +96,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\', \'experimental_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "replace"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt
index 8ba0e7480b..7ad4a32d43 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt
@@ -9,6 +9,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member_method {
+ name: "clone_model"
+ argspec: "args=[\'model\', \'input_tensors\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "load_model"
argspec: "args=[\'filepath\', \'custom_objects\', \'compile\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt
index 3f54bc33e7..ba9e651b34 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt
@@ -9,6 +9,10 @@ tf_module {
argspec: "args=[\'inputs\', \'num_buckets\', \'hash_key\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], "
}
member_method {
+ name: "expand_dims"
+ argspec: "args=[\'sp_input\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
name: "eye"
argspec: "args=[\'num_rows\', \'num_columns\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\'], "
}
diff --git a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
index 0482cf619a..27b350e13e 100644
--- a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
+++ b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
@@ -27,7 +27,7 @@ function run_configure_for_gpu_build {
}
function set_remote_cache_options {
- echo "build --remote_instance_name=projects/tensorflow-testing-cpu" >> "${TMP_BAZELRC}"
+ echo "build --remote_instance_name=projects/tensorflow-testing/instances/default_instance" >> "${TMP_BAZELRC}"
echo "build --experimental_remote_platform_override='properties:{name:\"build\" value:\"windows-x64\"}'" >> "${TMP_BAZELRC}"
echo "build --remote_cache=remotebuildexecution.googleapis.com" >> "${TMP_BAZELRC}"
echo "build --tls_enabled=true" >> "${TMP_BAZELRC}"
diff --git a/tensorflow/tools/docker/README.md b/tensorflow/tools/docker/README.md
index a286e8a212..263f25bc48 100644
--- a/tensorflow/tools/docker/README.md
+++ b/tensorflow/tools/docker/README.md
@@ -1,3 +1,10 @@
+# WARNING: THESE IMAGES ARE DEPRECATED.
+
+TensorFlow's Dockerfiles are now located in
+[`tensorflow/tools/dockerfiles/`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/dockerfiles).
+
+This directory will eventually be removed.
+
# Using TensorFlow via Docker
This directory contains `Dockerfile`s to make it easy to get up and running with
diff --git a/tensorflow/tools/dockerfiles/README.md b/tensorflow/tools/dockerfiles/README.md
new file mode 100644
index 0000000000..c484c162cb
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/README.md
@@ -0,0 +1,67 @@
+# TensorFlow Dockerfiles
+
+This directory houses TensorFlow's Dockerfiles. **DO NOT EDIT THE DOCKERFILES
+MANUALLY!** They are maintained by `assembler.py`, which builds Dockerfiles from
+the files in `partials/` and the rules in `spec.yml`. See [the Maintaining
+section](#maintaining) for more information.
+
+## Building
+
+The Dockerfiles in the `dockerfiles` directory must have their build context set
+to **the directory with this README.md** to copy in helper files. For example:
+
+```bash
+$ docker build -f ./dockerfiles/cpu.Dockerfile -t tf .
+```
+
+Each Dockerfile has its own set of available `--build-arg`s which are documented
+in the Dockerfile itself.
+
+## Running
+
+After building the image with the tag `tf` (for example), use `docker run` to
+run the images. Examples are below.
+
+Note for new Docker users: the `-v` and `-u` flags share directories between
+the Docker container and your machine, and very important. Without
+`-v`, your work will be wiped once the container quits, and without `-u`, files
+created by the container will have the wrong file permissions on your host
+machine. If you are confused, check out the [Docker run
+documentation](https://docs.docker.com/engine/reference/run/).
+
+```bash
+# Volume mount (-v) is optional but highly recommended, especially for Jupyter.
+# User permissions (-u) are required if you use (-v).
+
+# CPU-based images
+$ docker run -u $(id -u):$(id -g) -v $(PWD):/my-devel -it tf
+
+# GPU-based images (set up nvidia-docker2 first)
+$ docker run --runtime=nvidia -u $(id -u):$(id -g) -v $(PWD):/my-devel -it tf
+
+# Images with Jupyter run on port 8888, and needs a volume for notebooks
+$ docker run --user $(id -u):$(id -g) -p 8888:8888 -v $(PWD):/notebooks -it tf
+```
+
+These images do not come with the TensorFlow source code -- but the development
+images have git included, so you can `git clone` it yourself.
+
+## Contributing
+
+To make changes to TensorFlow's Dockerfiles, you'll update `spec.yml` and the
+`*.partial.Dockerfile` files in the `partials` directory, then run
+`assembler.py` to re-generate the full Dockerfiles before creating a pull
+request.
+
+You can use the `Dockerfile` in this directory to build an editing environment
+that has all of the Python dependencies you'll need:
+
+```bash
+$ docker build -t tf-assembler -f assembler.Dockerfile .
+
+# Set --user to set correct permissions on generated files
+$ docker run --user $(id -u):$(id -g) -it -v $(pwd):/tf tf-assembler bash
+
+# In the container...
+/tf $ python3 ./assembler.py -o dockerfiles -s spec.yml
+```
diff --git a/tensorflow/tools/dockerfiles/assembler.Dockerfile b/tensorflow/tools/dockerfiles/assembler.Dockerfile
new file mode 100644
index 0000000000..7a8e07fced
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/assembler.Dockerfile
@@ -0,0 +1,30 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# TensorFlow Dockerfile Development Container
+#
+# You can use this image to quickly develop changes to the Dockerfile assembler
+# or set of TF Docker partials. See README.md for usage instructions.
+FROM debian:stretch
+LABEL maintainer="Austin Anderson <angerson@google.com>"
+
+RUN apt-get update && apt-get install -y python3 python3-pip bash
+RUN pip3 install --upgrade pip setuptools pyyaml absl-py cerberus
+
+WORKDIR /tf
+VOLUME ["/tf"]
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
diff --git a/tensorflow/tools/dockerfiles/assembler.py b/tensorflow/tools/dockerfiles/assembler.py
new file mode 100644
index 0000000000..9cdd9bb0cb
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/assembler.py
@@ -0,0 +1,554 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Assemble common TF Dockerfiles from many parts.
+
+This script constructs TF's Dockerfiles by aggregating partial
+Dockerfiles. See README.md for usage examples.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import errno
+import os
+import os.path
+import re
+import shutil
+import textwrap
+
+from absl import app
+from absl import flags
+import cerberus
+import yaml
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_boolean(
+ 'dry_run', False, 'Do not actually generate Dockerfiles', short_name='n')
+
+flags.DEFINE_string(
+ 'spec_file',
+ './spec.yml',
+ 'Path to a YAML specification file',
+ short_name='s')
+
+flags.DEFINE_string(
+ 'output_dir',
+ './dockerfiles', ('Path to an output directory for Dockerfiles. '
+ 'Will be created if it doesn\'t exist.'),
+ short_name='o')
+
+flags.DEFINE_string(
+ 'partial_dir',
+ './partials',
+ 'Path to a directory containing foo.partial.Dockerfile partial files.',
+ short_name='p')
+
+flags.DEFINE_boolean(
+ 'quiet_dry_run',
+ True,
+ 'Do not print contents of dry run Dockerfiles.',
+ short_name='q')
+
+flags.DEFINE_boolean(
+ 'validate', True, 'Validate generated Dockerfiles', short_name='c')
+
+# Schema to verify the contents of spec.yml with Cerberus.
+# Must be converted to a dict from yaml to work.
+# Note: can add python references with e.g.
+# !!python/name:builtins.str
+# !!python/name:__main__.funcname
+SCHEMA_TEXT = """
+header:
+ type: string
+
+partials:
+ type: dict
+ keyschema:
+ type: string
+ valueschema:
+ type: dict
+ schema:
+ desc:
+ type: string
+ args:
+ type: dict
+ keyschema:
+ type: string
+ valueschema:
+ anyof:
+ - type: [ boolean, number, string ]
+ - type: dict
+ schema:
+ default:
+ type: [ boolean, number, string ]
+ desc:
+ type: string
+ options:
+ type: list
+ schema:
+ type: string
+
+images:
+ keyschema:
+ type: string
+ valueschema:
+ type: dict
+ schema:
+ desc:
+ type: string
+ arg-defaults:
+ type: list
+ schema:
+ anyof:
+ - type: dict
+ keyschema:
+ type: string
+ arg_in_use: true
+ valueschema:
+ type: string
+ - type: string
+ isimage: true
+ create-dockerfile:
+ type: boolean
+ partials:
+ type: list
+ schema:
+ anyof:
+ - type: dict
+ keyschema:
+ type: string
+ regex: image
+ valueschema:
+ type: string
+ isimage: true
+ - type: string
+ ispartial: true
+"""
+
+
+class TfDockerValidator(cerberus.Validator):
+ """Custom Cerberus validator for TF dockerfile spec.
+
+ Note: Each _validate_foo function's docstring must end with a segment
+ describing its own validation schema, e.g. "The rule's arguments are...". If
+ you add a new validator, you can copy/paste that section.
+ """
+
+ def _validate_ispartial(self, ispartial, field, value):
+ """Validate that a partial references an existing partial spec.
+
+ Args:
+ ispartial: Value of the rule, a bool
+ field: The field being validated
+ value: The field's value
+
+ The rule's arguments are validated against this schema:
+ {'type': 'boolean'}
+ """
+ if ispartial and value not in self.root_document.get('partials', dict()):
+ self._error(field, '{} is not an existing partial.'.format(value))
+
+ def _validate_isimage(self, isimage, field, value):
+ """Validate that an image references an existing partial spec.
+
+ Args:
+ isimage: Value of the rule, a bool
+ field: The field being validated
+ value: The field's value
+
+ The rule's arguments are validated against this schema:
+ {'type': 'boolean'}
+ """
+ if isimage and value not in self.root_document.get('images', dict()):
+ self._error(field, '{} is not an existing image.'.format(value))
+
+ def _validate_arg_in_use(self, arg_in_use, field, value):
+ """Validate that an arg references an existing partial spec's args.
+
+ Args:
+ arg_in_use: Value of the rule, a bool
+ field: The field being validated
+ value: The field's value
+
+ The rule's arguments are validated against this schema:
+ {'type': 'boolean'}
+ """
+ if arg_in_use:
+ for partial in self.root_document.get('partials', dict()).values():
+ if value in partial.get('args', tuple()):
+ return
+
+ self._error(field, '{} is not an arg used in any partial.'.format(value))
+
+
+def build_partial_description(partial_spec):
+ """Create the documentation lines for a specific partial.
+
+ Generates something like this:
+
+ # This is the partial's description, from spec.yml.
+ # --build-arg ARG_NAME=argdefault
+ # this is one of the args.
+ # --build-arg ANOTHER_ARG=(some|choices)
+ # another arg.
+
+ Args:
+ partial_spec: A dict representing one of the partials from spec.yml. Doesn't
+ include the name of the partial; is a dict like { desc: ..., args: ... }.
+
+ Returns:
+ A commented string describing this partial.
+ """
+
+ # Start from linewrapped desc field
+ lines = []
+ wrapper = textwrap.TextWrapper(
+ initial_indent='# ', subsequent_indent='# ', width=80)
+ description = wrapper.fill(partial_spec.get('desc', '( no comments )'))
+ lines.extend(['#', description])
+
+ # Document each arg
+ for arg, arg_data in partial_spec.get('args', dict()).items():
+ # Wrap arg description with comment lines
+ desc = arg_data.get('desc', '( no description )')
+ desc = textwrap.fill(
+ desc,
+ initial_indent='# ',
+ subsequent_indent='# ',
+ width=80,
+ drop_whitespace=False)
+
+ # Document (each|option|like|this)
+ if 'options' in arg_data:
+ arg_options = ' ({})'.format('|'.join(arg_data['options']))
+ else:
+ arg_options = ''
+
+ # Add usage sample
+ arg_use = '# --build-arg {}={}{}'.format(arg,
+ arg_data.get('default', '(unset)'),
+ arg_options)
+ lines.extend([arg_use, desc])
+
+ return '\n'.join(lines)
+
+
+def construct_contents(partial_specs, image_spec):
+ """Assemble the dockerfile contents for an image spec.
+
+ It assembles a concrete list of partial references into a single, large
+ string.
+ Also expands argument defaults, so that the resulting Dockerfile doesn't have
+ to be configured with --build-arg=... every time. That is, any ARG directive
+ will be updated with a new default value.
+
+ Args:
+ partial_specs: The dict from spec.yml["partials"].
+ image_spec: One of the dict values from spec.yml["images"].
+
+ Returns:
+ A string containing a valid Dockerfile based on the partials listed in
+ image_spec.
+ """
+ processed_partial_strings = []
+ for partial_name in image_spec['partials']:
+ # Apply image arg-defaults to existing arg defaults
+ partial_spec = copy.deepcopy(partial_specs[partial_name])
+ args = partial_spec.get('args', dict())
+ for k_v in image_spec.get('arg-defaults', []):
+ arg, value = list(k_v.items())[0]
+ if arg in args:
+ args[arg]['default'] = value
+
+ # Read partial file contents
+ filename = partial_spec.get('file', partial_name)
+ partial_path = os.path.join(FLAGS.partial_dir,
+ '{}.partial.Dockerfile'.format(filename))
+ with open(partial_path, 'r') as f_partial:
+ partial_contents = f_partial.read()
+
+ # Replace ARG FOO=BAR with ARG FOO=[new-default]
+ for arg, arg_data in args.items():
+ if 'default' in arg_data and arg_data['default']:
+ default = '={}'.format(arg_data['default'])
+ else:
+ default = ''
+ partial_contents = re.sub(r'ARG {}.*'.format(arg), 'ARG {}{}'.format(
+ arg, default), partial_contents)
+
+ # Store updated partial contents
+ processed_partial_strings.append(partial_contents)
+
+ # Join everything together
+ return '\n'.join(processed_partial_strings)
+
+
+def mkdir_p(path):
+ """Create a directory and its parents, even if it already exists."""
+ try:
+ os.makedirs(path)
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise
+
+
+def construct_documentation(header, partial_specs, image_spec):
+ """Assemble all of the documentation for a single dockerfile.
+
+ Builds explanations of included partials and available build args.
+
+ Args:
+ header: The string from spec.yml["header"]; will be commented and wrapped.
+ partial_specs: The dict from spec.yml["partials"].
+ image_spec: The spec for the dockerfile being built.
+
+ Returns:
+ A string containing a commented header that documents the contents of the
+ dockerfile.
+
+ """
+ # Comment and wrap header and image description
+ commented_header = '\n'.join(
+ [('# ' + l).rstrip() for l in header.splitlines()])
+ commented_desc = '\n'.join(
+ ['# ' + l for l in image_spec.get('desc', '').splitlines()])
+ partial_descriptions = []
+
+ # Build documentation for each partial in the image
+ for partial in image_spec['partials']:
+ # Copy partial data for default args unique to this image
+ partial_spec = copy.deepcopy(partial_specs[partial])
+ args = partial_spec.get('args', dict())
+
+ # Overwrite any existing arg defaults
+ for k_v in image_spec.get('arg-defaults', []):
+ arg, value = list(k_v.items())[0]
+ if arg in args:
+ args[arg]['default'] = value
+
+ # Build the description from new args
+ partial_description = build_partial_description(partial_spec)
+ partial_descriptions.append(partial_description)
+
+ contents = [commented_header, '#', commented_desc] + partial_descriptions
+ return '\n'.join(contents) + '\n'
+
+
+def normalize_partial_args(partial_specs):
+ """Normalize the shorthand form of a partial's args specification.
+
+ Turns this:
+
+ partial:
+ args:
+ SOME_ARG: arg_value
+
+ Into this:
+
+ partial:
+ args:
+ SOME_ARG:
+ default: arg_value
+
+ Args:
+ partial_specs: The dict from spec.yml["partials"]. This dict is modified in
+ place.
+
+ Returns:
+ The modified contents of partial_specs.
+
+ """
+ for _, partial in partial_specs.items():
+ args = partial.get('args', dict())
+ for arg, value in args.items():
+ if not isinstance(value, dict):
+ new_value = {'default': value}
+ args[arg] = new_value
+
+ return partial_specs
+
+
+def flatten_args_references(image_specs):
+ """Resolve all default-args in each image spec to a concrete dict.
+
+ Turns this:
+
+ example-image:
+ arg-defaults:
+ - MY_ARG: ARG_VALUE
+
+ another-example:
+ arg-defaults:
+ - ANOTHER_ARG: ANOTHER_VALUE
+ - example_image
+
+ Into this:
+
+ example-image:
+ arg-defaults:
+ - MY_ARG: ARG_VALUE
+
+ another-example:
+ arg-defaults:
+ - ANOTHER_ARG: ANOTHER_VALUE
+ - MY_ARG: ARG_VALUE
+
+ Args:
+ image_specs: A dict of image_spec dicts; should be the contents of the
+ "images" key in the global spec.yaml. This dict is modified in place and
+ then returned.
+
+ Returns:
+ The modified contents of image_specs.
+ """
+ for _, image_spec in image_specs.items():
+ too_deep = 0
+ while str in map(type, image_spec.get('arg-defaults', [])) and too_deep < 5:
+ new_args = []
+ for arg in image_spec['arg-defaults']:
+ if isinstance(arg, str):
+ new_args.extend(image_specs[arg]['arg-defaults'])
+ else:
+ new_args.append(arg)
+
+ image_spec['arg-defaults'] = new_args
+ too_deep += 1
+
+ return image_specs
+
+
+def flatten_partial_references(image_specs):
+ """Resolve all partial references in each image spec to a concrete list.
+
+ Turns this:
+
+ example-image:
+ partials:
+ - foo
+
+ another-example:
+ partials:
+ - bar
+ - image: example-image
+ - bat
+
+ Into this:
+
+ example-image:
+ partials:
+ - foo
+
+ another-example:
+ partials:
+ - bar
+ - foo
+ - bat
+ Args:
+ image_specs: A dict of image_spec dicts; should be the contents of the
+ "images" key in the global spec.yaml. This dict is modified in place and
+ then returned.
+
+ Returns:
+ The modified contents of image_specs.
+ """
+ for _, image_spec in image_specs.items():
+ too_deep = 0
+ while dict in map(type, image_spec['partials']) and too_deep < 5:
+ new_partials = []
+ for partial in image_spec['partials']:
+ if isinstance(partial, str):
+ new_partials.append(partial)
+ else:
+ new_partials.extend(image_specs[partial['image']]['partials'])
+
+ image_spec['partials'] = new_partials
+ too_deep += 1
+
+ return image_specs
+
+
+def construct_dockerfiles(tf_spec):
+ """Generate a mapping of {"cpu": <cpu dockerfile contents>, ...}.
+
+ Args:
+ tf_spec: The full spec.yml loaded as a python object.
+
+ Returns:
+ A string:string dict of short names ("cpu-devel") to Dockerfile contents.
+ """
+ names_to_contents = dict()
+ image_specs = tf_spec['images']
+ image_specs = flatten_partial_references(image_specs)
+ image_specs = flatten_args_references(image_specs)
+ partial_specs = tf_spec['partials']
+ partial_specs = normalize_partial_args(partial_specs)
+
+ for name, image_spec in image_specs.items():
+ if not image_spec.get('create-dockerfile', True):
+ continue
+ documentation = construct_documentation(tf_spec['header'], partial_specs,
+ image_spec)
+ contents = construct_contents(partial_specs, image_spec)
+ names_to_contents[name] = '\n'.join([documentation, contents])
+
+ return names_to_contents
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise app.UsageError('Unexpected command line args found: {}'.format(argv))
+
+ with open(FLAGS.spec_file, 'r') as spec_file:
+ tf_spec = yaml.load(spec_file)
+
+ # Abort if spec.yaml is invalid
+ if FLAGS.validate:
+ schema = yaml.load(SCHEMA_TEXT)
+ v = TfDockerValidator(schema)
+ if not v.validate(tf_spec):
+ print('>> ERROR: {} is an invalid spec! The errors are:'.format(
+ FLAGS.spec_file))
+ print(yaml.dump(v.errors, indent=2))
+ exit(1)
+ else:
+ print('>> WARNING: Not validating {}'.format(FLAGS.spec_file))
+
+ # Generate mapping of { "cpu-devel": "<cpu-devel dockerfile contents>", ... }
+ names_to_contents = construct_dockerfiles(tf_spec)
+
+ # Write each completed Dockerfile
+ if not FLAGS.dry_run:
+ print('>> Emptying destination dir "{}"'.format(FLAGS.output_dir))
+ shutil.rmtree(FLAGS.output_dir, ignore_errors=True)
+ mkdir_p(FLAGS.output_dir)
+ else:
+ print('>> Skipping creation of {} (dry run)'.format(FLAGS.output_dir))
+ for name, contents in names_to_contents.items():
+ path = os.path.join(FLAGS.output_dir, name + '.Dockerfile')
+ if FLAGS.dry_run:
+ print('>> Skipping writing contents of {} (dry run)'.format(path))
+ print(contents)
+ else:
+ mkdir_p(FLAGS.output_dir)
+ print('>> Writing {}'.format(path))
+ with open(path, 'w') as f:
+ f.write(contents)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/tensorflow/tools/dockerfiles/bashrc b/tensorflow/tools/dockerfiles/bashrc
new file mode 100644
index 0000000000..48cacf20f6
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/bashrc
@@ -0,0 +1,50 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+export PS1="\[\e[31m\]tf-docker\[\e[m\] \[\e[33m\]\w\[\e[m\] > "
+export TERM=xterm-256color
+alias grep="grep --color=auto"
+alias ls="ls --color=auto"
+
+echo -e "\e[1;31m"
+cat<<TF
+________ _______________
+___ __/__________________________________ ____/__ /________ __
+__ / _ _ \_ __ \_ ___/ __ \_ ___/_ /_ __ /_ __ \_ | /| / /
+_ / / __/ / / /(__ )/ /_/ / / _ __/ _ / / /_/ /_ |/ |/ /
+/_/ \___//_/ /_//____/ \____//_/ /_/ /_/ \____/____/|__/
+
+TF
+echo -e "\e[0;33m"
+
+if [[ $EUID -eq 0 ]]; then
+ cat <<WARN
+WARNING: You are running this container as root, which can cause new files in
+mounted volumes to be created as the root user on your host machine.
+
+To avoid this, run the container by specifying your user's userid:
+
+$ docker run -u \$(id -u):\$(id -g) args...
+WARN
+else
+ cat <<EXPL
+You are running this container as user with ID $(id -u) and group $(id -g),
+which should map to the ID and group for your user on the Docker host. Great!
+EXPL
+fi
+
+# Turn off colors
+echo -e "\e[m"
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/cpu-devel-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/cpu-devel-jupyter.Dockerfile
new file mode 100644
index 0000000000..dbbad7d03a
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/cpu-devel-jupyter.Dockerfile
@@ -0,0 +1,100 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# below. Please refer to the the TensorFlow dockerfiles documentation for
+# more information. Build args are documented as their default value.
+#
+# Ubuntu-based, CPU-only environment for developing changes for TensorFlow, with Jupyter included.
+#
+# Start from Ubuntu, with TF development packages (no GPU support)
+# --build-arg UBUNTU_VERSION=16.04
+# ( no description )
+#
+# Python is required for TensorFlow and other libraries.
+# --build-arg USE_PYTHON_3_NOT_2=True
+# Install python 3 over Python 2
+#
+# Install the latest version of Bazel and Python development tools.
+#
+# Configure TensorFlow's shell prompt and login tools.
+#
+# Launch Jupyter on execution instead of a bash prompt.
+
+ARG UBUNTU_VERSION=16.04
+FROM ubuntu:${UBUNTU_VERSION}
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ curl \
+ git \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ python-dev \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ openjdk-8-jdk \
+ openjdk-8-jre-headless \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+ARG USE_PYTHON_3_NOT_2=True
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} install --upgrade \
+ pip \
+ setuptools
+
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ curl \
+ git \
+ openjdk-8-jdk \
+ ${PYTHON}-dev \
+ swig
+
+# Install bazel
+RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \
+ curl https://bazel.build/bazel-release.pub.gpg | apt-key add - && \
+ apt-get update && \
+ apt-get install -y bazel
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
+
+RUN ${PIP} install jupyter
+
+RUN mkdir /notebooks && chmod a+rwx /notebooks
+RUN mkdir /.local && chmod a+rwx /.local
+WORKDIR /notebooks
+EXPOSE 8888
+
+CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/notebooks --ip 0.0.0.0 --no-browser --allow-root"]
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/cpu-devel.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/cpu-devel.Dockerfile
new file mode 100644
index 0000000000..160d7c02e2
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/cpu-devel.Dockerfile
@@ -0,0 +1,89 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# below. Please refer to the the TensorFlow dockerfiles documentation for
+# more information. Build args are documented as their default value.
+#
+# Ubuntu-based, CPU-only environment for developing changes for TensorFlow.
+#
+# Start from Ubuntu, with TF development packages (no GPU support)
+# --build-arg UBUNTU_VERSION=16.04
+# ( no description )
+#
+# Python is required for TensorFlow and other libraries.
+# --build-arg USE_PYTHON_3_NOT_2=True
+# Install python 3 over Python 2
+#
+# Install the latest version of Bazel and Python development tools.
+#
+# Configure TensorFlow's shell prompt and login tools.
+
+ARG UBUNTU_VERSION=16.04
+FROM ubuntu:${UBUNTU_VERSION}
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ curl \
+ git \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ python-dev \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ openjdk-8-jdk \
+ openjdk-8-jre-headless \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+ARG USE_PYTHON_3_NOT_2=True
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} install --upgrade \
+ pip \
+ setuptools
+
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ curl \
+ git \
+ openjdk-8-jdk \
+ ${PYTHON}-dev \
+ swig
+
+# Install bazel
+RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \
+ curl https://bazel.build/bazel-release.pub.gpg | apt-key add - && \
+ apt-get update && \
+ apt-get install -y bazel
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile
new file mode 100644
index 0000000000..8d5d653ab7
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile
@@ -0,0 +1,69 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# below. Please refer to the the TensorFlow dockerfiles documentation for
+# more information. Build args are documented as their default value.
+#
+# Ubuntu-based, CPU-only environment for using TensorFlow, with Jupyter included.
+#
+# Start from Ubuntu (no GPU support)
+# --build-arg UBUNTU_VERSION=16.04
+# ( no description )
+#
+# Python is required for TensorFlow and other libraries.
+# --build-arg USE_PYTHON_3_NOT_2=True
+# Install python 3 over Python 2
+#
+# Install the TensorFlow Python package.
+# --build-arg TF_PACKAGE=tensorflow (tensorflow|tensorflow-gpu|tf-nightly|tf-nightly-gpu)
+# The specific TensorFlow Python package to install
+#
+# Configure TensorFlow's shell prompt and login tools.
+#
+# Launch Jupyter on execution instead of a bash prompt.
+
+ARG UBUNTU_VERSION=16.04
+FROM ubuntu:${UBUNTU_VERSION}
+
+ARG USE_PYTHON_3_NOT_2=True
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} install --upgrade \
+ pip \
+ setuptools
+
+ARG TF_PACKAGE=tensorflow
+RUN ${PIP} install ${TF_PACKAGE}
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
+
+RUN ${PIP} install jupyter
+
+RUN mkdir /notebooks && chmod a+rwx /notebooks
+RUN mkdir /.local && chmod a+rwx /.local
+WORKDIR /notebooks
+EXPOSE 8888
+
+CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/notebooks --ip 0.0.0.0 --no-browser --allow-root"]
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile
new file mode 100644
index 0000000000..35c41b49fd
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile
@@ -0,0 +1,58 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# below. Please refer to the the TensorFlow dockerfiles documentation for
+# more information. Build args are documented as their default value.
+#
+# Ubuntu-based, CPU-only environment for using TensorFlow
+#
+# Start from Ubuntu (no GPU support)
+# --build-arg UBUNTU_VERSION=16.04
+# ( no description )
+#
+# Python is required for TensorFlow and other libraries.
+# --build-arg USE_PYTHON_3_NOT_2=True
+# Install python 3 over Python 2
+#
+# Install the TensorFlow Python package.
+# --build-arg TF_PACKAGE=tensorflow (tensorflow|tensorflow-gpu|tf-nightly|tf-nightly-gpu)
+# The specific TensorFlow Python package to install
+#
+# Configure TensorFlow's shell prompt and login tools.
+
+ARG UBUNTU_VERSION=16.04
+FROM ubuntu:${UBUNTU_VERSION}
+
+ARG USE_PYTHON_3_NOT_2=True
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} install --upgrade \
+ pip \
+ setuptools
+
+ARG TF_PACKAGE=tensorflow
+RUN ${PIP} install ${TF_PACKAGE}
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel-jupyter.Dockerfile
new file mode 100644
index 0000000000..0f5fedf2fe
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel-jupyter.Dockerfile
@@ -0,0 +1,120 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# below. Please refer to the the TensorFlow dockerfiles documentation for
+# more information. Build args are documented as their default value.
+#
+# Ubuntu-based, Nvidia-GPU-enabled environment for developing changes for TensorFlow, with Jupyter included.
+#
+# Start from Nvidia's Ubuntu base image with CUDA and CuDNN, with TF development
+# packages.
+# --build-arg UBUNTU_VERSION=16.04
+# ( no description )
+#
+# Python is required for TensorFlow and other libraries.
+# --build-arg USE_PYTHON_3_NOT_2=True
+# Install python 3 over Python 2
+#
+# Install the latest version of Bazel and Python development tools.
+#
+# Configure TensorFlow's shell prompt and login tools.
+#
+# Launch Jupyter on execution instead of a bash prompt.
+
+ARG UBUNTU_VERSION=16.04
+FROM nvidia/cuda:9.0-base-ubuntu${UBUNTU_VERSION}
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ cuda-command-line-tools-9-0 \
+ cuda-cublas-dev-9-0 \
+ cuda-cudart-dev-9-0 \
+ cuda-cufft-dev-9-0 \
+ cuda-curand-dev-9-0 \
+ cuda-cusolver-dev-9-0 \
+ cuda-cusparse-dev-9-0 \
+ curl \
+ git \
+ libcudnn7=7.1.4.18-1+cuda9.0 \
+ libcudnn7-dev=7.1.4.18-1+cuda9.0 \
+ libnccl2=2.2.13-1+cuda9.0 \
+ libnccl-dev=2.2.13-1+cuda9.0 \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ wget \
+ && \
+ rm -rf /var/lib/apt/lists/* && \
+ find /usr/local/cuda-9.0/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
+ rm /usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a
+
+# Link NCCL libray and header where the build script expects them.
+RUN mkdir /usr/local/cuda-9.0/lib && \
+ ln -s /usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/local/cuda/lib/libnccl.so.2 && \
+ ln -s /usr/include/nccl.h /usr/local/cuda/include/nccl.h
+
+# TODO(tobyboyd): Remove after license is excluded from BUILD file.
+RUN gunzip /usr/share/doc/libnccl2/NCCL-SLA.txt.gz && \
+ cp /usr/share/doc/libnccl2/NCCL-SLA.txt /usr/local/cuda/
+
+ARG USE_PYTHON_3_NOT_2=True
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} install --upgrade \
+ pip \
+ setuptools
+
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ curl \
+ git \
+ openjdk-8-jdk \
+ ${PYTHON}-dev \
+ swig
+
+# Install bazel
+RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \
+ curl https://bazel.build/bazel-release.pub.gpg | apt-key add - && \
+ apt-get update && \
+ apt-get install -y bazel
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
+
+RUN ${PIP} install jupyter
+
+RUN mkdir /notebooks && chmod a+rwx /notebooks
+RUN mkdir /.local && chmod a+rwx /.local
+WORKDIR /notebooks
+EXPOSE 8888
+
+CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/notebooks --ip 0.0.0.0 --no-browser --allow-root"]
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel.Dockerfile
new file mode 100644
index 0000000000..a6e280082e
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel.Dockerfile
@@ -0,0 +1,109 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# below. Please refer to the the TensorFlow dockerfiles documentation for
+# more information. Build args are documented as their default value.
+#
+# Ubuntu-based, Nvidia-GPU-enabled environment for developing changes for TensorFlow.
+#
+# Start from Nvidia's Ubuntu base image with CUDA and CuDNN, with TF development
+# packages.
+# --build-arg UBUNTU_VERSION=16.04
+# ( no description )
+#
+# Python is required for TensorFlow and other libraries.
+# --build-arg USE_PYTHON_3_NOT_2=True
+# Install python 3 over Python 2
+#
+# Install the latest version of Bazel and Python development tools.
+#
+# Configure TensorFlow's shell prompt and login tools.
+
+ARG UBUNTU_VERSION=16.04
+FROM nvidia/cuda:9.0-base-ubuntu${UBUNTU_VERSION}
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ cuda-command-line-tools-9-0 \
+ cuda-cublas-dev-9-0 \
+ cuda-cudart-dev-9-0 \
+ cuda-cufft-dev-9-0 \
+ cuda-curand-dev-9-0 \
+ cuda-cusolver-dev-9-0 \
+ cuda-cusparse-dev-9-0 \
+ curl \
+ git \
+ libcudnn7=7.1.4.18-1+cuda9.0 \
+ libcudnn7-dev=7.1.4.18-1+cuda9.0 \
+ libnccl2=2.2.13-1+cuda9.0 \
+ libnccl-dev=2.2.13-1+cuda9.0 \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ wget \
+ && \
+ rm -rf /var/lib/apt/lists/* && \
+ find /usr/local/cuda-9.0/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
+ rm /usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a
+
+# Link NCCL libray and header where the build script expects them.
+RUN mkdir /usr/local/cuda-9.0/lib && \
+ ln -s /usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/local/cuda/lib/libnccl.so.2 && \
+ ln -s /usr/include/nccl.h /usr/local/cuda/include/nccl.h
+
+# TODO(tobyboyd): Remove after license is excluded from BUILD file.
+RUN gunzip /usr/share/doc/libnccl2/NCCL-SLA.txt.gz && \
+ cp /usr/share/doc/libnccl2/NCCL-SLA.txt /usr/local/cuda/
+
+ARG USE_PYTHON_3_NOT_2=True
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} install --upgrade \
+ pip \
+ setuptools
+
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ curl \
+ git \
+ openjdk-8-jdk \
+ ${PYTHON}-dev \
+ swig
+
+# Install bazel
+RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \
+ curl https://bazel.build/bazel-release.pub.gpg | apt-key add - && \
+ apt-get update && \
+ apt-get install -y bazel
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/nvidia-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-jupyter.Dockerfile
new file mode 100644
index 0000000000..f1799113b1
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-jupyter.Dockerfile
@@ -0,0 +1,90 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# below. Please refer to the the TensorFlow dockerfiles documentation for
+# more information. Build args are documented as their default value.
+#
+# Ubuntu-based, Nvidia-GPU-enabled environment for using TensorFlow, with Jupyter included.
+#
+# NVIDIA with CUDA and CuDNN, no dev stuff
+# --build-arg UBUNTU_VERSION=16.04
+# ( no description )
+#
+# Python is required for TensorFlow and other libraries.
+# --build-arg USE_PYTHON_3_NOT_2=True
+# Install python 3 over Python 2
+#
+# Install the TensorFlow Python package.
+# --build-arg TF_PACKAGE=tensorflow-gpu (tensorflow|tensorflow-gpu|tf-nightly|tf-nightly-gpu)
+# The specific TensorFlow Python package to install
+#
+# Configure TensorFlow's shell prompt and login tools.
+#
+# Launch Jupyter on execution instead of a bash prompt.
+
+FROM nvidia/cuda:9.0-base-ubuntu16.04
+
+# Pick up some TF dependencies
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ cuda-command-line-tools-9-0 \
+ cuda-cublas-9-0 \
+ cuda-cufft-9-0 \
+ cuda-curand-9-0 \
+ cuda-cusolver-9-0 \
+ cuda-cusparse-9-0 \
+ libcudnn7=7.1.4.18-1+cuda9.0 \
+ libnccl2=2.2.13-1+cuda9.0 \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ software-properties-common \
+ unzip \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+ARG USE_PYTHON_3_NOT_2=True
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} install --upgrade \
+ pip \
+ setuptools
+
+ARG TF_PACKAGE=tensorflow-gpu
+RUN ${PIP} install ${TF_PACKAGE}
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
+
+RUN ${PIP} install jupyter
+
+RUN mkdir /notebooks && chmod a+rwx /notebooks
+RUN mkdir /.local && chmod a+rwx /.local
+WORKDIR /notebooks
+EXPOSE 8888
+
+CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/notebooks --ip 0.0.0.0 --no-browser --allow-root"]
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/nvidia.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/nvidia.Dockerfile
new file mode 100644
index 0000000000..690eb68b22
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/nvidia.Dockerfile
@@ -0,0 +1,79 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# below. Please refer to the the TensorFlow dockerfiles documentation for
+# more information. Build args are documented as their default value.
+#
+# Ubuntu-based, Nvidia-GPU-enabled environment for using TensorFlow.
+#
+# NVIDIA with CUDA and CuDNN, no dev stuff
+# --build-arg UBUNTU_VERSION=16.04
+# ( no description )
+#
+# Python is required for TensorFlow and other libraries.
+# --build-arg USE_PYTHON_3_NOT_2=True
+# Install python 3 over Python 2
+#
+# Install the TensorFlow Python package.
+# --build-arg TF_PACKAGE=tensorflow-gpu (tensorflow|tensorflow-gpu|tf-nightly|tf-nightly-gpu)
+# The specific TensorFlow Python package to install
+#
+# Configure TensorFlow's shell prompt and login tools.
+
+FROM nvidia/cuda:9.0-base-ubuntu16.04
+
+# Pick up some TF dependencies
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ cuda-command-line-tools-9-0 \
+ cuda-cublas-9-0 \
+ cuda-cufft-9-0 \
+ cuda-curand-9-0 \
+ cuda-cusolver-9-0 \
+ cuda-cusparse-9-0 \
+ libcudnn7=7.1.4.18-1+cuda9.0 \
+ libnccl2=2.2.13-1+cuda9.0 \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ software-properties-common \
+ unzip \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+ARG USE_PYTHON_3_NOT_2=True
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} install --upgrade \
+ pip \
+ setuptools
+
+ARG TF_PACKAGE=tensorflow-gpu
+RUN ${PIP} install ${TF_PACKAGE}
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
diff --git a/tensorflow/tools/dockerfiles/partials/bazel.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/bazel.partial.Dockerfile
new file mode 100644
index 0000000000..b08d8bdd14
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/bazel.partial.Dockerfile
@@ -0,0 +1,13 @@
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ curl \
+ git \
+ openjdk-8-jdk \
+ ${PYTHON}-dev \
+ swig
+
+# Install bazel
+RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \
+ curl https://bazel.build/bazel-release.pub.gpg | apt-key add - && \
+ apt-get update && \
+ apt-get install -y bazel
diff --git a/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile
new file mode 100644
index 0000000000..2c9b9f3f9a
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile
@@ -0,0 +1,8 @@
+RUN ${PIP} install jupyter
+
+RUN mkdir /notebooks && chmod a+rwx /notebooks
+RUN mkdir /.local && chmod a+rwx /.local
+WORKDIR /notebooks
+EXPOSE 8888
+
+CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/notebooks --ip 0.0.0.0 --no-browser --allow-root"]
diff --git a/tensorflow/tools/dockerfiles/partials/nvidia-devel.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/nvidia-devel.partial.Dockerfile
new file mode 100644
index 0000000000..f31b695e77
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/nvidia-devel.partial.Dockerfile
@@ -0,0 +1,43 @@
+ARG UBUNTU_VERSION=16.04
+FROM nvidia/cuda:9.0-base-ubuntu${UBUNTU_VERSION}
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ cuda-command-line-tools-9-0 \
+ cuda-cublas-dev-9-0 \
+ cuda-cudart-dev-9-0 \
+ cuda-cufft-dev-9-0 \
+ cuda-curand-dev-9-0 \
+ cuda-cusolver-dev-9-0 \
+ cuda-cusparse-dev-9-0 \
+ curl \
+ git \
+ libcudnn7=7.1.4.18-1+cuda9.0 \
+ libcudnn7-dev=7.1.4.18-1+cuda9.0 \
+ libnccl2=2.2.13-1+cuda9.0 \
+ libnccl-dev=2.2.13-1+cuda9.0 \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ wget \
+ && \
+ rm -rf /var/lib/apt/lists/* && \
+ find /usr/local/cuda-9.0/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
+ rm /usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a
+
+# Link NCCL libray and header where the build script expects them.
+RUN mkdir /usr/local/cuda-9.0/lib && \
+ ln -s /usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/local/cuda/lib/libnccl.so.2 && \
+ ln -s /usr/include/nccl.h /usr/local/cuda/include/nccl.h
+
+# TODO(tobyboyd): Remove after license is excluded from BUILD file.
+RUN gunzip /usr/share/doc/libnccl2/NCCL-SLA.txt.gz && \
+ cp /usr/share/doc/libnccl2/NCCL-SLA.txt /usr/local/cuda/
diff --git a/tensorflow/tools/dockerfiles/partials/nvidia.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/nvidia.partial.Dockerfile
new file mode 100644
index 0000000000..13d865b9d4
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/nvidia.partial.Dockerfile
@@ -0,0 +1,23 @@
+FROM nvidia/cuda:9.0-base-ubuntu16.04
+
+# Pick up some TF dependencies
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ cuda-command-line-tools-9-0 \
+ cuda-cublas-9-0 \
+ cuda-cufft-9-0 \
+ cuda-curand-9-0 \
+ cuda-cusolver-9-0 \
+ cuda-cusparse-9-0 \
+ libcudnn7=7.1.4.18-1+cuda9.0 \
+ libnccl2=2.2.13-1+cuda9.0 \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ software-properties-common \
+ unzip \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
diff --git a/tensorflow/tools/dockerfiles/partials/python.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/python.partial.Dockerfile
new file mode 100644
index 0000000000..6f346236a5
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/python.partial.Dockerfile
@@ -0,0 +1,12 @@
+ARG USE_PYTHON_3_NOT_2
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} install --upgrade \
+ pip \
+ setuptools
diff --git a/tensorflow/tools/dockerfiles/partials/shell.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/shell.partial.Dockerfile
new file mode 100644
index 0000000000..d641a11b06
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/shell.partial.Dockerfile
@@ -0,0 +1,2 @@
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
diff --git a/tensorflow/tools/dockerfiles/partials/tensorflow.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/tensorflow.partial.Dockerfile
new file mode 100644
index 0000000000..96e79547f0
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/tensorflow.partial.Dockerfile
@@ -0,0 +1,2 @@
+ARG TF_PACKAGE
+RUN ${PIP} install ${TF_PACKAGE}
diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu-devel.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu-devel.partial.Dockerfile
new file mode 100644
index 0000000000..bc79272276
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/ubuntu-devel.partial.Dockerfile
@@ -0,0 +1,24 @@
+ARG UBUNTU_VERSION=16.04
+FROM ubuntu:${UBUNTU_VERSION}
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ curl \
+ git \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ python-dev \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ openjdk-8-jdk \
+ openjdk-8-jre-headless \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu.partial.Dockerfile
new file mode 100644
index 0000000000..0a50735bf8
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/ubuntu.partial.Dockerfile
@@ -0,0 +1,2 @@
+ARG UBUNTU_VERSION=16.04
+FROM ubuntu:${UBUNTU_VERSION}
diff --git a/tensorflow/tools/dockerfiles/spec.yml b/tensorflow/tools/dockerfiles/spec.yml
new file mode 100644
index 0000000000..28bf9a55da
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/spec.yml
@@ -0,0 +1,195 @@
+# ======
+# HEADER
+# ======
+#
+# This is commented-out and prepended to each generated Dockerfile.
+header: |
+ Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ============================================================================
+
+ THIS IS A GENERATED DOCKERFILE.
+
+ This file was assembled from multiple pieces, whose use is documented
+ below. Please refer to the the TensorFlow dockerfiles documentation for
+ more information. Build args are documented as their default value.
+
+# ========
+# PARTIALS
+# ========
+#
+# Represent and document pieces of a Dockerfile. Spec:
+#
+# name: the name of the partial, is referenced from the images section
+# desc: A description, inserted later into the Dockerfile
+# file: Alternative file prefix, e.g. file.partial.Dockerfile. The default is
+# the name of the partial.
+# args: A dict of ARGs in the Dockerfile; each entry has the format
+# ARG_NAME: VALUE where VALUE is one of:
+# - a dict:
+# desc: Documentation for the arg
+# default: Default value for the arg; is written to the Dockerfile
+# options: List of strings, part of documentation
+# - a concrete value: the same as a dictionary with default: [value].
+
+partials:
+ ubuntu:
+ desc: Start from Ubuntu (no GPU support)
+ args:
+ UBUNTU_VERSION: 16.04
+
+ ubuntu-devel:
+ desc: Start from Ubuntu, with TF development packages (no GPU support)
+ args:
+ UBUNTU_VERSION: 16.04
+
+ bazel:
+ desc: Install the latest version of Bazel and Python development tools.
+
+ nvidia:
+ desc: NVIDIA with CUDA and CuDNN, no dev stuff
+ args:
+ UBUNTU_VERSION: 16.04
+
+ nvidia-devel:
+ desc: >
+ Start from Nvidia's Ubuntu base image with CUDA and CuDNN, with TF
+ development packages.
+ args:
+ UBUNTU_VERSION: 16.04
+
+ python:
+ desc: Python is required for TensorFlow and other libraries.
+ args:
+ USE_PYTHON_3_NOT_2:
+ default: true
+ desc: Install python 3 over Python 2
+
+ tensorflow:
+ desc: Install the TensorFlow Python package.
+ args:
+ TF_PACKAGE:
+ default: tensorflow
+ options:
+ - tensorflow
+ - tensorflow-gpu
+ - tf-nightly
+ - tf-nightly-gpu
+ desc: The specific TensorFlow Python package to install
+ shell:
+ desc: Configure TensorFlow's shell prompt and login tools.
+ jupyter:
+ desc: Launch Jupyter on execution instead of a bash prompt.
+
+# ======
+# IMAGES
+# ======
+#
+# Represent Dockerfiles. Spec:
+#
+# name: the name of the image, possibly referenced by other images
+# desc: A description, inserted later into the Dockerfile
+# create-dockerfile: Create a dockerfile based on this. Useful for creating
+# extensible base images that don't need a file. Default is true.
+# partials: List of VALUEs, where a VALUE is either:
+# - the name of a partial, which inserts that partial into this image
+# - image: [name of another image], which inserts the partials from that
+# image into this image
+# arg-defaults: List of VALUEs, where a VALUE is either:
+# - ARG_NAME: VALUE, which sets the ARG_NAME to VALUE wherever it appears
+# in this image's partials
+# - [name of another image], which loads the default args from that image
+images:
+
+ nodev:
+ create-dockerfile: false
+ partials:
+ - python
+ - tensorflow
+ - shell
+
+ dev:
+ create-dockerfile: false
+ partials:
+ - python
+ - bazel
+ - shell
+
+ cpu:
+ desc: Ubuntu-based, CPU-only environment for using TensorFlow
+ partials:
+ - ubuntu
+ - image: nodev
+
+ cpu-devel:
+ desc: >
+ Ubuntu-based, CPU-only environment for developing changes for
+ TensorFlow.
+ partials:
+ - ubuntu-devel
+ - image: dev
+
+ nvidia:
+ desc: Ubuntu-based, Nvidia-GPU-enabled environment for using TensorFlow.
+ arg-defaults:
+ - TF_PACKAGE: tensorflow-gpu
+ partials:
+ - nvidia
+ - image: nodev
+
+ nvidia-devel:
+ desc: >
+ Ubuntu-based, Nvidia-GPU-enabled environment for developing changes
+ for TensorFlow.
+ arg-defaults:
+ - TF_PACKAGE: tensorflow-gpu
+ partials:
+ - nvidia-devel
+ - image: dev
+
+ cpu-jupyter:
+ desc: >
+ Ubuntu-based, CPU-only environment for using TensorFlow, with Jupyter
+ included.
+ partials:
+ - image: cpu
+ - jupyter
+
+ cpu-devel-jupyter:
+ desc: >
+ Ubuntu-based, CPU-only environment for developing changes for
+ TensorFlow, with Jupyter included.
+ partials:
+ - image: cpu-devel
+ - jupyter
+
+ nvidia-jupyter:
+ desc: >
+ Ubuntu-based, Nvidia-GPU-enabled environment for using TensorFlow, with
+ Jupyter included.
+ arg-defaults:
+ - nvidia
+ partials:
+ - image: nvidia
+ - jupyter
+
+ nvidia-devel-jupyter:
+ desc: >
+ Ubuntu-based, Nvidia-GPU-enabled environment for developing changes for
+ TensorFlow, with Jupyter included.
+ arg-defaults:
+ - nvidia-devel
+ partials:
+ - image: nvidia-devel
+ - jupyter
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 9d0ce34344..34b4a66c41 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -493,11 +493,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/6203c9bd082a877a20c218033636712135a3c2db.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/6203c9bd082a877a20c218033636712135a3c2db.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/97d7bcd5c024ee6aec4eecbc723bb6d4f4c3dc3d.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/97d7bcd5c024ee6aec4eecbc723bb6d4f4c3dc3d.tar.gz",
],
- sha256 = "83a80f9fb2a5949ca77e526344cbd4581388c3ec7fea5c59e488d46fd38e06d9",
- strip_prefix = "llvm-6203c9bd082a877a20c218033636712135a3c2db",
+ sha256 = "2889b79ab979e676e344974cfeefbaf2c21c7c69a015bd584e8ae67b87b136bc",
+ strip_prefix = "llvm-97d7bcd5c024ee6aec4eecbc723bb6d4f4c3dc3d",
build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
)