aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--CODEOWNERS4
-rw-r--r--README.md2
-rw-r--r--configure.py6
-rw-r--r--tensorflow/BUILD5
-rw-r--r--tensorflow/c/BUILD9
-rw-r--r--tensorflow/c/eager/tape.h17
-rw-r--r--tensorflow/compiler/aot/BUILD3
-rw-r--r--tensorflow/compiler/aot/codegen.cc2
-rw-r--r--tensorflow/compiler/aot/embedded_protocol_buffers.cc2
-rw-r--r--tensorflow/compiler/aot/embedded_protocol_buffers.h4
-rw-r--r--tensorflow/compiler/jit/deadness_analysis.cc50
-rw-r--r--tensorflow/compiler/jit/deadness_analysis_internal.h2
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc145
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc37
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.h5
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass.cc2
-rw-r--r--tensorflow/compiler/jit/xla_device.cc19
-rw-r--r--tensorflow/compiler/jit/xla_device.h7
-rw-r--r--tensorflow/compiler/tests/BUILD2
-rw-r--r--tensorflow/compiler/tests/adam_test.py7
-rw-r--r--tensorflow/compiler/tests/ftrl_test.py51
-rw-r--r--tensorflow/compiler/tests/qr_op_test.py4
-rw-r--r--tensorflow/compiler/tests/randomized_tests.cc19
-rw-r--r--tensorflow/compiler/tests/reduce_ops_test.py2
-rw-r--r--tensorflow/compiler/tf2xla/BUILD2
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc24
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc30
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.h4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/diag_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/qr_op.cc14
-rw-r--r--tensorflow/compiler/tf2xla/kernels/random_ops.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/select_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/slice_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tile_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/training_ops.cc4
-rw-r--r--tensorflow/compiler/tf2xla/lib/BUILD3
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.cc9
-rw-r--r--tensorflow/compiler/tf2xla/lib/qr.cc21
-rw-r--r--tensorflow/compiler/tf2xla/lib/qr.h2
-rw-r--r--tensorflow/compiler/tf2xla/lib/scatter.cc8
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc34
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.h24
-rw-r--r--tensorflow/compiler/tf2xla/lib/while_loop.cc8
-rw-r--r--tensorflow/compiler/tf2xla/lib/while_loop.h12
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.cc5
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.h7
-rw-r--r--tensorflow/compiler/tf2xla/literal_util_test.cc4
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc12
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h8
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc2
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc4
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.h4
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc2
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h2
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.cc8
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.h8
-rw-r--r--tensorflow/compiler/xla/BUILD14
-rw-r--r--tensorflow/compiler/xla/array.h30
-rw-r--r--tensorflow/compiler/xla/array4d.h2
-rw-r--r--tensorflow/compiler/xla/array4d_test.cc24
-rw-r--r--tensorflow/compiler/xla/array_test.cc2
-rw-r--r--tensorflow/compiler/xla/client/BUILD5
-rw-r--r--tensorflow/compiler/xla/client/client.cc8
-rw-r--r--tensorflow/compiler/xla/client/client.h8
-rw-r--r--tensorflow/compiler/xla/client/compile_only_client.cc2
-rw-r--r--tensorflow/compiler/xla/client/compile_only_client.h2
-rw-r--r--tensorflow/compiler/xla/client/executable_build_options.h2
-rw-r--r--tensorflow/compiler/xla/client/lib/BUILD2
-rw-r--r--tensorflow/compiler/xla/client/lib/conv_grad_size_util.h1
-rw-r--r--tensorflow/compiler/xla/client/lib/math.cc3
-rw-r--r--tensorflow/compiler/xla/client/lib/math.h3
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric.cc57
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric_test.cc10
-rw-r--r--tensorflow/compiler/xla/client/lib/pooling.cc56
-rw-r--r--tensorflow/compiler/xla/client/lib/pooling.h29
-rw-r--r--tensorflow/compiler/xla/client/lib/pooling_test.cc6
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc10
-rw-r--r--tensorflow/compiler/xla/client/local_client.h15
-rw-r--r--tensorflow/compiler/xla/client/padding.cc13
-rw-r--r--tensorflow/compiler/xla/client/padding.h15
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc439
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h589
-rw-r--r--tensorflow/compiler/xla/index_util.cc15
-rw-r--r--tensorflow/compiler/xla/index_util.h14
-rw-r--r--tensorflow/compiler/xla/index_util_test.cc8
-rw-r--r--tensorflow/compiler/xla/layout_util.cc12
-rw-r--r--tensorflow/compiler/xla/layout_util.h15
-rw-r--r--tensorflow/compiler/xla/layout_util_test.cc6
-rw-r--r--tensorflow/compiler/xla/literal.cc151
-rw-r--r--tensorflow/compiler/xla/literal.h173
-rw-r--r--tensorflow/compiler/xla/literal_comparison.cc105
-rw-r--r--tensorflow/compiler/xla/literal_test.cc29
-rw-r--r--tensorflow/compiler/xla/literal_util.cc14
-rw-r--r--tensorflow/compiler/xla/literal_util.h49
-rw-r--r--tensorflow/compiler/xla/packed_literal_reader.cc10
-rw-r--r--tensorflow/compiler/xla/python/BUILD2
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc82
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h66
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i22
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.h2
-rw-r--r--tensorflow/compiler/xla/reference_util.cc90
-rw-r--r--tensorflow/compiler/xla/reference_util.h60
-rw-r--r--tensorflow/compiler/xla/service/BUILD59
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc205
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc465
-rw-r--r--tensorflow/compiler/xla/service/backend.cc10
-rw-r--r--tensorflow/compiler/xla/service/backend.h4
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.cc2
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc2
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization.cc9
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc2
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc124
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.h13
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc232
-rw-r--r--tensorflow/compiler/xla/service/buffer_value.h2
-rw-r--r--tensorflow/compiler/xla/service/channel_tracker.h2
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.cc2
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.h4
-rw-r--r--tensorflow/compiler/xla/service/compiler.h2
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD10
-rw-r--r--tensorflow/compiler/xla/service/cpu/buffer_info_util.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/buffer_info_util.h4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc3
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc38
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.h40
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.cc1
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc7
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h7
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc82
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h66
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.cc31
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.h17
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_fork_join.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/shape_partition_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc27
-rw-r--r--tensorflow/compiler/xla/service/cpu/vector_support_library.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/vector_support_library.h4
-rw-r--r--tensorflow/compiler/xla/service/cpu/xfeed_manager.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/xfeed_manager.h5
-rw-r--r--tensorflow/compiler/xla/service/defuser_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/device_memory_allocator.cc2
-rw-r--r--tensorflow/compiler/xla/service/device_memory_allocator.h4
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h2
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h2
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc37
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/executable.cc7
-rw-r--r--tensorflow/compiler/xla/service/executable.h57
-rw-r--r--tensorflow/compiler/xla/service/gather_expander.cc5
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc5
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD51
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_allocations.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_allocations.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h17
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc47
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc55
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc59
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc54
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc26
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h32
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc33
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_fusible.cc84
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_fusible.h49
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc332
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc (renamed from tensorflow/compiler/xla/service/gpu/hlo_schedule.cc)8
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h (renamed from tensorflow/compiler/xla/service/gpu/hlo_schedule.h)16
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc (renamed from tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc)25
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.cc10
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc72
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc34
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.h17
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h10
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc84
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h62
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.cc14
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc61
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc13
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc10
-rw-r--r--tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h9
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_executor_util.cc32
-rw-r--r--tensorflow/compiler/xla/service/gpu/tuple_thunk.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc24
-rw-r--r--tensorflow/compiler/xla/service/hlo_buffer.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc42
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.h39
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils_test.cc65
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc108
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h16
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h302
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc102
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h68
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc50
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc214
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h221
-rw-r--r--tensorflow/compiler/xla/service/hlo_lexer.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.h9
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc29
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability.h13
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.h14
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc35
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.h12
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier_test.cc81
-rw-r--r--tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc42
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h12
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/inliner.cc2
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc5
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h2
-rw-r--r--tensorflow/compiler/xla/service/interpreter/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.cc4
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.h6
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executor.h4
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc2
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/BUILD6
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc16
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h13
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.cc33
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.h33
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc5
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h4
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h4
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc9
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.h13
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h3
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/sort_util.cc39
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc3
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h5
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc2
-rw-r--r--tensorflow/compiler/xla/service/local_service.h4
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer.h2
-rw-r--r--tensorflow/compiler/xla/service/maybe_owning_device_memory.cc41
-rw-r--r--tensorflow/compiler/xla/service/maybe_owning_device_memory.h70
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.cc2
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.h2
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher.h1
-rw-r--r--tensorflow/compiler/xla/service/platform_util.cc28
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/scatter_expander.cc3
-rw-r--r--tensorflow/compiler/xla/service/service.cc33
-rw-r--r--tensorflow/compiler/xla/service/service.h26
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc111
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h66
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc11
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.h10
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h7
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc2
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/tuple_util.cc4
-rw-r--r--tensorflow/compiler/xla/service/tuple_util.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_analysis.cc3
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc1
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/while_util.cc2
-rw-r--r--tensorflow/compiler/xla/service/while_util.h2
-rw-r--r--tensorflow/compiler/xla/shape_tree.h27
-rw-r--r--tensorflow/compiler/xla/shape_util.cc63
-rw-r--r--tensorflow/compiler/xla/shape_util.h74
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc14
-rw-r--r--tensorflow/compiler/xla/sparse_index_array.cc19
-rw-r--r--tensorflow/compiler/xla/sparse_index_array.h19
-rw-r--r--tensorflow/compiler/xla/sparse_index_array_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/BUILD43
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc8
-rw-r--r--tensorflow/compiler/xla/tests/bfloat16_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_simple_test.cc20
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc55
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h116
-rw-r--r--tensorflow/compiler/xla/tests/compilation_cache_test.cc11
-rw-r--r--tensorflow/compiler/xla/tests/copy_test.cc8
-rw-r--r--tensorflow/compiler/xla/tests/deallocation_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/dynamic_ops_test.cc14
-rw-r--r--tensorflow/compiler/xla/tests/floor_ceil_test.cc6
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc14
-rw-r--r--tensorflow/compiler/xla/tests/gather_operation_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/half_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc21
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h17
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.h4
-rw-r--r--tensorflow/compiler/xla/tests/iota_test.cc6
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.h10
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.cc8
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.h10
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc11
-rw-r--r--tensorflow/compiler/xla/tests/pred_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/prng_test.cc18
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc17
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc20
-rw-r--r--tensorflow/compiler/xla/tests/reshape_motion_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/reshape_test.cc46
-rw-r--r--tensorflow/compiler/xla/tests/reverse_test.cc25
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc11
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc12
-rw-r--r--tensorflow/compiler/xla/tests/scatter_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/select_and_scatter_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc19
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.h2
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc6
-rw-r--r--tensorflow/compiler/xla/text_literal_writer.cc5
-rw-r--r--tensorflow/compiler/xla/tools/BUILD7
-rw-r--r--tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc6
-rw-r--r--tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc6
-rw-r--r--tensorflow/compiler/xla/tools/dumped_computation_to_text.cc6
-rw-r--r--tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc6
-rw-r--r--tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc7
-rw-r--r--tensorflow/compiler/xla/tools/replay_computation.cc6
-rw-r--r--tensorflow/compiler/xla/tools/show_signature.cc6
-rw-r--r--tensorflow/compiler/xla/util.cc19
-rw-r--r--tensorflow/compiler/xla/util.h103
-rw-r--r--tensorflow/compiler/xla/util_test.cc43
-rw-r--r--tensorflow/compiler/xla/window_util.cc5
-rw-r--r--tensorflow/compiler/xla/window_util.h6
-rw-r--r--tensorflow/compiler/xla/xla_data.proto14
-rw-r--r--tensorflow/compiler/xrt/BUILD83
-rw-r--r--tensorflow/compiler/xrt/cc/BUILD20
-rw-r--r--tensorflow/compiler/xrt/kernels/BUILD72
-rw-r--r--tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc239
-rw-r--r--tensorflow/compiler/xrt/kernels/xrt_execute_op.cc254
-rw-r--r--tensorflow/compiler/xrt/kernels/xrt_state_ops.cc110
-rw-r--r--tensorflow/compiler/xrt/kernels/xrt_state_ops.h424
-rw-r--r--tensorflow/compiler/xrt/ops/xrt_compile_ops.cc48
-rw-r--r--tensorflow/compiler/xrt/ops/xrt_execute_op.cc44
-rw-r--r--tensorflow/compiler/xrt/ops/xrt_state_ops.cc122
-rw-r--r--tensorflow/compiler/xrt/tests/BUILD65
-rw-r--r--tensorflow/compiler/xrt/tests/raw_api_test.cc421
-rw-r--r--tensorflow/compiler/xrt/xrt.proto78
-rw-r--r--tensorflow/compiler/xrt/xrt_compilation_cache.cc263
-rw-r--r--tensorflow/compiler/xrt/xrt_compilation_cache.h238
-rw-r--r--tensorflow/compiler/xrt/xrt_device.cc46
-rw-r--r--tensorflow/compiler/xrt/xrt_device.h66
-rw-r--r--tensorflow/compiler/xrt/xrt_state.cc458
-rw-r--r--tensorflow/compiler/xrt/xrt_state.h208
-rw-r--r--tensorflow/contrib/autograph/examples/integration_tests/BUILD1
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/BUILD4
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/__init__.py0
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc195
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py7
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py111
-rw-r--r--tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc22
-rw-r--r--tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc3
-rw-r--r--tensorflow/contrib/boosted_trees/proto/tree_config.proto12
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py9
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py9
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py157
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py127
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py41
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py65
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py58
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py41
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py21
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py6
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py24
-rw-r--r--tensorflow/contrib/data/python/ops/stats_ops.py23
-rw-r--r--tensorflow/contrib/distribute/BUILD1
-rw-r--r--tensorflow/contrib/distribute/README.md301
-rw-r--r--tensorflow/contrib/distribute/__init__.py2
-rw-r--r--tensorflow/contrib/distribute/python/BUILD3
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py150
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py12
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops.py68
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops_test.py51
-rw-r--r--tensorflow/contrib/distribute/python/input_ops.py13
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py56
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py105
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py2
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy.py215
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py52
-rw-r--r--tensorflow/contrib/distribute/python/strategy_test_lib.py3
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py107
-rw-r--r--tensorflow/contrib/distribute/python/values.py6
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py5
-rw-r--r--tensorflow/contrib/distributions/BUILD2
-rw-r--r--tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb7
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50.py4
-rw-r--r--tensorflow/contrib/estimator/python/estimator/baseline_test.py6
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util.py33
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util_test.py12
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.cc60
-rw-r--r--tensorflow/contrib/image/ops/image_ops.cc44
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/image_ops_test.py10
-rw-r--r--tensorflow/contrib/image/python/ops/image_ops.py7
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column.py8
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib.py127
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib_test.py10
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md40
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py51
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py14
-rw-r--r--tensorflow/contrib/lite/delegates/eager/BUILD1
-rw-r--r--tensorflow/contrib/lite/examples/android/app/build.gradle6
-rw-r--r--tensorflow/contrib/lite/examples/ios/camera/Podfile2
-rw-r--r--tensorflow/contrib/lite/examples/ios/simple/Podfile2
-rw-r--r--tensorflow/contrib/lite/g3doc/ios.md6
-rw-r--r--tensorflow/contrib/lite/g3doc/models.md77
-rw-r--r--tensorflow/contrib/lite/g3doc/rpi.md34
-rw-r--r--tensorflow/contrib/lite/java/demo/app/build.gradle6
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/build.gradle6
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/arg_min_max.cc11
-rw-r--r--tensorflow/contrib/lite/kernels/batch_to_space_nd.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc113
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc360
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc309
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc40
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc13
-rw-r--r--tensorflow/contrib/lite/kernels/conv_test.cc139
-rw-r--r--tensorflow/contrib/lite/kernels/floor.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/floor_div.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.cc63
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.h20
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h331
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc29
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h7
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h414
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h8
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h426
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc10
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h12
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h909
-rw-r--r--tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc30
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils.h5
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc114
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h5
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.h5
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm.cc20
-rw-r--r--tensorflow/contrib/lite/kernels/local_response_norm.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/logical.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/maximum_minimum.cc9
-rw-r--r--tensorflow/contrib/lite/kernels/mul.cc75
-rw-r--r--tensorflow/contrib/lite/kernels/pad.cc22
-rw-r--r--tensorflow/contrib/lite/kernels/pow.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/reduce.cc324
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/slice.cc26
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_batch_nd.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_depth.cc10
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc108
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc73
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc31
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc20
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc7
-rw-r--r--tensorflow/contrib/lite/python/BUILD3
-rw-r--r--tensorflow/contrib/lite/python/convert.py68
-rw-r--r--tensorflow/contrib/lite/python/convert_test.py89
-rw-r--r--tensorflow/contrib/lite/python/lite.py185
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py306
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py50
-rw-r--r--tensorflow/contrib/lite/toco/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/args.h3
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc8
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md22
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md24
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/python_api.md1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc46
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc3
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc106
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD13
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc167
-rw-r--r--tensorflow/contrib/lite/toco/model.h1
-rw-r--r--tensorflow/contrib/lite/toco/tflite/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc40
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.h7
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export_test.cc54
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc27
-rw-r--r--tensorflow/contrib/lite/toco/toco_flags.proto8
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc10
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/BUILD14
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/README.md6
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD10
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md8
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt1762
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc53
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc229
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h29
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc25
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h11
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/README.md22
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/ios/README.md4
-rw-r--r--tensorflow/contrib/lite/tools/optimize/BUILD14
-rw-r--r--tensorflow/contrib/lite/tools/optimize/g3doc/quantize_weights.md70
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights.cc292
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights.h6
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc108
-rw-r--r--tensorflow/contrib/makefile/proto_text_cc_files.txt1
-rw-r--r--tensorflow/contrib/model_pruning/BUILD2
-rw-r--r--tensorflow/contrib/opt/BUILD16
-rw-r--r--tensorflow/contrib/opt/python/training/matrix_functions.py155
-rw-r--r--tensorflow/contrib/opt/python/training/matrix_functions_test.py63
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo.py98
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo_test.py194
-rw-r--r--tensorflow/contrib/opt/python/training/weight_decay_optimizers.py72
-rw-r--r--tensorflow/contrib/saved_model/BUILD17
-rw-r--r--tensorflow/contrib/saved_model/__init__.py7
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py260
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py293
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py11
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py316
-rw-r--r--tensorflow/contrib/tpu/BUILD2
-rw-r--r--tensorflow/contrib/tpu/ops/cross_replica_ops.cc14
-rw-r--r--tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc3
-rw-r--r--tensorflow/contrib/tpu/python/ops/tpu_ops.py29
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py822
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py289
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py122
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py27
-rw-r--r--tensorflow/contrib/training/python/training/evaluation.py9
-rw-r--r--tensorflow/contrib/training/python/training/evaluation_test.py31
-rw-r--r--tensorflow/core/BUILD70
-rw-r--r--tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt3
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ParseSequenceExample.pbtxt112
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt12
-rw-r--r--tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt14
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ParseSequenceExample.pbtxt4
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.cc14
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc38
-rw-r--r--tensorflow/core/common_runtime/direct_session.h12
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc117
-rw-r--r--tensorflow/core/common_runtime/eager/BUILD5
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc2
-rw-r--r--tensorflow/core/common_runtime/executor.cc7
-rw-r--r--tensorflow/core/common_runtime/function.cc17
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc17
-rw-r--r--tensorflow/core/common_runtime/graph_execution_state.cc41
-rw-r--r--tensorflow/core/common_runtime/graph_execution_state.h7
-rw-r--r--tensorflow/core/common_runtime/lower_while_op.cc427
-rw-r--r--tensorflow/core/common_runtime/lower_while_op.h (renamed from tensorflow/core/util/status_util_test.cc)35
-rw-r--r--tensorflow/core/common_runtime/lower_while_op_test.cc249
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator.h5
-rw-r--r--tensorflow/core/common_runtime/placer.cc10
-rw-r--r--tensorflow/core/common_runtime/placer_test.cc5
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.cc70
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.h5
-rw-r--r--tensorflow/core/common_runtime/rendezvous_util.cc37
-rw-r--r--tensorflow/core/common_runtime/rendezvous_util.h3
-rw-r--r--tensorflow/core/common_runtime/ring_reducer.cc32
-rw-r--r--tensorflow/core/common_runtime/ring_reducer_test.cc9
-rw-r--r--tensorflow/core/debug/debug_io_utils.cc51
-rw-r--r--tensorflow/core/debug/debug_io_utils.h29
-rw-r--r--tensorflow/core/debug/debug_io_utils_test.cc46
-rw-r--r--tensorflow/core/debug/debugger_state_impl.cc3
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc15
-rw-r--r--tensorflow/core/distributed_runtime/tensor_coding.h3
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc6
-rw-r--r--tensorflow/core/framework/dataset.cc7
-rw-r--r--tensorflow/core/framework/dataset.h29
-rw-r--r--tensorflow/core/framework/function_testlib.cc56
-rw-r--r--tensorflow/core/framework/function_testlib.h9
-rw-r--r--tensorflow/core/graph/mkl_graph_util.h2
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD5
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc21
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h2
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc24
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc173
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h7
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc230
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc10
-rw-r--r--tensorflow/core/kernels/BUILD26
-rw-r--r--tensorflow/core/kernels/candidate_sampler_ops.cc6
-rw-r--r--tensorflow/core/kernels/collective_ops.cc46
-rw-r--r--tensorflow/core/kernels/cwise_op_zeta.cc5
-rw-r--r--tensorflow/core/kernels/data/BUILD6
-rw-r--r--tensorflow/core/kernels/data/optimize_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/stats_dataset_ops.cc196
-rw-r--r--tensorflow/core/kernels/data/tensor_dataset_op.cc11
-rw-r--r--tensorflow/core/kernels/data/tensor_slice_dataset_op.cc11
-rw-r--r--tensorflow/core/kernels/debug_ops.h9
-rw-r--r--tensorflow/core/kernels/eigen_benchmark.h298
-rw-r--r--tensorflow/core/kernels/eigen_benchmark_cpu_test.cc402
-rw-r--r--tensorflow/core/kernels/example_parsing_ops.cc165
-rw-r--r--tensorflow/core/kernels/gather_nd_op_cpu_impl.h15
-rw-r--r--tensorflow/core/kernels/list_kernels.cc12
-rw-r--r--tensorflow/core/kernels/list_kernels.cu.cc15
-rw-r--r--tensorflow/core/kernels/list_kernels.h121
-rw-r--r--tensorflow/core/kernels/logistic-loss.h2
-rw-r--r--tensorflow/core/kernels/loss_test.cc174
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc31
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc42
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc41
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc23
-rw-r--r--tensorflow/core/kernels/mkl_softmax_op.cc37
-rw-r--r--tensorflow/core/kernels/poisson-loss.h109
-rw-r--r--tensorflow/core/kernels/qr_op_complex128.cc12
-rw-r--r--tensorflow/core/kernels/qr_op_complex64.cc6
-rw-r--r--tensorflow/core/kernels/qr_op_double.cc12
-rw-r--r--tensorflow/core/kernels/qr_op_float.cc12
-rw-r--r--tensorflow/core/kernels/range_sampler_test.cc22
-rw-r--r--tensorflow/core/kernels/sdca_internal.cc2
-rw-r--r--tensorflow/core/kernels/sdca_ops.cc3
-rw-r--r--tensorflow/core/kernels/set_kernels.cc14
-rw-r--r--tensorflow/core/kernels/sparse_softmax_op.cc2
-rw-r--r--tensorflow/core/kernels/training_ops.cc52
-rw-r--r--tensorflow/core/lib/bfloat16/bfloat16.h4
-rw-r--r--tensorflow/core/lib/core/errors.h20
-rw-r--r--tensorflow/core/lib/core/stringpiece.cc54
-rw-r--r--tensorflow/core/lib/core/stringpiece.h117
-rw-r--r--tensorflow/core/lib/gtl/array_slice.h281
-rw-r--r--tensorflow/core/lib/gtl/array_slice_internal.h269
-rw-r--r--tensorflow/core/lib/gtl/array_slice_test.cc664
-rw-r--r--tensorflow/core/lib/gtl/optional.cc25
-rw-r--r--tensorflow/core/lib/gtl/optional.h853
-rw-r--r--tensorflow/core/lib/gtl/optional_test.cc1098
-rw-r--r--tensorflow/core/lib/strings/strcat.h37
-rw-r--r--tensorflow/core/lib/strings/strcat_test.cc8
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt393
-rw-r--r--tensorflow/core/ops/dataset_ops.cc12
-rw-r--r--tensorflow/core/ops/list_ops.cc51
-rw-r--r--tensorflow/core/ops/logging_ops.cc7
-rw-r--r--tensorflow/core/ops/ops.pbtxt275
-rw-r--r--tensorflow/core/ops/parsing_ops.cc93
-rw-r--r--tensorflow/core/ops/parsing_ops_test.cc82
-rw-r--r--tensorflow/core/ops/sdca_ops.cc2
-rw-r--r--tensorflow/core/platform/default/build_config.bzl6
-rw-r--r--tensorflow/core/protobuf/debug.proto6
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.cc228
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.h3
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing_test.cc2
-rw-r--r--tensorflow/core/util/example_proto_helper.cc53
-rw-r--r--tensorflow/core/util/example_proto_helper.h61
-rw-r--r--tensorflow/core/util/mkl_util.h45
-rw-r--r--tensorflow/core/util/status_util.h36
-rw-r--r--tensorflow/examples/speech_commands/models.py2
-rw-r--r--tensorflow/go/op/wrappers.go238
-rw-r--r--tensorflow/python/compat/compat.py2
-rw-r--r--tensorflow/python/debug/BUILD17
-rw-r--r--tensorflow/python/debug/examples/debug_tflearn_iris.py9
-rw-r--r--tensorflow/python/debug/lib/debug_utils.py12
-rw-r--r--tensorflow/python/debug/wrappers/disk_usage_test.py109
-rw-r--r--tensorflow/python/debug/wrappers/framework.py25
-rw-r--r--tensorflow/python/debug/wrappers/hooks.py5
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper.py5
-rw-r--r--tensorflow/python/distribute/BUILD4
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py213
-rw-r--r--tensorflow/python/distribute/distribute_coordinator_test.py150
-rw-r--r--tensorflow/python/distribute/estimator_training.py4
-rw-r--r--tensorflow/python/eager/backprop.py4
-rw-r--r--tensorflow/python/eager/backprop_test.py27
-rw-r--r--tensorflow/python/eager/context.py24
-rw-r--r--tensorflow/python/eager/core_test.py11
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc42
-rwxr-xr-xtensorflow/python/eager/pywrap_tfe.h2
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc7
-rw-r--r--tensorflow/python/eager/tape.py10
-rw-r--r--tensorflow/python/estimator/canned/baseline_test.py10
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py3
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined.py13
-rw-r--r--tensorflow/python/estimator/canned/dnn_testing_utils.py6
-rw-r--r--tensorflow/python/estimator/canned/linear_testing_utils.py10
-rw-r--r--tensorflow/python/estimator/estimator.py98
-rw-r--r--tensorflow/python/estimator/estimator_test.py118
-rw-r--r--tensorflow/python/estimator/export/export_output.py24
-rw-r--r--tensorflow/python/estimator/export/export_output_test.py89
-rw-r--r--tensorflow/python/estimator/keras.py75
-rw-r--r--tensorflow/python/estimator/model_fn.py113
-rw-r--r--tensorflow/python/estimator/model_fn_test.py51
-rw-r--r--tensorflow/python/framework/error_interpolation.py81
-rw-r--r--tensorflow/python/framework/error_interpolation_test.py111
-rw-r--r--tensorflow/python/framework/errors_impl.py9
-rw-r--r--tensorflow/python/framework/errors_test.py29
-rw-r--r--tensorflow/python/framework/ops.py21
-rw-r--r--tensorflow/python/framework/test_util.py14
-rwxr-xr-xtensorflow/python/keras/BUILD2
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py4
-rw-r--r--tensorflow/python/keras/initializers.py88
-rw-r--r--tensorflow/python/keras/initializers_test.py15
-rw-r--r--tensorflow/python/keras/layers/recurrent.py7
-rw-r--r--tensorflow/python/keras/metrics.py14
-rw-r--r--tensorflow/python/keras/metrics_test.py29
-rw-r--r--tensorflow/python/keras/models.py16
-rw-r--r--tensorflow/python/keras/optimizers.py10
-rw-r--r--tensorflow/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py10
-rw-r--r--tensorflow/python/kernel_tests/list_ops_test.py30
-rw-r--r--tensorflow/python/kernel_tests/parsing_ops_test.py1158
-rw-r--r--tensorflow/python/kernel_tests/sparse_matmul_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/sparse_ops_test.py17
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py8
-rw-r--r--tensorflow/python/ops/array_ops.py3
-rw-r--r--tensorflow/python/ops/collective_ops_test.py14
-rw-r--r--tensorflow/python/ops/custom_gradient.py10
-rw-r--r--tensorflow/python/ops/functional_ops.py2
-rw-r--r--tensorflow/python/ops/init_ops.py109
-rw-r--r--tensorflow/python/ops/init_ops_test.py38
-rw-r--r--tensorflow/python/ops/list_ops.py15
-rw-r--r--tensorflow/python/ops/parsing_ops.py346
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py33
-rw-r--r--tensorflow/python/ops/sparse_ops.py6
-rw-r--r--tensorflow/python/ops/variable_scope.py24
-rw-r--r--tensorflow/python/ops/variables.py46
-rwxr-xr-xtensorflow/python/pywrap_tfe.i2
-rw-r--r--tensorflow/python/tools/api/generator/BUILD3
-rw-r--r--tensorflow/python/tools/api/generator/api_init_files.bzl1
-rw-r--r--tensorflow/python/tools/api/generator/doc_srcs.py2
-rw-r--r--tensorflow/python/tools/optimize_for_inference_lib.py8
-rw-r--r--tensorflow/python/tools/saved_model_cli.py23
-rw-r--r--tensorflow/python/training/checkpointable/util.py5
-rw-r--r--tensorflow/python/training/distribute.py51
-rw-r--r--tensorflow/python/training/ftrl_test.py101
-rw-r--r--tensorflow/python/training/monitored_session.py6
-rw-r--r--tensorflow/python/training/queue_runner_impl.py22
-rw-r--r--tensorflow/python/training/training_util.py2
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc2
-rw-r--r--tensorflow/tensorflow.bzl36
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-variable-aggregation.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.glorot_normal_initializer.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.glorot_uniform_initializer.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.initializers.glorot_normal.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.initializers.glorot_uniform.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.initializers.pbtxt16
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-uniform.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-truncated-normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.glorot_normal.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.glorot_uniform.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.pbtxt16
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_uniform.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.truncated_normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.uniform.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.pbtxt18
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-variable-aggregation.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.glorot_normal_initializer.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.glorot_uniform_initializer.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_normal.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_uniform.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt16
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-uniform.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_normal.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_uniform.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt16
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_uniform.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.uniform.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.pbtxt18
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.train.-queue-runner.pbtxt49
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt16
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.-queue-runner.pbtxt49
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.pbtxt15
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.gpu1
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.rbe.gpu6
-rwxr-xr-xtensorflow/tools/ci_build/ci_parameterized_build.sh4
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu11
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7117
-rw-r--r--tensorflow/tools/docker/Dockerfile.gpu7
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel-jupyter.Dockerfile10
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel.Dockerfile10
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/nvidia-jupyter.Dockerfile7
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/nvidia.Dockerfile7
-rw-r--r--tensorflow/tools/dockerfiles/partials/nvidia-devel.partial.Dockerfile10
-rw-r--r--tensorflow/tools/dockerfiles/partials/nvidia.partial.Dockerfile7
-rw-r--r--tensorflow/tools/docs/generate_lib.py4
-rw-r--r--tensorflow/tools/docs/parser.py2
-rw-r--r--tensorflow/tools/docs/pretty_docs.py66
-rwxr-xr-xtensorflow/workspace.bzl60
-rw-r--r--third_party/clang_toolchain/download_clang.bzl8
-rw-r--r--third_party/gpus/crosstool/CROSSTOOL.tpl4
-rw-r--r--third_party/gpus/cuda_configure.bzl13
-rw-r--r--third_party/toolchains/gpus/crosstool/BUILD20
-rw-r--r--third_party/toolchains/gpus/crosstool/CROSSTOOL1123
-rw-r--r--third_party/toolchains/gpus/cuda/BUILD36
-rw-r--r--third_party/toolchains/gpus/cuda/build_defs.bzl8
-rw-r--r--third_party/toolchains/gpus/cuda/cuda/cuda_config.h6
-rw-r--r--third_party/toolchains/gpus/py/BUILD22
-rw-r--r--tools/bazel.rc5
851 files changed, 27997 insertions, 13481 deletions
diff --git a/CODEOWNERS b/CODEOWNERS
index 113eaf798f..1725a5c471 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -54,9 +54,9 @@
/tensorflow/contrib/slim/ @sguada @thenbasilmanran
/tensorflow/contrib/stateless/ @girving @alextp
/tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank
-/tensorflow/contrib/tensorrt/ @laigd
+/tensorflow/contrib/tensorrt/ @aaroey
# NEED OWNER: /tensorflow/contrib/testing/
/tensorflow/contrib/timeseries/ @allenlavoie
/tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj
/tensorflow/contrib/training/ @joel-shor @ebrevdo
-/tensorflow/contrib/util/ @sherrym \ No newline at end of file
+/tensorflow/contrib/util/ @sherrym
diff --git a/README.md b/README.md
index 91f49f8e95..e3092e551e 100644
--- a/README.md
+++ b/README.md
@@ -90,6 +90,8 @@ The TensorFlow project strives to abide by generally accepted best practices in
| **Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [pypi](https://pypi.org/project/tf-nightly/) |
| **Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
| **Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) |
+| **Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv6l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) |
+| **Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv7l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) |
### Community Supported Builds
diff --git a/configure.py b/configure.py
index 10fee6993e..361bd4764d 100644
--- a/configure.py
+++ b/configure.py
@@ -45,7 +45,7 @@ _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine()
_TF_OPENCL_VERSION = '1.2'
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
-_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15]
+_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16]
_DEFAULT_PROMPT_ASK_ATTEMPTS = 10
@@ -1543,6 +1543,10 @@ def main():
if environ_cp.get('TF_DOWNLOAD_CLANG') != '1':
# Set up which clang we should use as the cuda / host compiler.
set_clang_cuda_compiler_path(environ_cp)
+ else:
+ # Use downloaded LLD for linking.
+ write_to_bazelrc('build:cuda_clang --config=download_clang_use_lld')
+ write_to_bazelrc('test:cuda_clang --config=download_clang_use_lld')
else:
# Set up which gcc nvcc should use as the host compiler
# No need to set this on Windows
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 9cc4c4567b..b5e0a4e98b 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -24,6 +24,10 @@ load(
"gen_api_init_files", # @unused
)
load(
+ "//tensorflow/python/tools/api/generator:api_init_files_v1.bzl",
+ "TENSORFLOW_API_INIT_FILES_V1", # @unused
+)
+load(
"//third_party/ngraph:build_defs.bzl",
"if_ngraph",
)
@@ -589,6 +593,7 @@ gen_api_init_files(
name = "tensorflow_python_api_gen",
srcs = ["api_template.__init__.py"],
api_version = 1,
+ output_files = TENSORFLOW_API_INIT_FILES_V1,
root_init_template = "api_template.__init__.py",
)
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 8a9301d584..2c3a877edf 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -127,6 +127,15 @@ tf_cuda_library(
],
)
+cc_library(
+ name = "c_api_headers",
+ hdrs = [
+ "c_api.h",
+ ],
+ copts = tf_copts(),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
exports_files(
[
"version_script.lds",
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 1adb0458c3..ce038a4b57 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -440,6 +440,15 @@ Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
return Status::OK();
}
+gtl::FlatMap<string, gtl::FlatSet<int>>* FunctionsAcceptingNoneForIndicesMap() {
+ static auto* const m = new gtl::FlatMap<string, gtl::FlatSet<int>>({
+ {"SoftmaxCrossEntropyWithLogits", {1}},
+ {"SparseSoftmaxCrossEntropyWithLogits", {1}},
+ {"FusedBatchNorm", {1, 2, 3, 4}},
+ });
+ return m;
+}
+
} // namespace
// If over kMinAggregateCount gradients are accumulated and the total
@@ -485,10 +494,6 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
VLOG(1) << " " << t;
}
}
- gtl::FlatMap<string, gtl::FlatSet<int>> functions_accept_none_for_indices({
- {"SoftmaxCrossEntropyWithLogits", {1}},
- {"FusedBatchNorm", {1, 2, 3, 4}},
- });
while (!op_stack.empty()) {
const int64 op = op_stack.back();
VLOG(1) << "Popped " << op;
@@ -509,8 +514,8 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
auto grad_it = gradients.find(id);
if (grad_it == gradients.end()) {
auto func_name_it =
- functions_accept_none_for_indices.find(trace.op_type);
- if (func_name_it != functions_accept_none_for_indices.end() &&
+ FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
+ if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() &&
func_name_it->second.find(i) != func_name_it->second.end()) {
out_gradients.push_back(nullptr);
} else {
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index 59b961cdd9..6c29f09cde 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -56,6 +56,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -191,13 +192,13 @@ cc_library(
srcs = ["embedded_protocol_buffers.cc"],
hdrs = ["embedded_protocol_buffers.h"],
deps = [
- "//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
"@llvm//:support",
"@llvm//:target",
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index e77a8fecf0..2b1ce34b37 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_replace.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
@@ -30,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc
index 1401aae758..f1e8e5c084 100644
--- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc
+++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc
@@ -111,7 +111,7 @@ GetTargetMachineFromTriple(StringPiece target_triple) {
StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
StringPiece target_triple,
- gtl::ArraySlice<ProtobufToEmbed> protobufs_to_embed) {
+ absl::Span<const ProtobufToEmbed> protobufs_to_embed) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<llvm::TargetMachine> target_machine,
GetTargetMachineFromTriple(target_triple));
diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h
index 4e194a6aba..4f940c0197 100644
--- a/tensorflow/compiler/aot/embedded_protocol_buffers.h
+++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h
@@ -20,8 +20,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_
#define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
@@ -84,7 +84,7 @@ struct ProtobufToEmbed {
// EmbeddedProtocolBuffers instance.
StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
StringPiece target_triple,
- gtl::ArraySlice<ProtobufToEmbed> protobufs_to_embed);
+ absl::Span<const ProtobufToEmbed> protobufs_to_embed);
} // namespace tfcompile
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index fe28502f69..82aa03810b 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -108,7 +108,7 @@ class Predicate {
virtual string ToString() const = 0;
int64 hash() const { return hash_; }
- virtual gtl::ArraySlice<Predicate*> GetOperands() const = 0;
+ virtual absl::Span<Predicate* const> GetOperands() const = 0;
virtual Kind kind() const = 0;
virtual ~Predicate() {}
@@ -129,7 +129,7 @@ class Predicate {
};
int64 HashPredicateSequence(Predicate::Kind kind,
- gtl::ArraySlice<Predicate*> preds) {
+ absl::Span<Predicate* const> preds) {
int64 hash = ::tensorflow::hash<Predicate::Kind>()(kind);
for (Predicate* pred : preds) {
hash = Hash64Combine(hash, pred->hash());
@@ -159,8 +159,10 @@ class AndPredicate : public Predicate {
Kind kind() const override { return Kind::kAnd; }
- gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
- gtl::ArraySlice<Predicate*> operands() const { return operands_; }
+ absl::Span<Predicate* const> GetOperands() const override {
+ return operands_;
+ }
+ absl::Span<Predicate* const> operands() const { return operands_; }
private:
std::vector<Predicate*> operands_;
@@ -187,8 +189,10 @@ class OrPredicate : public Predicate {
}
Kind kind() const override { return Kind::kOr; }
- gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
- gtl::ArraySlice<Predicate*> operands() const { return operands_; }
+ absl::Span<Predicate* const> GetOperands() const override {
+ return operands_;
+ }
+ absl::Span<Predicate* const> operands() const { return operands_; }
private:
std::vector<Predicate*> operands_;
@@ -207,7 +211,9 @@ class NotPredicate : public Predicate {
Kind kind() const override { return Kind::kNot; }
Predicate* operand() const { return operands_[0]; }
- gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
+ absl::Span<Predicate* const> GetOperands() const override {
+ return operands_;
+ }
private:
std::array<Predicate*, 1> operands_;
@@ -240,7 +246,9 @@ class AndRecurrencePredicate : public Predicate {
Kind kind() const override { return Kind::kAndRecurrence; }
- gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
+ absl::Span<Predicate* const> GetOperands() const override {
+ return operands_;
+ }
private:
std::array<Predicate*, 2> operands_;
@@ -264,7 +272,7 @@ class SymbolPredicate : public Predicate {
}
Kind kind() const override { return Kind::kSymbol; }
- gtl::ArraySlice<Predicate*> GetOperands() const override { return {}; }
+ absl::Span<Predicate* const> GetOperands() const override { return {}; }
// If `must_be_true()` is true this SymbolPredicate represents the proposition
// "tensor_id() is live and evaluates to true".
@@ -313,11 +321,11 @@ template <typename FunctionTy>
// them.
class PredicateFactory {
public:
- Predicate* MakeAndPredicate(gtl::ArraySlice<Predicate*> operands) {
+ Predicate* MakeAndPredicate(absl::Span<Predicate* const> operands) {
return MakeAndOrImpl(operands, /*is_and=*/true);
}
- Predicate* MakeOrPredicate(gtl::ArraySlice<Predicate*> operands) {
+ Predicate* MakeOrPredicate(absl::Span<Predicate* const> operands) {
return MakeAndOrImpl(operands, /*is_and=*/false);
}
@@ -374,7 +382,7 @@ class PredicateFactory {
new PredicateT(std::forward<Args>(args)...));
}
- Predicate* MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands, bool is_and);
+ Predicate* MakeAndOrImpl(absl::Span<Predicate* const> operands, bool is_and);
// Predicate instances are interned, meaning that there is only a single
// instance of a Predicate object with a given content. This makes checking
@@ -387,7 +395,7 @@ class PredicateFactory {
// for the owning pointers to predicate instances.
using SignatureForAndOr =
- std::pair<Predicate::Kind, gtl::ArraySlice<Predicate*>>;
+ std::pair<Predicate::Kind, absl::Span<Predicate* const>>;
using SignatureForNot = Predicate*;
using SignatureForAndRec = std::pair<Predicate*, Predicate*>;
using SignatureForSymbol = std::pair<SafeTensorId, bool>;
@@ -422,8 +430,8 @@ class PredicateFactory {
};
// Common code to create AndPredicate or OrPredicate instances.
-Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands,
- bool is_and) {
+Predicate* PredicateFactory::MakeAndOrImpl(
+ absl::Span<Predicate* const> operands, bool is_and) {
Predicate::Kind pred_kind =
is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
gtl::FlatSet<Predicate*> simplified_ops_set;
@@ -474,7 +482,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands,
// NB! Because we'll use a non-owning reference to simplified_ops in the
// key for interned_and_or_instances_ we need to be careful to std::move()
// it all the way through.
- gtl::ArraySlice<Predicate*> operands_slice = simplified_ops;
+ absl::Span<Predicate* const> operands_slice = simplified_ops;
std::unique_ptr<Predicate> new_pred =
is_and ? Make<AndPredicate>(std::move(simplified_ops))
: Make<OrPredicate>(std::move(simplified_ops));
@@ -496,7 +504,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
: graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
Status Populate();
- Status PopulateWithReversePostOrder(gtl::ArraySlice<Node*> rpo);
+ Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo);
bool HasInputsWithMismatchingDeadness(const Node& node) override;
void Print() const override;
gtl::FlatMap<TensorId, string, TensorId::Hasher> PredicateMapAsString() const;
@@ -527,7 +535,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
}
}
- void SetPredicate(Node* n, gtl::ArraySlice<int> output_idxs, Predicate* pred,
+ void SetPredicate(Node* n, absl::Span<const int> output_idxs, Predicate* pred,
std::vector<bool>* should_revisit) {
for (int output_idx : output_idxs) {
SetPredicate(n, output_idx, pred, should_revisit);
@@ -625,7 +633,7 @@ Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory,
}
std::vector<Predicate*> and_ops;
- gtl::ArraySlice<Predicate*> recurrent_pred_ops =
+ absl::Span<Predicate* const> recurrent_pred_ops =
backedge_predicate->GetOperands();
bool found_sym = false;
@@ -784,7 +792,7 @@ Status DeadnessAnalysisImpl::Populate() {
}
Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
- gtl::ArraySlice<Node*> rpo) {
+ absl::Span<Node* const> rpo) {
// This an abstract interpretation over the deadness propagation semantics of
// the graph executor.
//
@@ -924,7 +932,7 @@ Status ComputePredicates(const Graph& graph,
}
Status ComputePredicates(const Graph& graph,
- gtl::ArraySlice<Node*> reverse_post_order,
+ absl::Span<Node* const> reverse_post_order,
PredicateMapTy* out_predicate_map) {
DeadnessAnalysisImpl impl(&graph);
TF_RETURN_IF_ERROR(impl.PopulateWithReversePostOrder(reverse_post_order));
diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h
index 401d6e406a..3df2679c62 100644
--- a/tensorflow/compiler/jit/deadness_analysis_internal.h
+++ b/tensorflow/compiler/jit/deadness_analysis_internal.h
@@ -32,7 +32,7 @@ Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map);
// specified in `reverse_post_order` which must be a valid RPO for the graph
// minus NextIteration->Merge edges.
Status ComputePredicates(const Graph& graph,
- gtl::ArraySlice<Node*> reverse_post_order,
+ absl::Span<Node* const> reverse_post_order,
PredicateMapTy* out_predicate_map);
} // namespace deadness_analysis_internal
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
index b3600fc48b..7bc0ef0303 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
@@ -379,7 +379,7 @@ Node* InputShaped(const GraphDefBuilder::Options& opts) {
return ops::SourceOp("InputTestShaped", opts);
}
-Node* KnownShapeBase(DataType dtype, const gtl::ArraySlice<int>& shape,
+Node* KnownShapeBase(DataType dtype, absl::Span<const int> shape,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const",
@@ -394,7 +394,7 @@ Node* KnownShapeBase(DataType dtype, const gtl::ArraySlice<int>& shape,
.FinalizeBuilder(&node_builder);
}
-Node* KnownShape(const gtl::ArraySlice<int>& shape,
+Node* KnownShape(absl::Span<const int> shape,
const GraphDefBuilder::Options& opts) {
return KnownShapeBase(DT_FLOAT, shape, opts);
}
@@ -417,8 +417,7 @@ Node* KeyPlaceholder(const string& call_node,
}
Node* RecvAtHost(ops::NodeOut key_input, const string& cluster,
- const string& oc_cluster,
- const gtl::ArraySlice<DataType>& dtypes,
+ const string& oc_cluster, absl::Span<const DataType> dtypes,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
string key =
@@ -892,13 +891,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"C:o:0", "c:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
- {"shapes", gtl::ArraySlice<DataType>({})},
+ {"shapes", absl::Span<const DataType>({})},
{"_outside_compilation_subgraph", "O1"}},
{"c"}},
},
@@ -1038,26 +1037,26 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
{{"outside_compilation_O2_host_compute"},
"XlaHostCompute",
{"F:o:0", "D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
{"ancestors",
- gtl::ArraySlice<string>({"outside_compilation_O1_host_compute"})},
+ absl::Span<const string>({"outside_compilation_O1_host_compute"})},
{"key", "host_compute_channel_F1_O2"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O2"},
- {"shapes", gtl::ArraySlice<DataType>({})},
+ {"shapes", absl::Span<const DataType>({})},
{"_outside_compilation_subgraph", "O2"}},
{"F", "outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"C:o:0", "D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
- {"shapes", gtl::ArraySlice<DataType>({})},
+ {"shapes", absl::Span<const DataType>({})},
{"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
@@ -1190,13 +1189,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"C:o:0", "D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
- {"shapes", gtl::ArraySlice<DataType>({})},
+ {"shapes", absl::Span<const DataType>({})},
{"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
@@ -1213,13 +1212,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"G:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F2_O1"},
{"shape_inference_graph", ""},
{"shapes",
- gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})},
+ absl::Span<const TensorShapeProto>({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}}},
},
{{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}});
@@ -1364,13 +1363,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"C:o:0", "D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
@@ -1386,13 +1385,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"G:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F2_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F2_O1"},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"}}},
},
{{"i_0_retval", "I:o:0"}});
@@ -1495,13 +1494,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{},
- {{"Tinputs", gtl::ArraySlice<DataType>({})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
{"shapes",
- gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})},
+ absl::Span<const TensorShapeProto>({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}}},
},
{{"f_0_retval", "F:o:0"}});
@@ -1579,13 +1578,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{},
- {{"Tinputs", gtl::ArraySlice<DataType>({})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
{"shapes",
- gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})},
+ absl::Span<const TensorShapeProto>({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
@@ -1661,12 +1660,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"}}},
},
{{"f_0_retval", "F:o:0"}});
@@ -1742,12 +1741,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"}}},
},
{{"f_0_retval", "F:o:0"}});
@@ -1846,13 +1845,13 @@ TEST(EncapsulateSubgraphsTest,
{{"outside_compilation_O2_host_compute"},
"XlaHostCompute",
{"F:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O2"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O2"},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O2"}}},
},
{{"h_0_retval", "H:o:0"}});
@@ -1955,13 +1954,13 @@ TEST(EncapsulateSubgraphsTest,
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"}}},
},
{{"h_0_retval", "H:o:0"}});
@@ -2066,37 +2065,37 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"}}},
{{"outside_compilation_O2_host_compute"},
"XlaHostCompute",
{"D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({})},
{"ancestors",
- gtl::ArraySlice<string>({"outside_compilation_O1_host_compute"})},
+ absl::Span<const string>({"outside_compilation_O1_host_compute"})},
{"key", "host_compute_channel_F1_O2"},
{"shape_inference_graph", ""},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O2"}},
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O3_host_compute"},
"XlaHostCompute",
{"D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({})},
{"ancestors",
- gtl::ArraySlice<string>({"outside_compilation_O1_host_compute",
- "outside_compilation_O2_host_compute"})},
+ absl::Span<const string>({"outside_compilation_O1_host_compute",
+ "outside_compilation_O2_host_compute"})},
{"key", "host_compute_channel_F1_O3"},
{"shape_inference_graph", ""},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O3"}},
{"outside_compilation_O1_host_compute",
"outside_compilation_O2_host_compute"}}},
@@ -2272,13 +2271,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"c:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
- {"shapes", gtl::ArraySlice<DataType>({})},
+ {"shapes", absl::Span<const DataType>({})},
{"_outside_compilation_subgraph", "O1"}},
{"c"}},
},
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index fde4135bf7..b6f2f632f7 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
@@ -57,18 +56,17 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
->stream->parent()
->platform()
->id();
- } else {
- platform_id_ = nullptr;
+ } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata_).ok()) {
+ use_multiple_streams_ = xla_device_metadata_->UseMultipleStreams();
+ platform_id_ = xla_device_metadata_->platform()->id();
}
}
Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx,
XlaCompilationCache** cache) {
- const XlaDevice::Metadata* metadata;
- Status s = XlaDevice::GetMetadata(ctx, &metadata);
- if (s.ok()) {
- *cache = new XlaCompilationCache(metadata->client(),
- metadata->jit_device_type());
+ if (xla_device_metadata_) {
+ *cache = new XlaCompilationCache(xla_device_metadata_->client(),
+ xla_device_metadata_->jit_device_type());
return Status::OK();
}
@@ -117,18 +115,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
// this is more obviously correct.)
core::ScopedUnref cache_ref(cache);
- const XlaDevice::Metadata* metadata = nullptr;
- Status s = XlaDevice::GetMetadata(ctx, &metadata);
- bool allocate_xla_tensors = s.ok();
- bool use_multiple_streams = s.ok() && metadata->UseMultipleStreams();
-
- // Get the platform_id_ for XLA_* devices.
- if (platform_id_ == nullptr) {
- if (s.ok()) {
- platform_id_ = metadata->platform()->id();
- }
- }
-
std::map<int, OptionalTensor> variables =
SnapshotResourceVariables(ctx, resources_);
@@ -146,7 +132,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
// (which local_xla_allocator above uses) as on an XlaDevice, this is a
// dummy allocator that returns XlaTensor objects. The XlaCompiler needs a
// real allocator to allocate real buffers.
- if (allocate_xla_tensors) {
+ if (xla_device_metadata_) {
xla_allocator = client->backend().memory_allocator();
} else {
xla_allocator = &local_xla_allocator;
@@ -163,8 +149,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
options.device_allocator = xla_allocator;
- if (metadata) {
- options.shape_representation_fn = metadata->shape_representation_fn();
+ if (xla_device_metadata_) {
+ options.shape_representation_fn =
+ xla_device_metadata_->shape_representation_fn();
}
const XlaCompiler::CompilationResult* kernel;
@@ -192,7 +179,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
VLOG(1) << "Executing XLA Computation...";
XlaComputationLaunchContext launch_context(
- client, xla_allocator, allocate_xla_tensors, use_multiple_streams);
+ client, xla_allocator,
+ /*allocate_xla_tensors=*/xla_device_metadata_ != nullptr,
+ use_multiple_streams_);
launch_context.PopulateInputs(ctx, kernel, variables);
// Execute the computation.
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h
index bf1e990668..e0f10e9817 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.h
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
+#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -58,7 +59,9 @@ class XlaLocalLaunchBase : public OpKernel {
DeviceType device_type_;
NameAttrList function_;
- se::Platform::Id platform_id_;
+ se::Platform::Id platform_id_ = nullptr;
+ bool use_multiple_streams_ = false;
+ const XlaDevice::Metadata* xla_device_metadata_ = nullptr;
};
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc
index 3a9a8c4988..a8f09bfa50 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass.cc
@@ -22,7 +22,7 @@ limitations under the License.
namespace tensorflow {
namespace {
Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result,
- gtl::ArraySlice<Node*> post_order) {
+ absl::Span<Node* const> post_order) {
// Find nodes that have at least one user outside their cluster that expects
// hostmem output. These nodes should be cloned to outside the cluster to
// avoid the device-host copy we'd otherwise need.
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 50c902fdfc..f31879a2bc 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -185,14 +185,13 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
return device_type_;
}
-/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx,
- const Metadata** metadata) {
+/*static*/ Status XlaDevice::GetMetadataFromDevice(
+ DeviceBase* device, const XlaDevice::Metadata** metadata) {
*metadata = nullptr;
- XlaDevice* xla_device =
- dynamic_cast<XlaDevice*>(ctx->device()->UnderlyingDevice());
+ XlaDevice* xla_device = dynamic_cast<XlaDevice*>(device->UnderlyingDevice());
if (xla_device == nullptr) {
return errors::Internal(
- "Cannot get XLA metadata from non-XLA device \"", ctx->device()->name(),
+ "Cannot get XLA metadata from non-XLA device \"", device->name(),
"\". GetMetadata must only be called on an XLA device. Either an "
"internal bug has been triggered, or an XLA-specific op has been "
"placed on the wrong device.");
@@ -201,6 +200,16 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
return Status::OK();
}
+/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx,
+ const Metadata** metadata) {
+ return GetMetadataFromDevice(ctx->device(), metadata);
+}
+
+/* static */ Status XlaDevice::GetMetadata(OpKernelConstruction* ctx,
+ const Metadata** metadata) {
+ return GetMetadataFromDevice(ctx->device(), metadata);
+}
+
XlaDevice::XlaDevice(
const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index dbf35f349f..92891ffa8c 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -88,6 +88,10 @@ class XlaDevice : public LocalDevice {
// Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`.
static Status GetMetadata(OpKernelContext* ctx, const Metadata** metadata);
+ // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`.
+ static Status GetMetadata(OpKernelConstruction* ctx,
+ const Metadata** metadata);
+
// Factory function. 'platform_name' is the name of the XLA platform.
// 'device_name' is the name of the Tensorflow device to create.
// 'jit_device_name' is the name of the corresponding JIT device.
@@ -158,6 +162,9 @@ class XlaDevice : public LocalDevice {
xla::StatusOr<XlaDeviceContext*> GetDeviceContextLocked()
EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ static Status GetMetadataFromDevice(DeviceBase* device,
+ const XlaDevice::Metadata** metadata);
+
mutex mu_;
// The metadata of this XlaDevice.
const Metadata xla_metadata_;
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index cf02926e06..34defe1c7a 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -251,6 +251,7 @@ tf_xla_py_test(
tf_xla_py_test(
name = "matrix_triangular_solve_op_test",
size = "small",
+ timeout = "moderate",
srcs = ["matrix_triangular_solve_op_test.py"],
tags = ["optonly"],
deps = [
@@ -572,6 +573,7 @@ tf_xla_py_test(
tf_xla_py_test(
name = "matrix_band_part_test",
size = "medium",
+ timeout = "long",
srcs = ["matrix_band_part_test.py"],
tags = ["optonly"],
deps = [
diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py
index 0d2e4d0296..df0f21471a 100644
--- a/tensorflow/compiler/tests/adam_test.py
+++ b/tensorflow/compiler/tests/adam_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
@@ -53,7 +54,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
def testBasic(self):
for dtype in self.float_types:
# TODO: test fails for float16 due to excessive precision requirements.
- if dtype == np.float16:
+ if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue
with self.test_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
@@ -95,7 +96,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
def testTensorLearningRate(self):
for dtype in self.float_types:
# TODO: test fails for float16 due to excessive precision requirements.
- if dtype == np.float16:
+ if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue
with self.test_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
@@ -137,7 +138,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
def testSharing(self):
for dtype in self.float_types:
# TODO: test fails for float16 due to excessive precision requirements.
- if dtype == np.float16:
+ if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue
with self.test_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py
index 7ca50b02d9..f1b87a5ffb 100644
--- a/tensorflow/compiler/tests/ftrl_test.py
+++ b/tensorflow/compiler/tests/ftrl_test.py
@@ -29,7 +29,6 @@ from tensorflow.python.training import adagrad
from tensorflow.python.training import ftrl
from tensorflow.python.training import gradient_descent
-
class FtrlOptimizerTest(xla_test.XLATestCase):
def initVariableAndGradient(self, dtype):
@@ -196,7 +195,11 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
# Validate updated params
self.assertAllCloseAccordingToType(
- np.array([-7.66718769, -10.91273689]), var0.eval(), rtol=1e-4)
+ np.array([-7.66718769, -10.91273689]),
+ var0.eval(),
+ rtol=1e-4,
+ bfloat16_rtol=1e-1,
+ bfloat16_atol=1e-1)
self.assertAllCloseAccordingToType(
np.array([-0.93460727, -1.86147261]), var1.eval(), rtol=1e-4)
@@ -259,9 +262,49 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
# Validate updated params
self.assertAllCloseAccordingToType(
- np.array([-0.21931979, -0.40642974]), var0.eval(), rtol=1e-4)
+ np.array([-0.22578996, -0.44345799]), var0.eval(), rtol=1e-4)
self.assertAllCloseAccordingToType(
- np.array([-0.0282721, -0.07188385]), var1.eval(), rtol=1e-4)
+ np.array([-0.14378493, -0.13229476]), var1.eval(), rtol=1e-4)
+
+ def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
+ """Verifies that l2 shrinkage in FTRL does not change lr schedule."""
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.1, 0.2], dtype=dtype)
+
+ opt0 = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0,
+ l2_shrinkage_regularization_strength=0.1)
+ opt1 = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0)
+ update0 = opt0.apply_gradients([(grads0, var0)])
+ update1 = opt1.apply_gradients([(grads1, var1)])
+ variables.global_variables_initializer().run()
+
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([1.0, 2.0], var1.eval())
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update0.run()
+ update1.run()
+
+ # var0 is experiencing L2 shrinkage so it should be smaller than var1
+ # in magnitude.
+ self.assertTrue((var0.eval()**2 < var1.eval()**2).all())
+ accum0 = list(opt0._slots["accum"].values())[0].eval()
+ accum1 = list(opt1._slots["accum"].values())[0].eval()
+ # L2 shrinkage should not change how we update grad accumulator.
+ self.assertAllCloseAccordingToType(accum0, accum1)
# When variables are initialized with Zero, FTRL-Proximal has two properties:
# 1. Without L1&L2 but with fixed learning rate, FTRL-Proximal is identical
diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py
index 3a268978bf..236b1b881d 100644
--- a/tensorflow/compiler/tests/qr_op_test.py
+++ b/tensorflow/compiler/tests/qr_op_test.py
@@ -101,8 +101,8 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
@parameterized.parameters(*PARAMS)
def testQR(self, rows, cols, dtype):
- # TODO(b/111317468): implement full_matrices=False, test other types.
- for full_matrices in [True]:
+ # TODO(b/111317468): Test other types.
+ for full_matrices in [True, False]:
# Only tests the (3, 2) case for small numbers of rows/columns.
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
self._test(dtype, batch_dims + (rows, cols), full_matrices)
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index c0ea242044..0faf0fd8ed 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -275,13 +275,13 @@ class OpTest : public ::testing::Test {
// Select a random element from 'candidates'.
template <typename T>
- T Choose(gtl::ArraySlice<T> candidates);
+ T Choose(absl::Span<const T> candidates);
static constexpr int kDefaultMaxRank = 5;
static constexpr int64 kDefaultMaxDimensionSize = 256LL;
// Returns true if 'dims' have a size less than tf_xla_max_tensor_size.
- bool TensorSizeIsOk(gtl::ArraySlice<int64> dims);
+ bool TensorSizeIsOk(absl::Span<const int64> dims);
// Returns a random dimension size, in the range [min, max).
int64 RandomDim(int64 min = 0, int64 max = kDefaultMaxDimensionSize);
@@ -307,11 +307,11 @@ class OpTest : public ::testing::Test {
// of the type's range. If the shape is omitted, a random shape is used.
// TODO(phawkins): generalize this code to a caller-supplied distribution.
Tensor RandomTensor(DataType dtype, bool needs_unique_values,
- gtl::ArraySlice<int64> shape);
+ absl::Span<const int64> shape);
Tensor RandomTensor(DataType dtype);
// Like RandomTensor, but uses values >= 0.
- Tensor RandomNonNegativeTensor(DataType dtype, gtl::ArraySlice<int64> shape);
+ Tensor RandomNonNegativeTensor(DataType dtype, absl::Span<const int64> shape);
Tensor RandomNonNegativeTensor(DataType dtype);
// Returns a random subset of the integers in the range [0, rank), suitable
@@ -415,7 +415,7 @@ void OpTest::Repeatedly(const std::function<TestResult(void)>& fn) {
}
template <typename T>
-T OpTest::Choose(gtl::ArraySlice<T> candidates) {
+T OpTest::Choose(absl::Span<const T> candidates) {
std::uniform_int_distribution<size_t> d(0, candidates.size() - 1);
return candidates[d(generator())];
}
@@ -425,7 +425,7 @@ int64 OpTest::RandomDim(int64 min, int64 max) {
return size_distribution(generator());
}
-bool OpTest::TensorSizeIsOk(gtl::ArraySlice<int64> dims) {
+bool OpTest::TensorSizeIsOk(absl::Span<const int64> dims) {
int64 size = 1LL;
for (int64 dim : dims) {
size *= dim;
@@ -451,7 +451,7 @@ std::vector<int64> OpTest::RandomDims(int min_rank, int max_rank,
}
Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
- gtl::ArraySlice<int64> shape) {
+ absl::Span<const int64> shape) {
Tensor tensor(dtype, TensorShape(shape));
switch (dtype) {
case DT_FLOAT: {
@@ -548,7 +548,7 @@ Tensor OpTest::RandomTensor(DataType dtype) {
}
Tensor OpTest::RandomNonNegativeTensor(DataType dtype,
- gtl::ArraySlice<int64> shape) {
+ absl::Span<const int64> shape) {
Tensor tensor(dtype, TensorShape(shape));
switch (dtype) {
case DT_FLOAT: {
@@ -1884,7 +1884,8 @@ TEST_F(OpTest, DynamicStitch) {
for (int i = 0; i < n; ++i) {
TensorShape shape(index_dims[i]);
Tensor t = test::AsTensor<int32>(
- gtl::ArraySlice<int32>(indices, pos, shape.num_elements()), shape);
+ absl::Span<const int32>(indices).subspan(pos, shape.num_elements()),
+ shape);
builder.Input(t);
pos += t.NumElements();
}
diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py
index 5ae5b1bc1d..132c59c32c 100644
--- a/tensorflow/compiler/tests/reduce_ops_test.py
+++ b/tensorflow/compiler/tests/reduce_ops_test.py
@@ -219,7 +219,7 @@ class ReduceOpPrecisionTest(xla_test.XLATestCase):
bf16_max = np.float32(dtypes.bfloat16.max)
f32_max = dtypes.float32.max
- value = min(bf16_max, f32_max - bf16_max)
+ value = min(bf16_max, f32_max - bf16_max) / 2
self._testReduceSum(
dtypes.bfloat16.as_numpy_dtype(value), dtypes.bfloat16.as_numpy_dtype,
itertools.permutations([bf16_max, value, bf16_max * (-1.0)], 3))
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 92e577bb7b..0797b2cb17 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -214,6 +214,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
alwayslink = 1,
)
@@ -239,6 +240,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
index cc52057f21..c068a4110c 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
@@ -805,11 +805,11 @@ TEST(FunctionalizeControlFlow, Complex) {
auto assign = ops::AssignAddVariableOp(
scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx);
- auto one =
- ops::Const<int32>(scope.WithOpName("outer/inner/One")
- .WithControlDependencies(
- gtl::ArraySlice<Operation>{assign.operation}),
- 1);
+ auto one = ops::Const<int32>(
+ scope.WithOpName("outer/inner/One")
+ .WithControlDependencies(
+ absl::Span<const Operation>{assign.operation}),
+ 1);
auto add_j =
ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
@@ -823,7 +823,7 @@ TEST(FunctionalizeControlFlow, Complex) {
scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
auto add_i =
ops::Add(scope.WithOpName("outer/add")
- .WithControlDependencies(gtl::ArraySlice<Operation>{
+ .WithControlDependencies(absl::Span<const Operation>{
exit_j.output.op(), exit_k.output.op()}),
identity_i, one_outer);
auto next_iteration_i =
@@ -929,7 +929,7 @@ TEST(FunctionalizeControlFlow, Complex) {
scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
auto add_i =
ops::Add(scope.WithOpName("outer/add")
- .WithControlDependencies(gtl::ArraySlice<Operation>{
+ .WithControlDependencies(absl::Span<const Operation>{
while_op[0].op(), while_op[1].op()}),
identity_i, one_outer);
@@ -991,11 +991,11 @@ TEST(FunctionalizeControlFlow, Complex) {
auto assign = ops::AssignAddVariableOp(
scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx);
- auto one =
- ops::Const<int32>(scope.WithOpName("outer/inner/One")
- .WithControlDependencies(
- gtl::ArraySlice<Operation>{assign.operation}),
- 1);
+ auto one = ops::Const<int32>(
+ scope.WithOpName("outer/inner/One")
+ .WithControlDependencies(
+ absl::Span<const Operation>{assign.operation}),
+ 1);
auto add_j =
ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index c1438f893f..4c776fb178 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -117,6 +117,7 @@ tf_kernel_library(
":while_op",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:batch_dot",
diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
index 48f2a005ab..edced6bc0e 100644
--- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
@@ -23,7 +23,7 @@ namespace {
void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
DataType input_dtype, const TensorShape& input_tensor_shape,
- gtl::ArraySlice<int64> block_shape,
+ absl::Span<const int64> block_shape,
const xla::Literal& crops) {
const int input_rank = input_tensor_shape.dims();
const gtl::InlinedVector<int64, 4> input_shape =
@@ -34,7 +34,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
ctx, input_rank >= 1 + block_rank,
errors::InvalidArgument("input rank should be >= ", 1 + block_rank,
" instead of ", input_rank));
- gtl::ArraySlice<int64> remainder_shape(input_shape);
+ absl::Span<const int64> remainder_shape(input_shape);
remainder_shape.remove_prefix(1 + block_rank);
OP_REQUIRES(
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index 2c328102e0..df17da4c1c 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -30,21 +30,21 @@ namespace {
// A subclass of a XlaBinaryOp must build the computation that
// describes the (tensor,tensor)->tensor function to apply to each element of
// the input.
-#define XLA_MAKE_BINARY(NAME, HLO) \
- class NAME##Op : public XlaBinaryOp { \
- public: \
- explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \
- xla::XlaOp Computation( \
- XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \
- const gtl::ArraySlice<int64>& lhs_shape, const xla::XlaOp& rhs, \
- const gtl::ArraySlice<int64>& rhs_shape, \
- const BCast& broadcast_helper, \
- const std::vector<int64>& extend_dimensions) override { \
- xla::XlaBuilder* b = ctx->builder(); \
- (void)b; \
- return HLO; \
- } \
- }; \
+#define XLA_MAKE_BINARY(NAME, HLO) \
+ class NAME##Op : public XlaBinaryOp { \
+ public: \
+ explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \
+ xla::XlaOp Computation( \
+ XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \
+ const absl::Span<const int64>& lhs_shape, const xla::XlaOp& rhs, \
+ const absl::Span<const int64>& rhs_shape, \
+ const BCast& broadcast_helper, \
+ const std::vector<int64>& extend_dimensions) override { \
+ xla::XlaBuilder* b = ctx->builder(); \
+ (void)b; \
+ return HLO; \
+ } \
+ }; \
REGISTER_XLA_OP(Name(#NAME), NAME##Op)
XLA_MAKE_BINARY(Add, xla::Add(lhs, rhs, extend_dimensions));
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
index a5b870f8db..6653944a91 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
@@ -57,8 +57,8 @@ class XlaBinaryOp : public XlaOpKernel {
// in the XLA documentation.
virtual xla::XlaOp Computation(
XlaOpKernelContext* ctx, const xla::XlaOp& lhs,
- const gtl::ArraySlice<int64>& lhs_shape, const xla::XlaOp& rhs,
- const gtl::ArraySlice<int64>& rhs_shape, const BCast& broadcast_helper,
+ const absl::Span<const int64>& lhs_shape, const xla::XlaOp& rhs,
+ const absl::Span<const int64>& rhs_shape, const BCast& broadcast_helper,
const std::vector<int64>& extend_dimensions) = 0;
void Compile(XlaOpKernelContext* ctx) override;
diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
index 70c3eaf66b..49c12fc232 100644
--- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
@@ -29,7 +29,7 @@ namespace {
// Create a diagonal / batch diagonal matrix with 'input' on the diagonal.
xla::XlaOp CreateDiagonal(xla::XlaOp input, int64 last_dim_size,
- gtl::ArraySlice<int64> other_dims,
+ absl::Span<const int64> other_dims,
xla::PrimitiveType element_type) {
xla::XlaBuilder* builder = input.builder();
// Create two matrices that have the following forms, and compare them:
@@ -177,7 +177,7 @@ class MatrixDiagOp : public XlaOpKernel {
int last_dim = dims.size() - 1;
int64 last_dim_size = input_shape.dim_size(last_dim);
- tensorflow::gtl::ArraySlice<int64> other_dims(dims);
+ absl::Span<const int64> other_dims(dims);
other_dims.remove_suffix(1);
xla::XlaOp input = ctx->Input(0);
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index 8e071bf0b7..d9a0257b70 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -78,7 +78,7 @@ struct ResizeConvolutionDims {
std::vector<int64> stride;
};
ResizeConvolutionDims ComputeResizeConvolutionParameters(
- gtl::ArraySlice<int64> in_size, gtl::ArraySlice<int64> out_size,
+ absl::Span<const int64> in_size, absl::Span<const int64> out_size,
bool align_corners) {
CHECK_EQ(in_size.size(), out_size.size());
int num_spatial_dims = in_size.size();
@@ -147,7 +147,7 @@ std::vector<float> Make1DKernel(int64 n) {
const int64 kMax2DKernelSize = 16;
xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
- gtl::ArraySlice<int64> kernel_size,
+ absl::Span<const int64> kernel_size,
int64 channels) {
xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
@@ -165,7 +165,7 @@ xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
}
xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder,
- gtl::ArraySlice<int64> kernel_size,
+ absl::Span<const int64> kernel_size,
int64 channels, int64 dim) {
xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc
index de9068a640..7ea0afc1f5 100644
--- a/tensorflow/compiler/tf2xla/kernels/qr_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc
@@ -23,15 +23,10 @@ namespace {
class QROp : public XlaOpKernel {
public:
explicit QROp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
- bool full_matrices;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices));
- OP_REQUIRES(
- ctx, full_matrices,
- errors::Unimplemented("full_matrices=False case of QR decomposition is "
- "not implemented in TF/XLA"));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_));
}
void Compile(XlaOpKernelContext* ctx) override {
- auto result = QRDecomposition(ctx->Input(0));
+ auto result = QRDecomposition(ctx->Input(0), full_matrices_);
if (!result.ok()) {
ctx->SetStatus(result.status());
return;
@@ -39,6 +34,11 @@ class QROp : public XlaOpKernel {
ctx->SetOutput(0, result.ValueOrDie().q);
ctx->SetOutput(1, result.ValueOrDie().r);
}
+
+ private:
+ // If true, compute full-sized q and r. If false, compute only the leading P
+ // columns of q.
+ bool full_matrices_;
};
REGISTER_XLA_OP(Name("Qr").TypeConstraint("T", kFloatTypes), QROp);
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index 2da9340625..afd5986846 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -155,7 +155,8 @@ class RandomShuffleOp : public XlaOpKernel {
xla::XlaOp indices = xla::Iota(builder, xla::S32, n);
// Swap the indices at i and swaps[i].
- auto swap_body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
+ auto swap_body_fn = [&](xla::XlaOp i,
+ absl::Span<const xla::XlaOp> loop_vars,
xla::XlaBuilder* builder)
-> xla::StatusOr<std::vector<xla::XlaOp>> {
auto swaps = loop_vars[0];
diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc
index d9578eca5b..9e4c57c9bf 100644
--- a/tensorflow/compiler/tf2xla/kernels/select_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc
@@ -66,7 +66,7 @@ class SelectOp : public XlaOpKernel {
// XLA. It seems we have to broadcast on the left and then Reshape
// to get the dimensions in the right order.
const auto dim_sizes = then_shape.dim_sizes();
- gtl::ArraySlice<int64> bdims = dim_sizes;
+ absl::Span<const int64> bdims = dim_sizes;
bdims.remove_prefix(1);
cond_handle = xla::Broadcast(cond_handle, bdims);
diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
index 6adc3c58de..537b71f3c0 100644
--- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
@@ -15,6 +15,7 @@ limitations under the License.
// XLA-specific Slice Op.
+#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mem.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
index 7327258c31..b7b4f3a546 100644
--- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
@@ -23,7 +23,7 @@ namespace {
void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
DataType input_dtype, const TensorShape& input_tensor_shape,
- gtl::ArraySlice<int64> block_shape,
+ absl::Span<const int64> block_shape,
const xla::Literal& paddings) {
const int input_rank = input_tensor_shape.dims();
const gtl::InlinedVector<int64, 4> input_shape =
@@ -34,7 +34,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
ctx, input_rank >= 1 + block_rank,
errors::InvalidArgument("input rank should be >= ", 1 + block_rank,
" instead of ", input_rank));
- gtl::ArraySlice<int64> remainder_shape(input_shape);
+ absl::Span<const int64> remainder_shape(input_shape);
remainder_shape.remove_prefix(1 + block_rank);
OP_REQUIRES(
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index 1062399d91..472d4744d7 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/strided_slice_op.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mem.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index be1814d8e3..bb114d1aed 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -122,7 +122,7 @@ Status GetTensorArrayShape(const XlaResource* resource,
// relevant slice of 'operand'.
xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand,
const xla::XlaOp& update,
- const gtl::ArraySlice<int64>& update_dims,
+ absl::Span<const int64> update_dims,
const xla::XlaOp& start_indices) {
xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims);
xla::XlaOp sum = xla::Add(current, update);
diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
index 2c7213f322..93d5996b5e 100644
--- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
@@ -16,6 +16,7 @@ limitations under the License.
// XLA-specific Tile Op.
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@@ -26,7 +27,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index be5e911386..7077c2e3a5 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -688,7 +688,7 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
}
// grad_to_use = grad + 2 * l2_shrinkage * var
- // new_accum = accum + grad_to_use * grad_to_use
+ // new_accum = accum + grad * grad
// linear += grad_to_use -
// (new_accum^(-lr_power) - accum^(-lr_power)) / lr * var
// quadratic = (new_accum^(-lr_power) / lr) + 2 * l2
@@ -704,7 +704,7 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
grad_to_use = grad;
}
- xla::XlaOp new_accum = accum + xla::Square(grad_to_use);
+ xla::XlaOp new_accum = accum + xla::Square(grad);
xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power);
xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power);
linear = linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var;
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index 99511e9914..9365d203f0 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -104,6 +104,7 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -166,6 +167,7 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -204,5 +206,6 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
index 67fb56510c..c50a8de33e 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -56,14 +56,15 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
const int n_dims = xla::ShapeUtil::Rank(a_shape);
const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(a_shape.dimensions()),
- /*pos=*/0,
- /*len=*/n_dims - 2);
+ auto major_dims = xla::AsInt64Slice(a_shape.dimensions())
+ .subspan(
+ /*pos=*/0,
+ /*len=*/n_dims - 2);
xla::XlaOp l = xla::ZerosLike(a);
// Construct the for loop body to iterate over rows.
- auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
+ auto body_fn = [&](xla::XlaOp i, absl::Span<const xla::XlaOp> loop_vars,
xla::XlaBuilder* body_builder)
-> xla::StatusOr<std::vector<xla::XlaOp>> {
xla::Shape col_shape;
diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc
index b6f30d8d49..0a140fa93c 100644
--- a/tensorflow/compiler/tf2xla/lib/qr.cc
+++ b/tensorflow/compiler/tf2xla/lib/qr.cc
@@ -65,9 +65,9 @@ namespace {
// return (v, tau, beta)
// TODO(phawkins): LAPACK's xLARFG implementation has code for handling
// overflows in the norm/beta calculations. Perhaps do the same here.
-xla::Status House(xla::XlaOp x, xla::XlaOp k, gtl::ArraySlice<int64> batch_dims,
- const int64 m, xla::XlaOp* v, xla::XlaOp* tau,
- xla::XlaOp* beta) {
+xla::Status House(xla::XlaOp x, xla::XlaOp k,
+ absl::Span<const int64> batch_dims, const int64 m,
+ xla::XlaOp* v, xla::XlaOp* tau, xla::XlaOp* beta) {
xla::XlaBuilder* const builder = x.builder();
TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
const xla::PrimitiveType type = x_shape.element_type();
@@ -173,7 +173,7 @@ xla::StatusOr<QRBlockResult> QRBlock(
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
auto qr_body_fn =
- [&](xla::XlaOp j, gtl::ArraySlice<xla::XlaOp> values,
+ [&](xla::XlaOp j, absl::Span<const xla::XlaOp> values,
xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
auto a = values[0];
auto vs = values[1];
@@ -255,7 +255,7 @@ xla::StatusOr<QRBlockResult> QRBlock(
// There is no need to return Y since at termination of the loop it is equal to
// vs.
xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
- xla::PrimitiveType type, gtl::ArraySlice<int64> batch_dims, xla::XlaOp vs,
+ xla::PrimitiveType type, absl::Span<const int64> batch_dims, xla::XlaOp vs,
xla::XlaOp taus, int64 m, int64 n,
xla::PrecisionConfigProto::Precision precision) {
std::vector<int64> batch_dim_indices(batch_dims.size());
@@ -263,7 +263,7 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
int64 n_index = batch_dims.size() + 1;
auto body_fn =
- [&](xla::XlaOp j, gtl::ArraySlice<xla::XlaOp> values,
+ [&](xla::XlaOp j, absl::Span<const xla::XlaOp> values,
xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
auto w = values[0];
auto y = values[1];
@@ -331,7 +331,7 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
// TODO(phawkins): consider using UT transformations (in the form I - V U V')
// rather than WY transformations.
xla::StatusOr<QRDecompositionResult> QRDecomposition(
- xla::XlaOp a, int64 block_size,
+ xla::XlaOp a, bool full_matrices, int64 block_size,
xla::PrecisionConfigProto::Precision precision) {
xla::XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
@@ -396,6 +396,13 @@ xla::StatusOr<QRDecompositionResult> QRDecomposition(
q = UpdateSliceInMinorDims(q, q_panel, {0, i});
}
QRDecompositionResult result;
+
+ // full_matrices is false when only a partial result in needed. Slice to the
+ // needed dimensions here.
+ if (!full_matrices) {
+ q = SliceInMinorDims(q, {0, 0}, {m, p});
+ a = SliceInMinorDims(a, {0, 0}, {p, n});
+ }
result.q = q;
result.r = a;
return result;
diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h
index 05565477b6..8a389fb7b0 100644
--- a/tensorflow/compiler/tf2xla/lib/qr.h
+++ b/tensorflow/compiler/tf2xla/lib/qr.h
@@ -34,7 +34,7 @@ struct QRDecompositionResult {
};
xla::StatusOr<QRDecompositionResult> QRDecomposition(
- xla::XlaOp a, int64 block_size = 128,
+ xla::XlaOp a, bool full_matrices, int64 block_size = 128,
xla::PrecisionConfigProto::Precision precision =
xla::PrecisionConfigProto::HIGHEST);
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc
index bafe5099f2..38dfde165d 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.cc
+++ b/tensorflow/compiler/tf2xla/lib/scatter.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
@@ -40,9 +40,9 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer));
TF_RETURN_IF_ERROR(builder->GetShape(updates).status());
TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices));
- gtl::ArraySlice<int64> indices_dims =
+ absl::Span<const int64> indices_dims =
xla::AsInt64Slice(indices_shape.dimensions());
- gtl::ArraySlice<int64> buffer_dims =
+ absl::Span<const int64> buffer_dims =
xla::AsInt64Slice(buffer_shape.dimensions());
// If the indices are N-dimensional, the minor dimension of indices contains
@@ -107,7 +107,7 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
// index = dynamic-slice(indices, i)
// update = dynamic-slice(updates, i)
// buffer = dynamic-update-slice(buffer, update, index)
- auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
+ auto body_fn = [&](xla::XlaOp i, absl::Span<const xla::XlaOp> loop_vars,
xla::XlaBuilder* body_builder) {
auto indices = loop_vars[0];
auto updates = loop_vars[1];
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index 8b5beba383..c267848524 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -113,8 +113,8 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
return xla::ConstantLiteral(builder, literal);
}
-xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice<int64> start,
- gtl::ArraySlice<int64> end) {
+xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span<const int64> start,
+ absl::Span<const int64> end) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_RET_CHECK(start.size() == end.size());
@@ -124,9 +124,10 @@ xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice<int64> start,
const int64 n_dims = xla::ShapeUtil::Rank(shape);
TF_RET_CHECK(n_minor_dims <= n_dims);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
- /*pos=*/0,
- /*len=*/n_dims - n_minor_dims);
+ auto major_dims = xla::AsInt64Slice(shape.dimensions())
+ .subspan(
+ /*pos=*/0,
+ /*len=*/n_dims - n_minor_dims);
// Prepends 0s in the major dim
std::vector<int64> padded_start(n_dims, 0);
@@ -143,8 +144,8 @@ xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice<int64> start,
});
}
-std::vector<int64> ConcatVectors(gtl::ArraySlice<int64> xs,
- gtl::ArraySlice<int64> ys) {
+std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
+ absl::Span<const int64> ys) {
std::vector<int64> output(xs.size() + ys.size());
std::copy(xs.begin(), xs.end(), output.begin());
std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
@@ -152,8 +153,8 @@ std::vector<int64> ConcatVectors(gtl::ArraySlice<int64> xs,
}
xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
- gtl::ArraySlice<xla::XlaOp> starts,
- gtl::ArraySlice<int64> sizes) {
+ absl::Span<const xla::XlaOp> starts,
+ absl::Span<const int64> sizes) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
@@ -161,9 +162,10 @@ xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
int64 n_minor_dims = starts.size();
TF_RET_CHECK(n_minor_dims == sizes.size());
TF_RET_CHECK(n_minor_dims <= n_dims);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
- /*pos=*/0,
- /*len=*/n_dims - sizes.size());
+ auto major_dims = xla::AsInt64Slice(shape.dimensions())
+ .subspan(
+ /*pos=*/0,
+ /*len=*/n_dims - sizes.size());
auto padded_starts = PrependZerosInMajorDims(x, starts);
auto padded_sizes = ConcatVectors(major_dims, sizes);
return xla::DynamicSlice(x, padded_starts, padded_sizes);
@@ -171,7 +173,7 @@ xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
}
xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update,
- gtl::ArraySlice<int64> start) {
+ absl::Span<const int64> start) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
// TODO(phawkins): make int64 work on all backends, remove the int32 cast.
@@ -189,7 +191,7 @@ xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update,
}
xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
- gtl::ArraySlice<int64> start) {
+ absl::Span<const int64> start) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
@@ -204,13 +206,13 @@ xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
}
xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
- gtl::ArraySlice<xla::XlaOp> starts) {
+ absl::Span<const xla::XlaOp> starts) {
auto padded_starts = PrependZerosInMajorDims(x, starts);
return xla::DynamicUpdateSlice(x, update, padded_starts);
}
xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
- gtl::ArraySlice<xla::XlaOp> starts) {
+ absl::Span<const xla::XlaOp> starts) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h
index b4905c9528..80e9e5b002 100644
--- a/tensorflow/compiler/tf2xla/lib/util.h
+++ b/tensorflow/compiler/tf2xla/lib/util.h
@@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
@@ -31,7 +31,7 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
// Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros
// prepended until the array is length n_dims.
xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
- gtl::ArraySlice<xla::XlaOp> starts);
+ absl::Span<const xla::XlaOp> starts);
// Returns a integer scalar constant of 'type' with 'value'.
// If 'type' is complex, returns a real value with zero imaginary component.
@@ -41,33 +41,33 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
// Builds a vector of zeros of length rank(x) with the last values being
// those in `starts`.
xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
- gtl::ArraySlice<xla::XlaOp> starts);
+ absl::Span<const xla::XlaOp> starts);
// Performs a slice in the minor dimensions of a Tensor.
-xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice<int64> start,
- gtl::ArraySlice<int64> end);
+xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span<const int64> start,
+ absl::Span<const int64> end);
// Returns the concatenation of `xs` and `ys`.
-std::vector<int64> ConcatVectors(gtl::ArraySlice<int64> xs,
- gtl::ArraySlice<int64> ys);
+std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
+ absl::Span<const int64> ys);
// Performs a dynamic slice in the minor dimensions of a Tensor.
xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
- gtl::ArraySlice<xla::XlaOp> starts,
- gtl::ArraySlice<int64> sizes);
+ absl::Span<const xla::XlaOp> starts,
+ absl::Span<const int64> sizes);
// Updates a slice of 'x', i.e.,
// x[start[0], ..., start[n]] = update
xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update,
- gtl::ArraySlice<int64> start);
+ absl::Span<const int64> start);
// Updates a slice of 'x', where 'start' contains a list of minor dimensions:
// x[..., start[0], ..., start[n]] = update
xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
- gtl::ArraySlice<int64> start);
+ absl::Span<const int64> start);
xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
- gtl::ArraySlice<xla::XlaOp> starts);
+ absl::Span<const xla::XlaOp> starts);
// Transposes a stack of matrices `x` by swapping the last two dimensions.
xla::XlaOp TransposeInMinorDims(xla::XlaOp x);
diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc
index d64394f140..5300e2c878 100644
--- a/tensorflow/compiler/tf2xla/lib/while_loop.cc
+++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc
@@ -24,7 +24,7 @@ namespace tensorflow {
xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
const LoopConditionFunction& condition_function,
const LoopBodyFunction& body_function,
- gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
+ absl::Span<const xla::XlaOp> initial_values, StringPiece name,
xla::XlaBuilder* builder) {
int arity = initial_values.size();
std::vector<xla::Shape> var_shapes;
@@ -84,15 +84,15 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
int64 num_iterations, xla::PrimitiveType num_iterations_type,
const ForEachIndexBodyFunction& body_function,
- gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
+ absl::Span<const xla::XlaOp> initial_values, StringPiece name,
xla::XlaBuilder* builder) {
auto while_cond_fn =
- [&](gtl::ArraySlice<xla::XlaOp> values,
+ [&](absl::Span<const xla::XlaOp> values,
xla::XlaBuilder* cond_builder) -> xla::StatusOr<xla::XlaOp> {
return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type,
num_iterations));
};
- auto while_body_fn = [&](gtl::ArraySlice<xla::XlaOp> values,
+ auto while_body_fn = [&](absl::Span<const xla::XlaOp> values,
xla::XlaBuilder* body_builder)
-> xla::StatusOr<std::vector<xla::XlaOp>> {
xla::XlaOp iteration = values[0];
diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h
index 9493b1f109..115ebf390d 100644
--- a/tensorflow/compiler/tf2xla/lib/while_loop.h
+++ b/tensorflow/compiler/tf2xla/lib/while_loop.h
@@ -19,24 +19,24 @@ limitations under the License.
#include <functional>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
// Function that builds a loop condition. Takes as input a sequence of input
// values, and returns a boolean value representing if the condition succeeds.
-typedef std::function<xla::StatusOr<xla::XlaOp>(gtl::ArraySlice<xla::XlaOp>,
+typedef std::function<xla::StatusOr<xla::XlaOp>(absl::Span<const xla::XlaOp>,
xla::XlaBuilder*)>
LoopConditionFunction;
// Function that builds a loop body. Takes as input a sequence of input values
// and returns a sequence of output values.
typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
- gtl::ArraySlice<xla::XlaOp>, xla::XlaBuilder*)>
+ absl::Span<const xla::XlaOp>, xla::XlaBuilder*)>
LoopBodyFunction;
// Helper function for building an XLA while loop, where the values carried by
@@ -50,7 +50,7 @@ typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
const LoopConditionFunction& condition_function,
const LoopBodyFunction& body_function,
- gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
+ absl::Span<const xla::XlaOp> initial_values, StringPiece name,
xla::XlaBuilder* builder);
// Builds an XLA loop that repeats a computation `num_iterations` times.
@@ -59,13 +59,13 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
// (current iteration number, loop-carried values), and returns an updated
// vector of the loop-carried values.
typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
- xla::XlaOp, gtl::ArraySlice<xla::XlaOp>, xla::XlaBuilder*)>
+ xla::XlaOp, absl::Span<const xla::XlaOp>, xla::XlaBuilder*)>
ForEachIndexBodyFunction;
xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
int64 num_iterations, xla::PrimitiveType num_iterations_type,
const ForEachIndexBodyFunction& body_function,
- gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
+ absl::Span<const xla::XlaOp> initial_values, StringPiece name,
xla::XlaBuilder* builder);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc
index 77da1bf29c..20103ec3ae 100644
--- a/tensorflow/compiler/tf2xla/literal_util.cc
+++ b/tensorflow/compiler/tf2xla/literal_util.cc
@@ -49,9 +49,8 @@ Status HostTensorToMutableBorrowingLiteral(
return Status::OK();
}
-Status HostTensorsToBorrowingLiteralTuple(
- tensorflow::gtl::ArraySlice<Tensor> host_tensors,
- xla::BorrowingLiteral* literal) {
+Status HostTensorsToBorrowingLiteralTuple(absl::Span<const Tensor> host_tensors,
+ xla::BorrowingLiteral* literal) {
std::vector<const char*> buf_ptrs;
buf_ptrs.reserve(host_tensors.size());
std::vector<xla::Shape> tensor_shapes(host_tensors.size());
diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h
index 09d6fa8116..1db7470ee2 100644
--- a/tensorflow/compiler/tf2xla/literal_util.h
+++ b/tensorflow/compiler/tf2xla/literal_util.h
@@ -18,11 +18,11 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
@@ -43,9 +43,8 @@ Status HostTensorToMutableBorrowingLiteral(
// Returns a BorrowingLiteral tuple that utilizes the same underlying buffers
// owned by 'host_tensors'.
-Status HostTensorsToBorrowingLiteralTuple(
- tensorflow::gtl::ArraySlice<Tensor> host_tensors,
- xla::BorrowingLiteral* literal);
+Status HostTensorsToBorrowingLiteralTuple(absl::Span<const Tensor> host_tensors,
+ xla::BorrowingLiteral* literal);
// Copies 'literal' to freshly allocated 'host_tensor', which is allocated of
// type <target_type>.
diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc
index a3404c2b3d..7dc16b5a46 100644
--- a/tensorflow/compiler/tf2xla/literal_util_test.cc
+++ b/tensorflow/compiler/tf2xla/literal_util_test.cc
@@ -28,7 +28,7 @@ TEST(LiteralUtil, LiteralToHostTensor) {
{
std::vector<int64> int64_values = {1, 2, 3};
std::unique_ptr<xla::Literal> int64_values_literal =
- xla::LiteralUtil::CreateR1(gtl::ArraySlice<int64>(int64_values));
+ xla::LiteralUtil::CreateR1(absl::Span<const int64>(int64_values));
Tensor host_tensor;
EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor)
@@ -49,7 +49,7 @@ TEST(LiteralUtil, LiteralToHostTensor) {
Tensor host_tensor;
std::vector<int32> int32_values = {10, 11};
std::unique_ptr<xla::Literal> int32_values_literal =
- xla::LiteralUtil::CreateR1(gtl::ArraySlice<int32>(int32_values));
+ xla::LiteralUtil::CreateR1(absl::Span<const int32>(int32_values));
EXPECT_TRUE(
LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor)
.ok());
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index aa2a521d98..0c300c282e 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -835,8 +835,8 @@ Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key,
namespace {
-void SetTransfer(const string& key, gtl::ArraySlice<DataType> types,
- gtl::ArraySlice<TensorShape> shapes,
+void SetTransfer(const string& key, absl::Span<const DataType> types,
+ absl::Span<const TensorShape> shapes,
tf2xla::HostTransferMetadata* transfer) {
transfer->set_key(key);
CHECK(types.size() == shapes.size());
@@ -850,8 +850,8 @@ void SetTransfer(const string& key, gtl::ArraySlice<DataType> types,
} // namespace
Status XlaCompiler::SetDeviceToHostMetadata(
- const string& key, gtl::ArraySlice<DataType> types,
- gtl::ArraySlice<TensorShape> shapes) {
+ const string& key, absl::Span<const DataType> types,
+ absl::Span<const TensorShape> shapes) {
if (host_compute_sends_.find(key) != host_compute_sends_.end()) {
return errors::InvalidArgument(
"Duplicate calls to SetDeviceToHostMetadata with key ", key);
@@ -877,8 +877,8 @@ Status XlaCompiler::GetDeviceToHostShapes(
}
Status XlaCompiler::SetHostToDeviceMetadata(
- const string& key, gtl::ArraySlice<DataType> types,
- gtl::ArraySlice<TensorShape> shapes) {
+ const string& key, absl::Span<const DataType> types,
+ absl::Span<const TensorShape> shapes) {
if (host_compute_recvs_.find(key) != host_compute_sends_.end()) {
return errors::InvalidArgument(
"Duplicate calls to SetHostToDeviceMetadata with key ", key);
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 9e2c64fd42..8f4a9858ed 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -351,8 +351,8 @@ class XlaCompiler {
// Sets the shapes and types for the device to host transfer associated with
// 'key'.
Status SetDeviceToHostMetadata(const string& key,
- gtl::ArraySlice<DataType> types,
- gtl::ArraySlice<TensorShape> shapes);
+ absl::Span<const DataType> types,
+ absl::Span<const TensorShape> shapes);
// Gets the shapes the device to host transfer associated with 'key'.
Status GetDeviceToHostShapes(const string& key,
@@ -361,8 +361,8 @@ class XlaCompiler {
// Sets the shapes and types for the host to device transfer associated with
// 'key'.
Status SetHostToDeviceMetadata(const string& key,
- gtl::ArraySlice<DataType> types,
- gtl::ArraySlice<TensorShape> shapes);
+ absl::Span<const DataType> types,
+ absl::Span<const TensorShape> shapes);
// In order to avoid deadlocks from dependencies in host computations, it can
// be necessary to enforce a partial order on the execution of HostCompute
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index e36039ada5..24a4b92b45 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
@@ -31,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index 8efb3d55c8..9a34cd8c6a 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
@@ -31,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
@@ -119,7 +119,7 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type,
}
/* static */ Status XlaHelpers::ReshapeLiteral(
- const xla::Literal& input, gtl::ArraySlice<int64> dimensions,
+ const xla::Literal& input, absl::Span<const int64> dimensions,
xla::Literal* output) {
if (xla::ShapeUtil::IsTuple(input.shape())) {
return errors::InvalidArgument("ReshapeLiteral does not support tuples.");
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h
index e6522157a5..39578144ca 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.h
+++ b/tensorflow/compiler/tf2xla/xla_helpers.h
@@ -18,10 +18,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_
#define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_
+#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
@@ -50,7 +50,7 @@ class XlaHelpers {
// Reshapes literal 'input' to have 'shape'. Both the original shape and
// 'shape' must contain the same number of elements.
static Status ReshapeLiteral(const xla::Literal& input,
- gtl::ArraySlice<int64> shape,
+ absl::Span<const int64> shape,
xla::Literal* output);
// Returns the argmax of `input` along `axis`. `output_type` is the type to
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 9e8f5f2a1a..1499c99ed1 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -119,7 +119,7 @@ Status XlaOpKernelContext::ConstantInput(StringPiece name,
}
Status XlaOpKernelContext::ConstantInputReshaped(
- int index, gtl::ArraySlice<int64> new_dims,
+ int index, absl::Span<const int64> new_dims,
xla::Literal* constant_literal) {
const Tensor& tensor = context_->input(index);
TensorShape new_shape(new_dims);
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 3e26ba4f01..45cfa7da74 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -113,7 +113,7 @@ class XlaOpKernelContext {
// cannot be evaluated, e.g., because it depends on unbound parameters,
// returns a non-Ok status. If InputShape(index).num_elements() !=
// new_shape.num_elements(), returns an error status.
- Status ConstantInputReshaped(int index, gtl::ArraySlice<int64> new_shape,
+ Status ConstantInputReshaped(int index, absl::Span<const int64> new_dims,
xla::Literal* constant_literal);
// Converts a constant scalar int32 or int64 tensor into an int64.
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index 2f3a4cd3b5..dae2d956ca 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -105,7 +105,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
/* static */ void XlaOpRegistry::RegisterBackend(
const string& compilation_device_name,
- gtl::ArraySlice<DataType> supported_types, BackendOpFilter op_filter) {
+ absl::Span<const DataType> supported_types, BackendOpFilter op_filter) {
XlaOpRegistry& registry = Instance();
mutex_lock lock(registry.mutex_);
auto result = registry.backends_.emplace(compilation_device_name, Backend());
@@ -382,7 +382,7 @@ XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) {
}
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(
- gtl::ArraySlice<StringPiece> devices) {
+ absl::Span<const StringPiece> devices) {
registration_->has_device_whitelist = true;
for (StringPiece device : devices) {
registration_->device_whitelist.emplace(device);
@@ -415,7 +415,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
}
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
- StringPiece attr_name, gtl::ArraySlice<DataType> allowed) {
+ StringPiece attr_name, absl::Span<const DataType> allowed) {
std::set<DataType>& types =
registration_->type_constraints[string(attr_name)];
for (DataType t : allowed) {
@@ -452,7 +452,7 @@ XlaOpRegistrar::XlaOpRegistrar(
}
XlaBackendRegistrar::XlaBackendRegistrar(
- StringPiece name, gtl::ArraySlice<DataType> types,
+ StringPiece name, absl::Span<const DataType> types,
XlaOpRegistry::BackendOpFilter op_filter) {
XlaOpRegistry& registry = XlaOpRegistry::Instance();
registry.RegisterBackend(string(name), types, op_filter);
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index 6ce0e2580b..c640842dc0 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -94,7 +94,7 @@ class XlaOpRegistry {
// the device; it may optionally modify the KernelDef.
typedef bool (*BackendOpFilter)(KernelDef* kdef);
static void RegisterBackend(const string& compilation_device_name,
- gtl::ArraySlice<DataType> supported_types,
+ absl::Span<const DataType> supported_types,
BackendOpFilter op_filter);
// Returns the names of the registered backends.
@@ -236,7 +236,7 @@ class XlaOpRegistrationBuilder {
// Specifies a whitelist of devices on which the operator may run.
XlaOpRegistrationBuilder& Device(StringPiece devices);
- XlaOpRegistrationBuilder& Device(gtl::ArraySlice<StringPiece> devices);
+ XlaOpRegistrationBuilder& Device(absl::Span<const StringPiece> devices);
// Specifies a type constraint for a type variable attribute. Each constraint
// specifies the set of types that the type variable may assume.
@@ -244,7 +244,7 @@ class XlaOpRegistrationBuilder {
DataType allowed);
XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name,
- gtl::ArraySlice<DataType> allowed);
+ absl::Span<const DataType> allowed);
// Specifies that a dummy copy of this operator should not be registered on
// XLA_* devices, but may be used during compilation.
@@ -288,7 +288,7 @@ class XlaOpRegistrar {
class XlaBackendRegistrar {
public:
- XlaBackendRegistrar(StringPiece name, gtl::ArraySlice<DataType> types,
+ XlaBackendRegistrar(StringPiece name, absl::Span<const DataType> types,
XlaOpRegistry::BackendOpFilter op_filter = nullptr);
};
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index ddeba1d91d..76e36f3c46 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -176,6 +176,7 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -246,6 +247,7 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -307,6 +309,7 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -349,6 +352,7 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -404,6 +408,7 @@ cc_library(
":types",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -472,6 +477,7 @@ cc_library(
":types",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -483,6 +489,7 @@ tf_cc_test(
":test",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/types:span",
],
)
@@ -510,6 +517,7 @@ cc_library(
":util",
":xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
],
)
@@ -577,6 +585,7 @@ cc_library(
":xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -607,6 +616,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -648,6 +658,7 @@ cc_library(
":xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -672,6 +683,7 @@ cc_library(
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
@@ -702,8 +714,8 @@ cc_library(
":array2d",
":shape_util",
":xla_data_proto",
- "//tensorflow/core:lib",
"@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h
index c8e483712e..58cc157585 100644
--- a/tensorflow/compiler/xla/array.h
+++ b/tensorflow/compiler/xla/array.h
@@ -29,10 +29,10 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/bits.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -97,12 +97,11 @@ class Array {
using value_type = T;
// Creates a new array with the specified dimensions.
- explicit Array(tensorflow::gtl::ArraySlice<int64> sizes)
- : Array(sizes, T()) {}
+ explicit Array(absl::Span<const int64> sizes) : Array(sizes, T()) {}
// Creates a new array with the specified dimensions and specified value for
// every cell.
- Array(tensorflow::gtl::ArraySlice<int64> sizes, T value)
+ Array(absl::Span<const int64> sizes, T value)
: sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) {
Fill(value);
}
@@ -301,7 +300,7 @@ class Array {
// Invokes a callback with the (indices, value_ptr) for each cell in the
// array.
- void Each(std::function<void(tensorflow::gtl::ArraySlice<int64>, T*)> f) {
+ void Each(std::function<void(absl::Span<const int64>, T*)> f) {
std::vector<int64> index(sizes_.size());
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
f(index, &values_[i]);
@@ -309,8 +308,7 @@ class Array {
}
// Invokes a callback with the (indices, value) for each cell in the array.
- void Each(
- std::function<void(tensorflow::gtl::ArraySlice<int64>, T)> f) const {
+ void Each(std::function<void(absl::Span<const int64>, T)> f) const {
std::vector<int64> index(sizes_.size());
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
f(index, values_[i]);
@@ -320,8 +318,7 @@ class Array {
// Invokes a callback with the (indices, value_ptr) for each cell in the
// array. If a callback returns a non-OK status, returns that else returns
// Status::OK().
- Status EachStatus(
- std::function<Status(tensorflow::gtl::ArraySlice<int64>, T*)> f) {
+ Status EachStatus(std::function<Status(absl::Span<const int64>, T*)> f) {
std::vector<int64> index(sizes_.size());
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
Status s = f(index, &values_[i]);
@@ -335,8 +332,7 @@ class Array {
// Invokes a callback with the (indices, value) for each cell in the array.
// If a callback returns a non-OK status, returns that else returns
// Status::OK().
- Status EachStatus(
- std::function<Status(tensorflow::gtl::ArraySlice<int64>, T)> f) const {
+ Status EachStatus(std::function<Status(absl::Span<const int64>, T)> f) const {
std::vector<int64> index(sizes_.size());
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
Status s = f(index, values_[i]);
@@ -377,13 +373,13 @@ class Array {
// Returns the value at the cell specified by the indexes. The number of
// arguments have to match with the number of dimensions for the array.
- const T& operator()(tensorflow::gtl::ArraySlice<int64> indexes) const {
+ const T& operator()(absl::Span<const int64> indexes) const {
return values_[calculate_index(indexes)];
}
// Returns the value at the cell specified by the indexes. The number of
// arguments have to match with the number of dimensions for the array.
- T& operator()(tensorflow::gtl::ArraySlice<int64> indexes) {
+ T& operator()(absl::Span<const int64> indexes) {
return values_[calculate_index(indexes)];
}
@@ -438,8 +434,8 @@ class Array {
bool operator!=(const Array<T>& other) const { return !(*this == other); }
// Performs the equivalent of a slice operation on this array.
- Array<T> Slice(tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> limits) const {
+ Array<T> Slice(absl::Span<const int64> starts,
+ absl::Span<const int64> limits) const {
CHECK_EQ(starts.size(), num_dimensions());
CHECK_EQ(limits.size(), num_dimensions());
@@ -464,7 +460,7 @@ class Array {
// Performs the equivalent of a DynamicUpdateSlice in-place on this array.
void UpdateSlice(const Array<T>& from,
- tensorflow::gtl::ArraySlice<int64> start_indices) {
+ absl::Span<const int64> start_indices) {
CHECK_EQ(from.num_dimensions(), num_dimensions());
std::vector<int64> limit_indices;
std::transform(start_indices.begin(), start_indices.end(),
@@ -484,7 +480,7 @@ class Array {
// Performs an in-place reshape, modifying the dimensions but not the
// underlying data.
- void Reshape(tensorflow::gtl::ArraySlice<int64> new_dimensions) {
+ void Reshape(absl::Span<const int64> new_dimensions) {
int64 old_num_elements = num_elements();
sizes_ = std::vector<int64>(new_dimensions.begin(), new_dimensions.end());
CHECK_EQ(num_elements(), old_num_elements);
diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h
index 8557bb8fe4..e23d317baf 100644
--- a/tensorflow/compiler/xla/array4d.h
+++ b/tensorflow/compiler/xla/array4d.h
@@ -27,10 +27,10 @@ limitations under the License.
#include <vector>
#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/array4d_test.cc b/tensorflow/compiler/xla/array4d_test.cc
index 927733ea1e..918872a7a0 100644
--- a/tensorflow/compiler/xla/array4d_test.cc
+++ b/tensorflow/compiler/xla/array4d_test.cc
@@ -18,8 +18,8 @@ limitations under the License.
#include <initializer_list>
#include <numeric>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace {
@@ -27,8 +27,7 @@ namespace {
// Given an Array4D and a 4-tuple index, computes the linear index into the
// array idx represents.
template <typename T>
-int64 Array4DLinearIndex(const Array4D<T>& arr,
- tensorflow::gtl::ArraySlice<int64> idx) {
+int64 Array4DLinearIndex(const Array4D<T>& arr, absl::Span<const int64> idx) {
EXPECT_EQ(4, idx.size());
return (idx[3] + idx[2] * arr.n4() + idx[1] * arr.n3() * arr.n4() +
idx[0] * arr.n2() * arr.n3() * arr.n4());
@@ -51,9 +50,8 @@ TEST(Array4dTest, FillCtor) {
EXPECT_EQ(fullof7.n3(), 4);
EXPECT_EQ(fullof7.n4(), 5);
- fullof7.Each([](tensorflow::gtl::ArraySlice<int64> idx, int* cell) {
- EXPECT_EQ(*cell, 7);
- });
+ fullof7.Each(
+ [](absl::Span<const int64> idx, int* cell) { EXPECT_EQ(*cell, 7); });
}
TEST(Array4dTest, ContainerCtor) {
@@ -69,7 +67,7 @@ TEST(Array4dTest, ContainerCtor) {
EXPECT_EQ(arr.n3(), 4);
EXPECT_EQ(arr.n4(), 5);
- arr.Each([&arr](tensorflow::gtl::ArraySlice<int64> idx, int* cell) {
+ arr.Each([&arr](absl::Span<const int64> idx, int* cell) {
EXPECT_EQ(*cell, Array4DLinearIndex(arr, idx));
});
}
@@ -129,21 +127,19 @@ TEST(Array3dTest, InitializerListCtorHalf) {
TEST(Array4dTest, Fill) {
Array4D<int> fullof7(2, 3, 4, 5, 7);
- fullof7.Each([](tensorflow::gtl::ArraySlice<int64> idx, int* cell) {
- EXPECT_EQ(*cell, 7);
- });
+ fullof7.Each(
+ [](absl::Span<const int64> idx, int* cell) { EXPECT_EQ(*cell, 7); });
fullof7.Fill(11);
- fullof7.Each([](tensorflow::gtl::ArraySlice<int64> idx, int* cell) {
- EXPECT_EQ(*cell, 11);
- });
+ fullof7.Each(
+ [](absl::Span<const int64> idx, int* cell) { EXPECT_EQ(*cell, 11); });
}
TEST(Array4dTest, FillWithMultiples) {
Array4D<float> arr(2, 3, 4, 5);
arr.FillWithMultiples(2.0f);
- arr.Each([&arr](tensorflow::gtl::ArraySlice<int64> idx, float* cell) {
+ arr.Each([&arr](absl::Span<const int64> idx, float* cell) {
EXPECT_EQ(*cell, 2.0f * Array4DLinearIndex(arr, idx));
});
}
diff --git a/tensorflow/compiler/xla/array_test.cc b/tensorflow/compiler/xla/array_test.cc
index e8356c9832..2d0ac98bd4 100644
--- a/tensorflow/compiler/xla/array_test.cc
+++ b/tensorflow/compiler/xla/array_test.cc
@@ -163,7 +163,7 @@ TEST(ArrayTest, Each) {
arr.FillWithMultiples(1);
int64 each_count = 0, each_sum = 0;
- arr.Each([&](tensorflow::gtl::ArraySlice<int64> idx, int cell) {
+ arr.Each([&](absl::Span<const int64> idx, int cell) {
int64 lin_idx = idx[0] * 12 + idx[1] * 4 + idx[2];
EXPECT_EQ(lin_idx, cell);
each_count++;
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index 2638dea1bd..f825f67b44 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -45,6 +45,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -78,6 +79,7 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -118,9 +120,9 @@ cc_library(
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:source_map_util",
"//tensorflow/compiler/xla/service:stream_pool",
- "//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
"@llvm//:support",
],
)
@@ -220,6 +222,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index 1fdf8f6260..8818f81312 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -163,8 +163,7 @@ Status Client::ResetDevice() {
}
StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options,
ExecutionProfile* execution_profile) {
TF_ASSIGN_OR_RETURN(
@@ -212,8 +211,7 @@ StatusOr<XlaComputation> Client::LoadSnapshot(const HloSnapshot& module) {
}
StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options,
ExecutionProfile* execution_profile) {
ExecuteGraphRequest request;
@@ -252,7 +250,7 @@ StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
}
StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
- tensorflow::gtl::ArraySlice<XlaComputationInstance> computations) {
+ absl::Span<const XlaComputationInstance> computations) {
ExecuteGraphParallelRequest request;
for (const XlaComputationInstance& computation : computations) {
diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h
index be50cebfcc..7960b07868 100644
--- a/tensorflow/compiler/xla/client/client.h
+++ b/tensorflow/compiler/xla/client/client.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@@ -53,7 +53,7 @@ class Client {
// will be filled with profile data from the execution.
StatusOr<std::unique_ptr<GlobalData>> Execute(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options = nullptr,
ExecutionProfile* execution_profile = nullptr);
@@ -82,7 +82,7 @@ class Client {
// from each computation.
//
StatusOr<std::vector<std::unique_ptr<GlobalData>>> ExecuteParallel(
- tensorflow::gtl::ArraySlice<XlaComputationInstance> computations);
+ absl::Span<const XlaComputationInstance> computations);
// Requests device_count device handles available on the target. The returned
// device handles are used to specify the devices to execute the computations
@@ -134,7 +134,7 @@ class Client {
// Execute() and Transfer().
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options = nullptr,
ExecutionProfile* execution_profile = nullptr);
diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc
index 040344c9a6..a6c58cb175 100644
--- a/tensorflow/compiler/xla/client/compile_only_client.cc
+++ b/tensorflow/compiler/xla/client/compile_only_client.cc
@@ -23,7 +23,7 @@ namespace xla {
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileOnlyClient::CompileAheadOfTime(
- const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
+ const absl::Span<const AotXlaComputationInstance> computations,
const AotCompilationOptions& options,
std::unique_ptr<AotCompilationMetadata>* metadata) {
std::vector<CompileOnlyService::AotXlaComputationInstance> service_instances;
diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h
index d0c83cbfcc..9e3ed23734 100644
--- a/tensorflow/compiler/xla/client/compile_only_client.h
+++ b/tensorflow/compiler/xla/client/compile_only_client.h
@@ -52,7 +52,7 @@ class CompileOnlyClient : public Client {
// code. |metadata|, if provided, is populated during compilation.
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
- const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
+ const absl::Span<const AotXlaComputationInstance> computations,
const AotCompilationOptions& options,
std::unique_ptr<AotCompilationMetadata>* metadata = nullptr);
diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h
index 888d2f28eb..93334db88b 100644
--- a/tensorflow/compiler/xla/client/executable_build_options.h
+++ b/tensorflow/compiler/xla/client/executable_build_options.h
@@ -86,7 +86,7 @@ class ExecutableBuildOptions {
void add_disabled_hlo_pass(absl::string_view pass_name) {
disabled_hlo_passes_.push_back(std::string(pass_name));
}
- const tensorflow::gtl::ArraySlice<std::string> disabled_hlo_passes() const {
+ const absl::Span<const std::string> disabled_hlo_passes() const {
return disabled_hlo_passes_;
}
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index 8736f18dcf..a18c94c4e6 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -113,7 +113,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/compiler/xla/client/lib/conv_grad_size_util.h b/tensorflow/compiler/xla/client/lib/conv_grad_size_util.h
index c18087ce6b..0ad01728e6 100644
--- a/tensorflow/compiler/xla/client/lib/conv_grad_size_util.h
+++ b/tensorflow/compiler/xla/client/lib/conv_grad_size_util.h
@@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_
#include "tensorflow/compiler/xla/client/padding.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc
index e569610b85..d3d7edb42a 100644
--- a/tensorflow/compiler/xla/client/lib/math.cc
+++ b/tensorflow/compiler/xla/client/lib/math.cc
@@ -69,8 +69,7 @@ std::array<float, 6> kErfUCoefficient = {
// Evaluate the polynomial given coefficients and `x`.
// N.B. Coefficients should be supplied in decreasing order.
-XlaOp EvaluatePolynomial(XlaOp x,
- tensorflow::gtl::ArraySlice<float> coefficients) {
+XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const float> coefficients) {
XlaOp poly = ScalarLike(x, 0.0);
for (float c : coefficients) {
poly = poly * x + ScalarLike(x, c);
diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h
index 13db232556..a6cafd4207 100644
--- a/tensorflow/compiler/xla/client/lib/math.h
+++ b/tensorflow/compiler/xla/client/lib/math.h
@@ -34,8 +34,7 @@ XlaOp Reciprocal(XlaOp operand);
// Evaluates a polynomial given coefficients and `x`.
// N.B. Coefficients should be supplied in decreasing order.
-XlaOp EvaluatePolynomial(XlaOp x,
- tensorflow::gtl::ArraySlice<float> coefficients);
+XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const float> coefficients);
// Computes an approximation of the error function complement (1 - erf(x)).
XlaOp Erfc(XlaOp x);
diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc
index 02bed80162..377654220b 100644
--- a/tensorflow/compiler/xla/client/lib/numeric.cc
+++ b/tensorflow/compiler/xla/client/lib/numeric.cc
@@ -16,60 +16,13 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
-namespace {
-
-template <typename T>
-XlaOp MakeIota(XlaBuilder* builder, int64 size) {
- std::vector<T> values(size);
- for (int64 i = 0; i < size; ++i) {
- values[i] = static_cast<T>(i);
- }
- return ConstantR1<T>(builder, values);
-}
-
-} // namespace
-
-XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) {
- switch (type) {
- case S8:
- return MakeIota<int8>(builder, size);
- case S16:
- return MakeIota<int16>(builder, size);
- case S32:
- return MakeIota<int32>(builder, size);
- case S64:
- return MakeIota<int64>(builder, size);
- case U8:
- return MakeIota<uint8>(builder, size);
- case U16:
- return MakeIota<uint16>(builder, size);
- case U32:
- return MakeIota<uint32>(builder, size);
- case U64:
- return MakeIota<uint64>(builder, size);
- case BF16:
- return MakeIota<bfloat16>(builder, size);
- case F16:
- return MakeIota<half>(builder, size);
- case F32:
- return MakeIota<float>(builder, size);
- case F64:
- return MakeIota<double>(builder, size);
- case C64:
- return MakeIota<complex64>(builder, size);
- default:
- return builder->ReportError(InvalidArgument(
- "Unimplemented type for Iota: %s.", PrimitiveType_Name(type)));
- }
-}
-
XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m,
int64 n) {
auto a = Iota(builder, type, m);
@@ -86,8 +39,8 @@ XlaOp GetMatrixDiagonal(XlaOp x) {
TF_RET_CHECK(n_dims >= 2);
const int64 m = shape.dimensions(n_dims - 2);
const int64 n = shape.dimensions(n_dims - 1);
- tensorflow::gtl::ArraySlice<int64> major_dims(
- AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2);
+ absl::Span<const int64> major_dims =
+ AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
auto a = Iota(builder, U32, n);
auto b = Iota(builder, U32, m);
auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
@@ -113,8 +66,8 @@ XlaOp Triangle(XlaOp x, bool lower) {
TF_RET_CHECK(n_dims >= 2);
const int64 m = shape.dimensions(n_dims - 2);
const int64 n = shape.dimensions(n_dims - 1);
- tensorflow::gtl::ArraySlice<int64> major_dims(
- AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2);
+ absl::Span<const int64> major_dims =
+ AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
auto a = Iota(builder, U32, n);
auto b = Iota(builder, U32, m);
xla::XlaOp indicator;
diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc
index 8a96ec68d2..7d6aedd494 100644
--- a/tensorflow/compiler/xla/client/lib/numeric_test.cc
+++ b/tensorflow/compiler/xla/client/lib/numeric_test.cc
@@ -30,16 +30,6 @@ class NumericTest : public ClientLibraryTestBase {
void TestMatrixDiagonal();
};
-// TODO(b/64798317): Delete this test case once xla::IotaGen is converted to
-// xla::Iota. This test is already implemented for xla::IotaGen in
-// xla/tests/iota_test.cc.
-XLA_TEST_F(NumericTest, Iota) {
- XlaBuilder builder(TestName());
- Iota(&builder, S32, 10);
-
- ComputeAndCompareR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {});
-}
-
XLA_TEST_F(NumericTest, Triangle) {
XlaBuilder builder(TestName());
Array3D<int32> input(2, 3, 4);
diff --git a/tensorflow/compiler/xla/client/lib/pooling.cc b/tensorflow/compiler/xla/client/lib/pooling.cc
index 3ae9ae36f6..1979c867a4 100644
--- a/tensorflow/compiler/xla/client/lib/pooling.cc
+++ b/tensorflow/compiler/xla/client/lib/pooling.cc
@@ -26,11 +26,9 @@ namespace {
// element of an image by the count of elements that contributed to that
// element during pooling.
XlaOp AvgPoolDivideByCountWithGeneralPadding(
- XlaOp sums, PrimitiveType dtype,
- tensorflow::gtl::ArraySlice<int64> input_shape,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> spatial_padding,
- tensorflow::gtl::ArraySlice<int64> ksize,
- tensorflow::gtl::ArraySlice<int64> stride,
+ XlaOp sums, PrimitiveType dtype, absl::Span<const int64> input_shape,
+ absl::Span<const std::pair<int64, int64>> spatial_padding,
+ absl::Span<const int64> ksize, absl::Span<const int64> stride,
const TensorFormat& data_format) {
// The padding shouldn't be included in the counts. We use another
// ReduceWindow to find the right counts.
@@ -73,8 +71,8 @@ XlaOp AvgPoolDivideByCountWithGeneralPadding(
// Sums all elements in the window specified by 'kernel_size' and 'stride'.
XlaOp ComputeSums(XlaOp operand, XlaOp init_value,
- tensorflow::gtl::ArraySlice<int64> kernel_size,
- tensorflow::gtl::ArraySlice<int64> stride,
+ absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride,
const TensorFormat& data_format) {
XlaBuilder* b = operand.builder();
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@@ -89,8 +87,8 @@ XlaOp ComputeSums(XlaOp operand, XlaOp init_value,
// Creates a padding configuration out of spatial padding values.
PaddingConfig MakeSpatialPaddingConfig(
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> spatial_padding,
- int num_spatial_dims, tensorflow::gtl::ArraySlice<int64> stride,
+ absl::Span<const std::pair<int64, int64>> spatial_padding,
+ int num_spatial_dims, absl::Span<const int64> stride,
const TensorFormat& data_format) {
PaddingConfig padding_config;
for (int i = 0; i < 2 + num_spatial_dims; ++i) {
@@ -107,13 +105,12 @@ PaddingConfig MakeSpatialPaddingConfig(
return padding_config;
}
-XlaOp AvgPoolDivideByCount(
- XlaOp pooled, tensorflow::gtl::ArraySlice<int64> input_size,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- PrimitiveType dtype, const TensorFormat& data_format,
- bool counts_include_padding) {
+XlaOp AvgPoolDivideByCount(XlaOp pooled, absl::Span<const int64> input_size,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ PrimitiveType dtype, const TensorFormat& data_format,
+ bool counts_include_padding) {
if (counts_include_padding) {
// If counts include padding, all windows have the same number of elements
// contributing to each average. Divide by the window size everywhere to get
@@ -133,8 +130,8 @@ XlaOp AvgPoolDivideByCount(
} // namespace
-XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
- tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
+XlaOp MaxPool(XlaOp operand, absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride, Padding padding,
const TensorFormat& data_format) {
XlaBuilder* b = operand.builder();
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@@ -147,9 +144,9 @@ XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
});
}
-XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
- tensorflow::gtl::ArraySlice<int64> stride,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+XlaOp AvgPool(XlaOp operand, absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding,
const TensorFormat& data_format,
const bool counts_include_padding) {
XlaBuilder* b = operand.builder();
@@ -173,9 +170,8 @@ XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
}
std::vector<std::pair<int64, int64>> MakeSpatialPadding(
- tensorflow::gtl::ArraySlice<int64> input_size,
- tensorflow::gtl::ArraySlice<int64> kernel_size,
- tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
+ absl::Span<const int64> input_size, absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride, Padding padding,
const TensorFormat& data_format) {
const int num_spatial_dims = kernel_size.size() - 2;
std::vector<int64> input_spatial_dimensions;
@@ -193,12 +189,12 @@ std::vector<std::pair<int64, int64>> MakeSpatialPadding(
stride_spatial_dimensions, padding);
}
-XlaOp AvgPoolGrad(
- XlaOp out_backprop, tensorflow::gtl::ArraySlice<int64> gradients_size,
- tensorflow::gtl::ArraySlice<int64> kernel_size,
- tensorflow::gtl::ArraySlice<int64> stride,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> spatial_padding,
- const TensorFormat& data_format, const bool counts_include_padding) {
+XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span<const int64> gradients_size,
+ absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> spatial_padding,
+ const TensorFormat& data_format,
+ const bool counts_include_padding) {
XlaBuilder* b = out_backprop.builder();
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
const int num_dims = kernel_size.size();
diff --git a/tensorflow/compiler/xla/client/lib/pooling.h b/tensorflow/compiler/xla/client/lib/pooling.h
index 291c711a00..5c0054857d 100644
--- a/tensorflow/compiler/xla/client/lib/pooling.h
+++ b/tensorflow/compiler/xla/client/lib/pooling.h
@@ -25,7 +25,7 @@ namespace xla {
class TensorFormat {
public:
TensorFormat(int batch_dimension, int feature_dimension,
- tensorflow::gtl::ArraySlice<int64> spatial_dimensions)
+ absl::Span<const int64> spatial_dimensions)
: batch_dimension_(batch_dimension),
feature_dimension_(feature_dimension),
spatial_dimensions_(spatial_dimensions.begin(),
@@ -49,32 +49,31 @@ class TensorFormat {
};
// Computes the max pool of 'operand'.
-XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
- tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
+XlaOp MaxPool(XlaOp operand, absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride, Padding padding,
const TensorFormat& data_format);
// Computes the average pool of 'operand'.
-XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
- tensorflow::gtl::ArraySlice<int64> stride,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+XlaOp AvgPool(XlaOp operand, absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding,
const TensorFormat& data_format,
const bool counts_include_padding);
// Returns the list of low and high padding elements in each spatial dimension
// for the given 'padding' specification.
std::vector<std::pair<int64, int64>> MakeSpatialPadding(
- tensorflow::gtl::ArraySlice<int64> input_size,
- tensorflow::gtl::ArraySlice<int64> kernel_size,
- tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
+ absl::Span<const int64> input_size, absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride, Padding padding,
const TensorFormat& data_format);
// Computes the average pool gradient.
-XlaOp AvgPoolGrad(
- XlaOp out_backprop, tensorflow::gtl::ArraySlice<int64> gradients_size,
- tensorflow::gtl::ArraySlice<int64> kernel_size,
- tensorflow::gtl::ArraySlice<int64> stride,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> spatial_padding,
- const TensorFormat& data_format, const bool counts_include_padding);
+XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span<const int64> gradients_size,
+ absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> spatial_padding,
+ const TensorFormat& data_format,
+ const bool counts_include_padding);
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/pooling_test.cc b/tensorflow/compiler/xla/client/lib/pooling_test.cc
index 1890047918..30adb9b1ad 100644
--- a/tensorflow/compiler/xla/client/lib/pooling_test.cc
+++ b/tensorflow/compiler/xla/client/lib/pooling_test.cc
@@ -32,8 +32,8 @@ TensorFormat MakeNCHWFormat(int num_spatial_dims) {
}
std::vector<std::pair<int64, int64>> MakeGeneralPadding(
- XlaOp input, tensorflow::gtl::ArraySlice<int64> kernel_size,
- tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
+ XlaOp input, absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride, Padding padding,
const xla::TensorFormat& data_format) {
XlaBuilder* b = input.builder();
Shape operand_shape = b->GetShape(input).ValueOrDie();
@@ -46,7 +46,7 @@ std::vector<std::pair<int64, int64>> MakeGeneralPadding(
// Add singleton batch and feature dimensions to spatial dimensions, according
// to 'data_format' specification.
std::vector<int64> ExpandWithBatchAndFeatureDimensions(
- tensorflow::gtl::ArraySlice<int64> spatial_dim_sizes,
+ absl::Span<const int64> spatial_dim_sizes,
const xla::TensorFormat& data_format) {
const int num_spatial_dims = spatial_dim_sizes.size();
std::vector<int64> tensor_sizes(num_spatial_dims + 2, 1);
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index db7a8fc047..4402ba8762 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -51,7 +51,7 @@ LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
}
Status LocalExecutable::ValidateExecutionOptions(
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const absl::Span<const ShapedBuffer* const> arguments,
const ExecutableRunOptions& run_options, const Backend& backend) {
const ComputationLayout& computation_layout =
executable_->module_config().entry_computation_layout();
@@ -140,7 +140,7 @@ Status LocalExecutable::ValidateExecutionOptions(
}
StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const absl::Span<const ShapedBuffer* const> arguments,
ExecutableRunOptions run_options) {
TF_RETURN_IF_ERROR(
ValidateExecutionOptions(arguments, run_options, *backend_));
@@ -177,7 +177,7 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
StatusOr<ScopedShapedBuffer> LocalExecutable::ExecuteAndDump(
const ServiceExecutableRunOptions* run_options,
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ const absl::Span<const ShapedBuffer* const> arguments) {
executable_->hlo_snapshot()->set_execution_platform(
backend_->platform()->Name());
TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot()));
@@ -191,7 +191,7 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::ExecuteAndDump(
}
Status LocalExecutable::RecordArguments(
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const absl::Span<const ShapedBuffer* const> arguments,
HloSnapshot* hlo_snapshot) {
hlo_snapshot->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
@@ -245,7 +245,7 @@ Backend* LocalClient::mutable_backend() {
StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
const XlaComputation& computation,
- const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& options) {
ExecutableBuildOptions updated_options = options;
if (options.device_ordinal() == -1) {
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index ae23809261..56c3a3da02 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
@@ -30,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
@@ -40,7 +40,7 @@ class LocalExecutable {
// Run the compiled computation with the given arguments and options and
// return the result.
StatusOr<ScopedShapedBuffer> Run(
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const absl::Span<const ShapedBuffer* const> arguments,
ExecutableRunOptions run_options);
// Return the options used to build the executable.
@@ -63,7 +63,7 @@ class LocalExecutable {
// The given ExecutableRunOptions override any values from legacy_flags
// (TF_XLA_FLAGS environment variable).
Status ValidateExecutionOptions(
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const absl::Span<const ShapedBuffer* const> arguments,
const ExecutableRunOptions& run_options, const Backend& backend);
// Records the computation in a SessionModule proto with the arguments used to
@@ -73,13 +73,12 @@ class LocalExecutable {
// (TF_XLA_FLAGS environment variable).
StatusOr<ScopedShapedBuffer> ExecuteAndDump(
const ServiceExecutableRunOptions* run_options,
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
+ const absl::Span<const ShapedBuffer* const> arguments);
// Records the arguments used to invoke the computation in a SessionModule
// proto.
- Status RecordArguments(
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- HloSnapshot* hlo_snapshot);
+ Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
+ HloSnapshot* hlo_snapshot);
// Records the result of the computation in a SessionModule proto.
Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot);
@@ -120,7 +119,7 @@ class LocalClient : public Client {
// (TF_XLA_FLAGS environment variable).
StatusOr<std::unique_ptr<LocalExecutable>> Compile(
const XlaComputation& computation,
- const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& options);
// Copy the literal data to the device with the given ordinal and return as a
diff --git a/tensorflow/compiler/xla/client/padding.cc b/tensorflow/compiler/xla/client/padding.cc
index ed4dc8e9f6..992b13139c 100644
--- a/tensorflow/compiler/xla/client/padding.cc
+++ b/tensorflow/compiler/xla/client/padding.cc
@@ -23,10 +23,9 @@ limitations under the License.
namespace xla {
-Status ValidatePaddingValues(
- tensorflow::gtl::ArraySlice<int64> input_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides) {
+Status ValidatePaddingValues(absl::Span<const int64> input_dimensions,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides) {
bool ok = input_dimensions.size() == window_dimensions.size() &&
input_dimensions.size() == window_strides.size();
if (!ok) {
@@ -40,9 +39,9 @@ Status ValidatePaddingValues(
}
std::vector<std::pair<int64, int64>> MakePadding(
- tensorflow::gtl::ArraySlice<int64> input_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
+ absl::Span<const int64> input_dimensions,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides, Padding padding) {
TF_CHECK_OK(ValidatePaddingValues(input_dimensions, window_dimensions,
window_strides));
std::vector<std::pair<int64, int64>> low_high_padding;
diff --git a/tensorflow/compiler/xla/client/padding.h b/tensorflow/compiler/xla/client/padding.h
index e23b0b3a90..5c009bd49e 100644
--- a/tensorflow/compiler/xla/client/padding.h
+++ b/tensorflow/compiler/xla/client/padding.h
@@ -19,9 +19,9 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
@@ -41,10 +41,9 @@ enum class Padding {
// Validates that the slices are acceptable for determining padding -- this can
// be used to check the preconditions of MakePadding below to produce an error
// message that can be returned to the user.
-Status ValidatePaddingValues(
- tensorflow::gtl::ArraySlice<int64> input_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides);
+Status ValidatePaddingValues(absl::Span<const int64> input_dimensions,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides);
// Returns the padding needed for the base area, given the base area dimensions,
// window dimensions, strides, and the type of padding.
@@ -58,9 +57,9 @@ Status ValidatePaddingValues(
// window_dimensions, and strides must match, which is equal to the number
// of elements in the result vector.
std::vector<std::pair<int64, int64>> MakePadding(
- tensorflow::gtl::ArraySlice<int64> input_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
+ absl::Span<const int64> input_dimensions,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides, Padding padding);
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 819d324927..e639028ccd 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -90,7 +90,7 @@ StatusOr<Shape> XlaBuilder::GetShape(const XlaOp& op) const {
}
StatusOr<std::vector<Shape>> XlaBuilder::GetOperandShapes(
- tensorflow::gtl::ArraySlice<XlaOp> operands) const {
+ absl::Span<const XlaOp> operands) const {
std::vector<Shape> operand_shapes;
for (const XlaOp& operand : operands) {
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
@@ -291,7 +291,7 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
const Shape& shape, const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
TF_RETURN_IF_ERROR(first_error_);
HloInstructionProto instr;
@@ -352,9 +352,8 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) {
});
}
-XlaOp XlaBuilder::BinaryOp(
- HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
@@ -448,12 +447,12 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
}
XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kAdd, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions);
}
@@ -466,7 +465,7 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
});
}
-XlaOp XlaBuilder::IotaGen(const Shape& shape, int64 iota_dimension) {
+XlaOp XlaBuilder::Iota(const Shape& shape, int64 iota_dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = shape;
@@ -475,12 +474,12 @@ XlaOp XlaBuilder::IotaGen(const Shape& shape, int64 iota_dimension) {
});
}
-XlaOp XlaBuilder::IotaGen(PrimitiveType type, int64 size) {
- return IotaGen(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0);
+XlaOp XlaBuilder::Iota(PrimitiveType type, int64 size) {
+ return Iota(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0);
}
XlaOp XlaBuilder::Call(const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<XlaOp> operands) {
+ absl::Span<const XlaOp> operands) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
@@ -515,8 +514,8 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
});
}
-XlaOp XlaBuilder::Broadcast(
- const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
+XlaOp XlaBuilder::Broadcast(const XlaOp& operand,
+ absl::Span<const int64> broadcast_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
@@ -541,7 +540,7 @@ XlaOp XlaBuilder::Broadcast(
XlaOp XlaBuilder::BroadcastInDim(
const XlaOp& operand, const Shape& shape,
- const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ const absl::Span<const int64> broadcast_dimensions) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
return InDimBroadcast(shape, operand, broadcast_dimensions);
});
@@ -556,9 +555,9 @@ StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) {
}
XlaOp XlaBuilder::Slice(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides) {
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -593,7 +592,7 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index,
}
XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ absl::Span<const int64> slice_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -631,7 +630,7 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
});
}
-XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands,
+XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
int64 dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -671,8 +670,8 @@ XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value,
}
XlaOp XlaBuilder::Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes) {
+ absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(const Shape& shape,
@@ -686,7 +685,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
}
XlaOp XlaBuilder::Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> new_sizes) {
+ absl::Span<const int64> new_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand));
std::vector<int64> dimensions(shape.dimensions_size());
@@ -696,7 +695,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
}
XlaOp XlaBuilder::Collapse(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+ absl::Span<const int64> dimensions) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (dimensions.size() <= 1) {
// Not collapsing anything, trivially we can return the operand versus
@@ -706,8 +705,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand,
// Out-of-order collapse is not supported.
// Checks that the collapsed dimensions are in order and consecutive.
- for (tensorflow::gtl::ArraySlice<int64>::size_type i = 1;
- i < dimensions.size(); ++i) {
+ for (absl::Span<const int64>::size_type i = 1; i < dimensions.size(); ++i) {
if (dimensions[i] - 1 != dimensions[i - 1]) {
return InvalidArgument(
"Collapsed dimensions are not in consecutive order.");
@@ -758,7 +756,7 @@ XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true,
});
}
-XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) {
+XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
@@ -792,32 +790,32 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) {
}
XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kEq, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Ne(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kNe, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Ge(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kGe, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Gt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kGt, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Le(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kLe, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kLt, lhs, rhs, broadcast_dimensions);
}
@@ -899,8 +897,8 @@ Status XlaBuilder::VerifyConvolution(
}
XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding, int64 feature_group_count,
+ absl::Span<const int64> window_strides, Padding padding,
+ int64 feature_group_count,
const PrecisionConfigProto* precision_config_proto) {
return ConvWithGeneralDimensions(
lhs, rhs, window_strides, padding,
@@ -909,9 +907,8 @@ XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs,
}
XlaOp XlaBuilder::ConvWithGeneralPadding(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count,
const PrecisionConfigProto* precision_config_proto) {
return ConvGeneral(lhs, rhs, window_strides, padding,
@@ -920,9 +917,8 @@ XlaOp XlaBuilder::ConvWithGeneralPadding(
}
XlaOp XlaBuilder::ConvWithGeneralDimensions(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers,
+ const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
+ Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count,
const PrecisionConfigProto* precision_config_proto) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@@ -957,9 +953,8 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions(
}
XlaOp XlaBuilder::ConvGeneral(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count,
const PrecisionConfigProto* precision_config_proto) {
@@ -969,11 +964,9 @@ XlaOp XlaBuilder::ConvGeneral(
}
XlaOp XlaBuilder::ConvGeneralDilated(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count,
const PrecisionConfigProto* precision_config_proto) {
@@ -1013,11 +1006,11 @@ XlaOp XlaBuilder::ConvGeneralDilated(
}
StatusOr<Window> XlaBuilder::MakeWindow(
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation) const {
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation,
+ absl::Span<const int64> rhs_dilation) const {
const auto verify_size = [&](const size_t x, const char* x_name) {
if (x == 0 || x == window_dimensions.size()) {
return Status::OK();
@@ -1067,7 +1060,7 @@ StatusOr<Window> XlaBuilder::MakeWindow(
}
XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type,
- const tensorflow::gtl::ArraySlice<int64> fft_length) {
+ const absl::Span<const int64> fft_length) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -1276,7 +1269,7 @@ XlaOp XlaBuilder::CreateToken() {
});
}
-XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice<XlaOp> tokens) {
+XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (tokens.empty()) {
return InvalidArgument("AfterAll requires at least one operand");
@@ -1288,7 +1281,7 @@ XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice<XlaOp> tokens) {
}
XlaOp XlaBuilder::CustomCall(const string& call_target_name,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
+ absl::Span<const XlaOp> operands,
const Shape& shape) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -1304,9 +1297,8 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name,
});
}
-XlaOp XlaBuilder::Complex(
- const XlaOp& real, const XlaOp& imag,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+XlaOp XlaBuilder::Complex(const XlaOp& real, const XlaOp& imag,
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kComplex, real, imag, broadcast_dimensions);
}
@@ -1315,42 +1307,42 @@ XlaOp XlaBuilder::Conj(const XlaOp& operand) {
}
XlaOp XlaBuilder::Sub(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kSubtract, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Div(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kDivide, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Rem(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kRemainder, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Max(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kMaximum, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Min(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kMinimum, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::And(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kAnd, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kXor, lhs, rhs, broadcast_dimensions);
}
@@ -1358,22 +1350,21 @@ XlaOp XlaBuilder::Not(const XlaOp& operand) {
return UnaryOp(HloOpcode::kNot, operand);
}
-XlaOp XlaBuilder::ShiftLeft(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+XlaOp XlaBuilder::ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::ShiftRightArithmetic(
const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs,
broadcast_dimensions);
}
XlaOp XlaBuilder::ShiftRightLogical(
const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs,
broadcast_dimensions);
}
@@ -1382,9 +1373,8 @@ XlaOp XlaBuilder::Abs(const XlaOp& operand) {
return UnaryOp(HloOpcode::kAbs, operand);
}
-XlaOp XlaBuilder::Atan2(
- const XlaOp& y, const XlaOp& x,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+XlaOp XlaBuilder::Atan2(const XlaOp& y, const XlaOp& x,
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kAtan2, y, x, broadcast_dimensions);
}
@@ -1449,7 +1439,7 @@ XlaOp XlaBuilder::IsFinite(const XlaOp& operand) {
}
XlaOp XlaBuilder::Transpose(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> permutation) {
+ absl::Span<const int64> permutation) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -1464,7 +1454,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand,
}
XlaOp XlaBuilder::Rev(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+ absl::Span<const int64> dimensions) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -1506,7 +1496,7 @@ XlaOp XlaBuilder::Sort(XlaOp keys, absl::optional<XlaOp> values,
}
XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kPower, lhs, rhs, broadcast_dimensions);
}
@@ -1544,10 +1534,10 @@ XlaOp XlaBuilder::Clamp(const XlaOp& min, const XlaOp& operand,
return TernaryOp(HloOpcode::kClamp, min, operand, max);
}
-XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
+XlaOp XlaBuilder::Map(absl::Span<const XlaOp> operands,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<XlaOp> static_operands) {
+ absl::Span<const int64> dimensions,
+ absl::Span<const XlaOp> static_operands) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (!static_operands.empty()) {
return Unimplemented("static_operands is not supported in Map");
@@ -1588,7 +1578,7 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
}
XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<XlaOp> parameters,
+ absl::Span<const XlaOp> parameters,
const Shape& shape) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -1649,7 +1639,7 @@ XlaOp XlaBuilder::While(const XlaComputation& condition,
XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ absl::Span<const int64> slice_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -1729,22 +1719,39 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand,
});
}
-XlaOp XlaBuilder::Reduce(
- const XlaOp& operand, const XlaOp& init_value,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
+XlaOp XlaBuilder::Reduce(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ absl::Span<const int64> dimensions_to_reduce) {
+ return Reduce(absl::Span<const XlaOp>({operand}),
+ absl::Span<const XlaOp>({init_value}), computation,
+ dimensions_to_reduce);
+}
+
+XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
+ absl::Span<const XlaOp> init_values,
+ const XlaComputation& computation,
+ absl::Span<const int64> dimensions_to_reduce) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
- TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
- TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value));
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
computation.GetProgramShape());
- TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
- ShapeInference::InferReduceShape(
- {&operand_shape, &init_shape}, dimensions_to_reduce,
- called_program_shape));
+ std::vector<XlaOp> all_operands;
+ all_operands.insert(all_operands.end(), operands.begin(), operands.end());
+ all_operands.insert(all_operands.end(), init_values.begin(),
+ init_values.end());
+
+ std::vector<const Shape*> operand_shape_ptrs;
+ TF_ASSIGN_OR_RETURN(const auto& operand_shapes,
+ GetOperandShapes(all_operands));
+ absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
+ [](const Shape& shape) { return &shape; });
+
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferReduceShape(
+ operand_shape_ptrs, dimensions_to_reduce, called_program_shape));
for (int64 dim : dimensions_to_reduce) {
instr.add_dimensions(dim);
@@ -1752,8 +1759,7 @@ XlaOp XlaBuilder::Reduce(
AddCalledComputation(computation, &instr);
- return AddInstruction(std::move(instr), HloOpcode::kReduce,
- {operand, init_value});
+ return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands);
});
}
@@ -1767,11 +1773,11 @@ XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value,
});
}
-XlaOp XlaBuilder::ReduceWindow(
- const XlaOp& operand, const XlaOp& init_value,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
+XlaOp XlaBuilder::ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ Padding padding) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -1792,9 +1798,9 @@ XlaOp XlaBuilder::ReduceWindow(
XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -1889,8 +1895,7 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
}
XlaOp XlaBuilder::CrossReplicaSum(
- const XlaOp& operand,
- tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups) {
+ const XlaOp& operand, absl::Span<const ReplicaGroup> replica_groups) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {});
@@ -1905,7 +1910,7 @@ XlaOp XlaBuilder::CrossReplicaSum(
XlaOp XlaBuilder::CrossReplicaSum(
const XlaOp& operand, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups,
+ absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<ChannelHandle>& channel_id) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -2005,12 +2010,13 @@ XlaOp XlaBuilder::CollectivePermute(
});
}
-XlaOp XlaBuilder::SelectAndScatter(
- const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const XlaOp& source, const XlaOp& init_value,
- const XlaComputation& scatter) {
+XlaOp XlaBuilder::SelectAndScatter(const XlaOp& operand,
+ const XlaComputation& select,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ Padding padding, const XlaOp& source,
+ const XlaOp& init_value,
+ const XlaComputation& scatter) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
return SelectAndScatterWithGeneralPadding(
@@ -2023,11 +2029,10 @@ XlaOp XlaBuilder::SelectAndScatter(
XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const XlaOp& source, const XlaOp& init_value,
- const XlaComputation& scatter) {
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
+ const XlaOp& init_value, const XlaComputation& scatter) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -2410,9 +2415,9 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
return Status::OK();
}
-StatusOr<XlaOp> XlaBuilder::AddInstruction(
- HloInstructionProto&& instr, HloOpcode opcode,
- tensorflow::gtl::ArraySlice<XlaOp> operands) {
+StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
+ HloOpcode opcode,
+ absl::Span<const XlaOp> operands) {
TF_RETURN_IF_ERROR(first_error_);
const int64 handle = instructions_.size();
@@ -2486,14 +2491,12 @@ XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) {
return builder->ConstantLiteral(literal);
}
-XlaOp Broadcast(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
+XlaOp Broadcast(const XlaOp& operand, absl::Span<const int64> broadcast_sizes) {
return operand.builder()->Broadcast(operand, broadcast_sizes);
}
-XlaOp BroadcastInDim(
- const XlaOp& operand, const Shape& shape,
- const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape,
+ const absl::Span<const int64> broadcast_dimensions) {
return operand.builder()->BroadcastInDim(operand, shape,
broadcast_dimensions);
}
@@ -2503,26 +2506,22 @@ XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
return operand.builder()->Pad(operand, padding_value, padding_config);
}
-XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes) {
+XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes) {
return operand.builder()->Reshape(operand, dimensions, new_sizes);
}
-XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> new_sizes) {
+XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes) {
return operand.builder()->Reshape(operand, new_sizes);
}
-XlaOp Collapse(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions) {
return operand.builder()->Collapse(operand, dimensions);
}
-XlaOp Slice(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides) {
+XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides) {
return operand.builder()->Slice(operand, start_indices, limit_indices,
strides);
}
@@ -2534,7 +2533,7 @@ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
}
XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ absl::Span<const int64> slice_sizes) {
return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
}
@@ -2543,8 +2542,7 @@ XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
}
-XlaOp ConcatInDim(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
+XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
int64 dimension) {
return builder->ConcatInDim(operands, dimension);
}
@@ -2557,7 +2555,7 @@ XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false) {
return pred.builder()->Select(pred, on_true, on_false);
}
-XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> elements) {
+XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements) {
return builder->Tuple(elements);
}
@@ -2566,32 +2564,32 @@ XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) {
}
XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Eq(lhs, rhs, broadcast_dimensions);
}
XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Ne(lhs, rhs, broadcast_dimensions);
}
XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Ge(lhs, rhs, broadcast_dimensions);
}
XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Gt(lhs, rhs, broadcast_dimensions);
}
XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Lt(lhs, rhs, broadcast_dimensions);
}
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Le(lhs, rhs, broadcast_dimensions);
}
@@ -2608,7 +2606,7 @@ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
}
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ absl::Span<const int64> window_strides, Padding padding,
int64 feature_group_count,
const PrecisionConfigProto* precision_config_proto) {
return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
@@ -2616,9 +2614,8 @@ XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
}
XlaOp ConvWithGeneralPadding(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count,
const PrecisionConfigProto* precision_config_proto) {
return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides,
@@ -2627,9 +2624,8 @@ XlaOp ConvWithGeneralPadding(
}
XlaOp ConvWithGeneralDimensions(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers,
+ const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
+ Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count,
const PrecisionConfigProto* precision_config_proto) {
return lhs.builder()->ConvWithGeneralDimensions(
@@ -2638,8 +2634,8 @@ XlaOp ConvWithGeneralDimensions(
}
XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count,
const PrecisionConfigProto* precision_config_proto) {
@@ -2648,22 +2644,21 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
precision_config_proto);
}
-XlaOp ConvGeneralDilated(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto) {
+XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation,
+ absl::Span<const int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count,
+ const PrecisionConfigProto* precision_config_proto) {
return lhs.builder()->ConvGeneralDilated(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, precision_config_proto);
}
XlaOp Fft(const XlaOp& operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length) {
+ absl::Span<const int64> fft_length) {
return operand.builder()->Fft(operand, fft_type, fft_length);
}
@@ -2677,99 +2672,106 @@ void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
}
XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<XlaOp> operands) {
+ absl::Span<const XlaOp> operands) {
return builder->Call(computation, operands);
}
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const Shape& shape) {
+ absl::Span<const XlaOp> operands, const Shape& shape) {
return builder->CustomCall(call_target_name, operands, shape);
}
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return real.builder()->Complex(real, imag, broadcast_dimensions);
}
XlaOp Conj(const XlaOp& operand) { return operand.builder()->Conj(operand); }
XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Add(lhs, rhs, broadcast_dimensions);
}
XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Sub(lhs, rhs, broadcast_dimensions);
}
XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Mul(lhs, rhs, broadcast_dimensions);
}
XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Div(lhs, rhs, broadcast_dimensions);
}
XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Rem(lhs, rhs, broadcast_dimensions);
}
XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Max(lhs, rhs, broadcast_dimensions);
}
XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Min(lhs, rhs, broadcast_dimensions);
}
XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->And(lhs, rhs, broadcast_dimensions);
}
XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Or(lhs, rhs, broadcast_dimensions);
}
XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Xor(lhs, rhs, broadcast_dimensions);
}
XlaOp Not(const XlaOp& operand) { return operand.builder()->Not(operand); }
XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->ShiftLeft(lhs, rhs, broadcast_dimensions);
}
-XlaOp ShiftRightArithmetic(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->ShiftRightArithmetic(lhs, rhs, broadcast_dimensions);
}
-XlaOp ShiftRightLogical(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->ShiftRightLogical(lhs, rhs, broadcast_dimensions);
}
XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
+ absl::Span<const int64> dimensions_to_reduce) {
return operand.builder()->Reduce(operand, init_value, computation,
dimensions_to_reduce);
}
+// Reduces several arrays simultaneously among the provided dimensions, given
+// "computation" as a reduction operator.
+XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
+ absl::Span<const XlaOp> init_values,
+ const XlaComputation& computation,
+ absl::Span<const int64> dimensions_to_reduce) {
+ return builder->Reduce(operands, init_values, computation,
+ dimensions_to_reduce);
+}
+
XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation) {
return operand.builder()->ReduceAll(operand, init_value, computation);
@@ -2777,9 +2779,8 @@ XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding) {
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides, Padding padding) {
return operand.builder()->ReduceWindow(operand, init_value, computation,
window_dimensions, window_strides,
padding);
@@ -2788,22 +2789,21 @@ XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
XlaOp ReduceWindowWithGeneralPadding(
const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding) {
return operand.builder()->ReduceWindowWithGeneralPadding(
operand, init_value, computation, window_dimensions, window_strides,
padding);
}
-XlaOp CrossReplicaSum(
- const XlaOp& operand,
- tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups) {
+XlaOp CrossReplicaSum(const XlaOp& operand,
+ absl::Span<const ReplicaGroup> replica_groups) {
return operand.builder()->CrossReplicaSum(operand, replica_groups);
}
XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups,
+ absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<ChannelHandle>& channel_id) {
return operand.builder()->CrossReplicaSum(operand, computation,
replica_groups, channel_id);
@@ -2823,10 +2823,10 @@ XlaOp CollectivePermute(
}
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding, const XlaOp& source,
- const XlaOp& init_value, const XlaComputation& scatter) {
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides, Padding padding,
+ const XlaOp& source, const XlaOp& init_value,
+ const XlaComputation& scatter) {
return operand.builder()->SelectAndScatter(operand, select, window_dimensions,
window_strides, padding, source,
init_value, scatter);
@@ -2834,11 +2834,10 @@ XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
XlaOp SelectAndScatterWithGeneralPadding(
const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const XlaOp& source, const XlaOp& init_value,
- const XlaComputation& scatter) {
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
+ const XlaOp& init_value, const XlaComputation& scatter) {
return operand.builder()->SelectAndScatterWithGeneralPadding(
operand, select, window_dimensions, window_strides, padding, source,
init_value, scatter);
@@ -2847,7 +2846,7 @@ XlaOp SelectAndScatterWithGeneralPadding(
XlaOp Abs(const XlaOp& operand) { return operand.builder()->Abs(operand); }
XlaOp Atan2(const XlaOp& y, const XlaOp& x,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return y.builder()->Atan2(y, x, broadcast_dimensions);
}
@@ -2880,7 +2879,7 @@ XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); }
XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); }
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions);
}
@@ -2898,12 +2897,11 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) {
XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); }
-XlaOp Transpose(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> permutation) {
+XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation) {
return operand.builder()->Transpose(operand, permutation);
}
-XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
+XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions) {
return operand.builder()->Rev(operand, dimensions);
}
@@ -2915,10 +2913,9 @@ XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) {
return min.builder()->Clamp(min, operand, max);
}
-XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> operands,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<XlaOp> static_operands) {
+XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
+ const XlaComputation& computation, absl::Span<const int64> dimensions,
+ absl::Span<const XlaOp> static_operands) {
return builder->Map(operands, computation, dimensions, static_operands);
}
@@ -2952,7 +2949,7 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ absl::Span<const int64> slice_sizes) {
return input.builder()->Gather(input, start_indices, dimension_numbers,
slice_sizes);
}
@@ -3008,7 +3005,7 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); }
-XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> tokens) {
+XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens) {
return builder->AfterAll(tokens);
}
@@ -3035,12 +3032,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
grad_output, epsilon, feature_index);
}
-XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size) {
- return builder->IotaGen(type, size);
+XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) {
+ return builder->Iota(type, size);
}
-XlaOp IotaGen(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) {
- return builder->IotaGen(shape, iota_dimension);
+XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) {
+ return builder->Iota(shape, iota_dimension);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 193d8ed071..59fbc664f2 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <utility>
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stacktrace.h"
@@ -294,7 +294,7 @@ class XlaBuilder {
template <typename NativeT>
XlaOp ConstantR0(NativeT value);
template <typename NativeT>
- XlaOp ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values);
+ XlaOp ConstantR1(absl::Span<const NativeT> values);
XlaOp ConstantR1(const tensorflow::core::Bitmap& values);
template <typename NativeT>
XlaOp ConstantR2(
@@ -336,7 +336,7 @@ class XlaBuilder {
//
// output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
XlaOp Broadcast(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+ absl::Span<const int64> broadcast_sizes);
// Performs in-dimension-style broadcast.
//
@@ -355,9 +355,8 @@ class XlaBuilder {
// will generate output
// [1 , 1]
// [2 , 2]
- XlaOp BroadcastInDim(
- const XlaOp& operand, const Shape& shape,
- const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape,
+ const absl::Span<const int64> broadcast_dimensions);
// Enqueues a pad operation onto the computation that pads the given value on
// the edges as well as between the elements of the input. padding_config
@@ -370,15 +369,13 @@ class XlaBuilder {
// given, followed by reshaping it into the shape with the given dimension
// sizes (also major to minor). Conceptually, this is a limited form of
// "shape casting".
- XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
+ XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes);
// Enqueues an operation onto the computation that collapses the operand, from
// first to last dimension (C order), then reshapes it to the given dimension
// sizes. Conceptually, this is a limited form of "shape casting".
- XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
+ XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);
// Wrapper for Reshape.
// Enqueues an operation to collapse the provided dimensions; e.g. an
@@ -398,8 +395,7 @@ class XlaBuilder {
//
// This could potentially cause data to be moved -- it provides a more
// structured form of reshaping than an arbitrary Reshape operation.
- XlaOp Collapse(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions);
// Enqueues a slice operation onto the computation that slices the operand
// from the start indices to the limit indices; e.g.
@@ -412,10 +408,9 @@ class XlaBuilder {
// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
// range notation.
// The strides parameter determines the stride over the slice
- XlaOp Slice(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
+ XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides);
// Enqueues a slice operation in a given dimension, taking all other
// dimensions as they are; e.g. if dimno is 1 from start_index 2 to
@@ -436,7 +431,7 @@ class XlaBuilder {
// Slice index calculations are computed modulo input dimension sizes to
// prevent dynamic start indices from generating out-of-bound array accesses.
XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
// Enqueues a dynamic update slice operation onto the computation, which
// updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
@@ -459,8 +454,7 @@ class XlaBuilder {
// Enqueues a concatenate instruction onto the computation. 'operands' must
// have >= 1 entry.
- XlaOp ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands,
- int64 dimension);
+ XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
// Enqueue a tracing operation onto the computation; the computation will emit
// a logging message with the operand.
@@ -471,34 +465,34 @@ class XlaBuilder {
XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
// Enqueues a tuple-creation instruction onto the computation.
- XlaOp Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements);
+ XlaOp Tuple(absl::Span<const XlaOp> elements);
// Enqueues a tuple-element-get instruction onto the computation.
XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
// Enqueues an equal-to comparison instruction onto the computation.
XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a not-equal comparison instruction onto the computation.
XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a greater-or-equal comparison instruction onto the computation.
XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a greater-than comparison instruction onto the computation.
XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a less-than comparison instruction onto the computation.
XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a less-or-equal comparison instruction onto the computation.
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a dot instruction onto the computation.
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
@@ -513,7 +507,7 @@ class XlaBuilder {
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ absl::Span<const int64> window_strides, Padding padding,
int64 feature_group_count = 1,
const PrecisionConfigProto* precision_config_proto = nullptr);
@@ -521,8 +515,8 @@ class XlaBuilder {
// provided padding configuration in the format returned by MakePadding().
XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count = 1,
const PrecisionConfigProto* precision_config_proto = nullptr);
@@ -530,7 +524,7 @@ class XlaBuilder {
// provided dimension numbers configuration.
XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ absl::Span<const int64> window_strides, Padding padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1,
const PrecisionConfigProto* precision_config_proto = nullptr);
@@ -539,8 +533,8 @@ class XlaBuilder {
// provided padding configuration as well as the dimension numbers.
XlaOp ConvGeneral(
const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1,
const PrecisionConfigProto* precision_config_proto = nullptr);
@@ -549,10 +543,10 @@ class XlaBuilder {
// provided padding configuration, dilation factors and dimension numbers.
XlaOp ConvGeneralDilated(
const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation,
+ absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1,
const PrecisionConfigProto* precision_config_proto = nullptr);
@@ -560,7 +554,7 @@ class XlaBuilder {
// Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length.
XlaOp Fft(const XlaOp& operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
+ absl::Span<const int64> fft_length);
// Enqueues an infeed instruction onto the computation, which writes data of
// the given shape to the infeed buffer of the device.
@@ -582,15 +576,14 @@ class XlaBuilder {
// Enqueues a call instruction onto the computation.
XlaOp Call(const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<XlaOp> operands);
+ absl::Span<const XlaOp> operands);
// Enqueues a custom call instruction onto the computation.
// During code generation, a call instruction is emitted which targets a
// symbol with the name |call_target_name|. The |operands| are passed to the
// call instruction. |shape| is the resultant shape.
XlaOp CustomCall(const string& call_target_name,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const Shape& shape);
+ absl::Span<const XlaOp> operands, const Shape& shape);
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
@@ -599,65 +592,70 @@ class XlaBuilder {
// Enqueues a complex compose instruction onto the computation.
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a complex conjugate instruction onto the computation.
XlaOp Conj(const XlaOp& operand);
// Enqueues an add instruction onto the computation.
XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a subtract instruction onto the computation.
XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a multiply instruction onto the computation.
XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a divide instruction onto the computation.
XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a remainder instruction onto the computation.
XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a max instruction onto the computation.
XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a min instruction onto the computation.
XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Element-wise logical operators
XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
XlaOp Not(const XlaOp& operand);
XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
- XlaOp ShiftRightArithmetic(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
- XlaOp ShiftRightLogical(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
+ XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions = {});
+ XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions = {});
// Reduces an array among the provided dimensions, given "computation" as a
// reduction operator.
XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+ absl::Span<const int64> dimensions_to_reduce);
+
+ // Reduces several arrays simultaneously among the provided dimensions, given
+ // "computation" as a reduction operator.
+ XlaOp Reduce(absl::Span<const XlaOp> operands,
+ absl::Span<const XlaOp> init_values,
+ const XlaComputation& computation,
+ absl::Span<const int64> dimensions_to_reduce);
// Convenience wrapper around the above that reduces all the dimensions in the
// operand shape.
@@ -667,25 +665,23 @@ class XlaBuilder {
// Enqueues a windowed reduce instruction onto the computation.
XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides, Padding padding);
// As ReduceWindow(), but the padding is given in the format
// returned by MakePadding().
XlaOp ReduceWindowWithGeneralPadding(
const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding);
// Returns the sum of the operand value within each subgroup of replicas. All
// replicas supply one input to the sum and all replicas receive the resulting
// sum for each subgroup.
- XlaOp CrossReplicaSum(
- const XlaOp& operand,
- tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups = {});
+ XlaOp CrossReplicaSum(const XlaOp& operand,
+ absl::Span<const ReplicaGroup> replica_groups = {});
// Enqueues an operation that do an AllReduce of the operand cross cores. Here
// AllReduce means doing a reduction on the input operand cross cores and then
@@ -707,7 +703,7 @@ class XlaBuilder {
// TODO(b/79737069): Rename this to AllReduce when it's ready to use.
XlaOp CrossReplicaSum(
const XlaOp& operand, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups = {},
+ absl::Span<const ReplicaGroup> replica_groups = {},
const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
// Enqueues an operation that do an Alltoall of the operand cross cores.
@@ -724,8 +720,8 @@ class XlaBuilder {
// Enqueues an operation that scatters the `source` array to the selected
// indices of each window.
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
Padding padding, const XlaOp& source,
const XlaOp& init_value,
const XlaComputation& scatter);
@@ -734,18 +730,17 @@ class XlaBuilder {
// returned by MakePadding().
XlaOp SelectAndScatterWithGeneralPadding(
const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const XlaOp& source, const XlaOp& init_value,
- const XlaComputation& scatter);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
+ const XlaOp& init_value, const XlaComputation& scatter);
// Enqueues an abs instruction onto the computation.
XlaOp Abs(const XlaOp& operand);
// Enqueues a atan2 instruction onto the computation.
XlaOp Atan2(const XlaOp& y, const XlaOp& x,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues an exp instruction onto the computation.
XlaOp Exp(const XlaOp& operand);
@@ -792,7 +787,7 @@ class XlaBuilder {
// Enqueues a lhs^rhs computation onto the computation.
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues an operator that tests if the operand's values are finite, i.e.,
// not Inf or NaN. Defined only for floating-point types. Returns an array of
@@ -801,10 +796,10 @@ class XlaBuilder {
XlaOp IsFinite(const XlaOp& operand);
// Enqueues an iota operation onto the computation.
- XlaOp IotaGen(const Shape& shape, int64 iota_dimension);
+ XlaOp Iota(const Shape& shape, int64 iota_dimension);
// Enqueues a rank-1 iota operation onto the computation.
- XlaOp IotaGen(PrimitiveType type, int64 size);
+ XlaOp Iota(PrimitiveType type, int64 size);
// Enqueues a convert instruction onto the computation that changes the
// element type of the operand array to primitive_type.
@@ -822,14 +817,12 @@ class XlaBuilder {
XlaOp Neg(const XlaOp& operand);
// Enqueues a transpose instruction onto the computation.
- XlaOp Transpose(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> permutation);
+ XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation);
// Enqueues a reverse instruction onto the computation. The order of the
// elements in the given dimensions is reversed (i.e., the element at index i
// is moved to index dimension_size - 1 - i).
- XlaOp Rev(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
// Enqueues a sort (as increasing order) instruction onto the computation.
// If only keys are provided:
@@ -854,10 +847,9 @@ class XlaBuilder {
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
// Enqueues a map instruction onto the computation.
- XlaOp Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<XlaOp> static_operands = {});
+ XlaOp Map(absl::Span<const XlaOp> operands, const XlaComputation& computation,
+ absl::Span<const int64> dimensions,
+ absl::Span<const XlaOp> static_operands = {});
// Enqueues a N(mu, sigma) random number generation instruction onto the
// computation.
@@ -884,7 +876,7 @@ class XlaBuilder {
// Enqueues a Gather node onto the computation.
XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
// Enqueues a Scatter node onto the computation.
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
@@ -912,7 +904,7 @@ class XlaBuilder {
// Enqueues an AfterAll operation with no operands producing a token-shaped
// value.
- XlaOp AfterAll(tensorflow::gtl::ArraySlice<XlaOp> tokens);
+ XlaOp AfterAll(absl::Span<const XlaOp> tokens);
// Enqueues a Recv node onto the computation. The data comes from a Send
// instruction that shares the same channel handle and its shape must
@@ -959,9 +951,8 @@ class XlaBuilder {
const XlaOp& grad_output, float epsilon,
int64 feature_index);
- StatusOr<XlaOp> AddInstruction(
- HloInstructionProto&& instr, HloOpcode opcode,
- tensorflow::gtl::ArraySlice<XlaOp> operands = {});
+ StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
+ absl::Span<const XlaOp> operands = {});
void AddCalledComputation(const XlaComputation& computation,
HloInstructionProto* instr);
@@ -975,19 +966,17 @@ class XlaBuilder {
// broadcast_dimensions specifies which dimensions to use for broadcasting
// when the operation is between tensors of different ranks.
XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
// Internal helper method that does the building for an arbitrary ternary op.
XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
const XlaOp& ehs);
XlaOp RngOp(RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<XlaOp> parameters,
- const Shape& shape);
+ absl::Span<const XlaOp> parameters, const Shape& shape);
- StatusOr<XlaOp> InDimBroadcast(
- const Shape& shape, const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ StatusOr<XlaOp> InDimBroadcast(const Shape& shape, const XlaOp& operand,
+ absl::Span<const int64> broadcast_dimensions);
// Internal helper method that creates a sequence of instructions that
// performs an explicit broadcast of the operand to the target shape.
@@ -1003,7 +992,7 @@ class XlaBuilder {
// Returns shapes for the operands.
StatusOr<std::vector<Shape>> GetOperandShapes(
- tensorflow::gtl::ArraySlice<XlaOp> operands) const;
+ absl::Span<const XlaOp> operands) const;
// A visitor which checks whether an operation is a compile-time constant,
// meaning that it doesn't depend on any parameters, or on any stateful
@@ -1020,12 +1009,11 @@ class XlaBuilder {
// Helper function for creating a Window proto from user-supplied data.
// Returns error if the user-supplied data was invalid.
- StatusOr<Window> MakeWindow(
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation) const;
+ StatusOr<Window> MakeWindow(absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation,
+ absl::Span<const int64> rhs_dilation) const;
string name_; // Name to use for the built computation.
@@ -1069,7 +1057,7 @@ class XlaBuilder {
friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
template <typename NativeT>
friend XlaOp ConstantR1(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<NativeT> values);
+ absl::Span<const NativeT> values);
friend XlaOp ConstantR1(XlaBuilder* builder,
const tensorflow::core::Bitmap& values);
template <typename NativeT>
@@ -1109,188 +1097,183 @@ class XlaBuilder {
friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
friend XlaOp Broadcast(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+ absl::Span<const int64> broadcast_sizes);
friend XlaOp BroadcastInDim(
const XlaOp& operand, const Shape& shape,
- const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ const absl::Span<const int64> broadcast_dimensions);
friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
const PaddingConfig& padding_config);
- friend XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
+ friend XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes);
- friend XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
+ friend XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);
friend XlaOp Collapse(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const int64> dimensions);
friend XlaOp Slice(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides);
friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index,
int64 limit_index, int64 stride, int64 dimno);
friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
const XlaOp& start_indices);
friend XlaOp ConcatInDim(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- int64 dimension);
+ absl::Span<const XlaOp> operands, int64 dimension);
friend void Trace(const string& tag, const XlaOp& operand);
friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true,
const XlaOp& on_false);
- friend XlaOp Tuple(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> elements);
+ friend XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
const PrecisionConfigProto* precision_config_proto);
friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
const DotDimensionNumbers& dimension_number,
const PrecisionConfigProto* precision_config_proto);
friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding, int64 feature_group_count,
+ absl::Span<const int64> window_strides, Padding padding,
+ int64 feature_group_count,
const PrecisionConfigProto* precision_config_proto);
friend XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count,
const PrecisionConfigProto* precision_config_proto);
friend XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count,
- const PrecisionConfigProto* precision_config_proto);
- friend XlaOp ConvGeneral(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ absl::Span<const int64> window_strides, Padding padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count,
const PrecisionConfigProto* precision_config_proto);
+ friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count,
+ const PrecisionConfigProto* precision_config_proto);
friend XlaOp ConvGeneralDilated(
const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation,
+ absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count,
const PrecisionConfigProto* precision_config_proto);
friend XlaOp Fft(const XlaOp& operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
+ absl::Span<const int64> fft_length);
friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
const string& config);
friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
const string& outfeed_config);
friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<XlaOp> operands);
+ absl::Span<const XlaOp> operands);
friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const Shape& shape);
+ absl::Span<const XlaOp> operands, const Shape& shape);
friend XlaOp Complex(const XlaOp& real, const XlaOp& imag,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Conj(const XlaOp& operand);
friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Not(const XlaOp& operand);
- friend XlaOp ShiftLeft(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp ShiftRightArithmetic(
const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp ShiftRightLogical(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
+ friend XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+ absl::Span<const int64> dimensions_to_reduce);
+ friend XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
+ absl::Span<const XlaOp> init_values,
+ const XlaComputation& computation,
+ absl::Span<const int64> dimensions_to_reduce);
friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation);
- friend XlaOp ReduceWindow(
- const XlaOp& operand, const XlaOp& init_value,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
+ friend XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ Padding padding);
friend XlaOp ReduceWindowWithGeneralPadding(
const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
- friend XlaOp CrossReplicaSum(
- const XlaOp& operand,
- tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups);
- friend XlaOp CrossReplicaSum(
- const XlaOp& operand, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups,
- const absl::optional<ChannelHandle>& channel_id);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding);
+ friend XlaOp CrossReplicaSum(const XlaOp& operand,
+ absl::Span<const ReplicaGroup> replica_groups);
+ friend XlaOp CrossReplicaSum(const XlaOp& operand,
+ const XlaComputation& computation,
+ absl::Span<const ReplicaGroup> replica_groups,
+ const absl::optional<ChannelHandle>& channel_id);
friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups);
friend XlaOp CollectivePermute(
const XlaOp& operand,
const std::vector<std::pair<int64, int64>>& source_target_pairs);
- friend XlaOp SelectAndScatter(
- const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const XlaOp& source, const XlaOp& init_value,
- const XlaComputation& scatter);
+ friend XlaOp SelectAndScatter(const XlaOp& operand,
+ const XlaComputation& select,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ Padding padding, const XlaOp& source,
+ const XlaOp& init_value,
+ const XlaComputation& scatter);
friend XlaOp SelectAndScatterWithGeneralPadding(
const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const XlaOp& source, const XlaOp& init_value,
- const XlaComputation& scatter);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
+ const XlaOp& init_value, const XlaComputation& scatter);
friend XlaOp Abs(const XlaOp& operand);
friend XlaOp Atan2(const XlaOp& y, const XlaOp& x,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Exp(const XlaOp& operand);
friend XlaOp Expm1(const XlaOp& operand);
friend XlaOp Floor(const XlaOp& operand);
@@ -1306,29 +1289,25 @@ class XlaBuilder {
friend XlaOp Real(const XlaOp& operand);
friend XlaOp Imag(const XlaOp& operand);
friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp IsFinite(const XlaOp& operand);
- // TODO(b/64798317): Finish CPU & GPU implementation, then replace xla::Iota
- // in xla/client/lib/numeric.h with this (renamed to xla::Iota).
- friend XlaOp IotaGen(XlaBuilder* builder, const Shape& shape,
- int64 iota_dimension);
- friend XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size);
+ friend XlaOp Iota(XlaBuilder* builder, const Shape& shape,
+ int64 iota_dimension);
+ friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
friend XlaOp ConvertElementType(const XlaOp& operand,
PrimitiveType new_element_type);
friend XlaOp BitcastConvertType(const XlaOp& operand,
PrimitiveType new_element_type);
friend XlaOp Neg(const XlaOp& operand);
friend XlaOp Transpose(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> permutation);
- friend XlaOp Rev(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const int64> permutation);
+ friend XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
friend XlaOp Sort(XlaOp keys, absl::optional<XlaOp> values, int64 dimension);
friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
- friend XlaOp Map(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
+ friend XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<XlaOp> static_operands);
+ absl::Span<const int64> dimensions,
+ absl::Span<const XlaOp> static_operands);
friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma,
const Shape& shape);
friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
@@ -1342,7 +1321,7 @@ class XlaBuilder {
const int mantissa_bits);
friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
const XlaOp& updates,
const XlaComputation& update_computation,
@@ -1376,8 +1355,7 @@ class XlaBuilder {
const Shape& shape_with_layout,
const string& outfeed_config);
friend XlaOp CreateToken(XlaBuilder* builder);
- friend XlaOp AfterAll(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> tokens);
+ friend XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
};
// RAII-style object: sets the current sharding assignment in builder on
@@ -1441,8 +1419,7 @@ XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal);
template <typename NativeT>
XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
template <typename NativeT>
-XlaOp ConstantR1(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<NativeT> values);
+XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values);
XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values);
template <typename NativeT>
XlaOp ConstantR2(XlaBuilder* builder,
@@ -1491,8 +1468,7 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
// The new dimensions index into copies of the operand, i.e.
//
// output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
-XlaOp Broadcast(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+XlaOp Broadcast(const XlaOp& operand, absl::Span<const int64> broadcast_sizes);
// Performs in-dimension-style broadcast.
//
@@ -1511,9 +1487,8 @@ XlaOp Broadcast(const XlaOp& operand,
// will generate output
// [1 , 1]
// [2 , 2]
-XlaOp BroadcastInDim(
- const XlaOp& operand, const Shape& shape,
- const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape,
+ const absl::Span<const int64> broadcast_dimensions);
// Enqueues a pad operation onto the computation that pads the given value on
// the edges as well as between the elements of the input. padding_config
@@ -1526,15 +1501,13 @@ XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
// given, followed by reshaping it into the shape with the given dimension
// sizes (also major to minor). Conceptually, this is a limited form of
// "shape casting".
-XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
+XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes);
// Enqueues an operation onto the computation that collapses the operand, from
// first to last dimension (C order), then reshapes it to the given dimension
// sizes. Conceptually, this is a limited form of "shape casting".
-XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
+XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);
// Wrapper for Reshape.
// Enqueues an operation to collapse the provided dimensions; e.g. an
@@ -1554,8 +1527,7 @@ XlaOp Reshape(const XlaOp& operand,
//
// This could potentially cause data to be moved -- it provides a more
// structured form of reshaping than an arbitrary Reshape operation.
-XlaOp Collapse(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions);
// Enqueues a slice operation onto the computation that slices the operand
// from the start indices to the limit indices; e.g.
@@ -1568,10 +1540,9 @@ XlaOp Collapse(const XlaOp& operand,
// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
// range notation.
// The strides parameter determines the stride over the slice
-XlaOp Slice(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
+XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides);
// Enqueues a slice operation in a given dimension, taking all other
// dimensions as they are; e.g. if dimno is 1 from start_index 2 to
@@ -1592,7 +1563,7 @@ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
// Slice index calculations are computed modulo input dimension sizes to
// prevent dynamic start indices from generating out-of-bound array accesses.
XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
// Enqueues a dynamic update slice operation onto the computation, which
// updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
@@ -1615,8 +1586,8 @@ XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
// Enqueues a concatenate instruction onto the computation. 'operands' must
// have >= 1 entry.
-XlaOp ConcatInDim(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> operands, int64 dimension);
+XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
+ int64 dimension);
// Enqueue a tracing operation onto the computation; the computation will emit
// a logging message with the operand.
@@ -1627,34 +1598,34 @@ void Trace(const string& tag, const XlaOp& operand);
XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
// Enqueues a tuple-creation instruction onto the computation.
-XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> elements);
+XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
// Enqueues a tuple-element-get instruction onto the computation.
XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
// Enqueues an equal-to comparison instruction onto the computation.
XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a not-equal comparison instruction onto the computation.
XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a greater-or-equal comparison instruction onto the computation.
XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a greater-than comparison instruction onto the computation.
XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a less-than comparison instruction onto the computation.
XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a less-or-equal comparison instruction onto the computation.
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a dot instruction onto the computation.
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
@@ -1668,33 +1639,31 @@ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ absl::Span<const int64> window_strides, Padding padding,
int64 feature_group_count = 1,
const PrecisionConfigProto* precision_config_proto = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration in the format returned by MakePadding().
XlaOp ConvWithGeneralPadding(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count = 1,
const PrecisionConfigProto* precision_config_proto = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided dimension numbers configuration.
XlaOp ConvWithGeneralDimensions(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers,
+ const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
+ Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1,
const PrecisionConfigProto* precision_config_proto = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration as well as the dimension numbers.
XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1,
const PrecisionConfigProto* precision_config_proto = nullptr);
@@ -1702,11 +1671,9 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration, dilation factors and dimension numbers.
XlaOp ConvGeneralDilated(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1,
const PrecisionConfigProto* precision_config_proto = nullptr);
@@ -1714,7 +1681,7 @@ XlaOp ConvGeneralDilated(
// Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length.
XlaOp Fft(const XlaOp& operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
+ absl::Span<const int64> fft_length);
// Enqueues an infeed instruction onto the computation, which writes data of
// the given shape to the infeed buffer of the device.
@@ -1746,15 +1713,14 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
// Enqueues a call instruction onto the computation.
XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<XlaOp> operands);
+ absl::Span<const XlaOp> operands);
// Enqueues a custom call instruction onto the computation.
// During code generation, a call instruction is emitted which targets a
// symbol with the name |call_target_name|. The |operands| are passed to the
// call instruction. |shape| is the resultant shape.
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const Shape& shape);
+ absl::Span<const XlaOp> operands, const Shape& shape);
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
@@ -1763,65 +1729,70 @@ XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
// Enqueues a complex compose instruction onto the computation.
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a complex conjugate instruction onto the computation.
XlaOp Conj(const XlaOp& operand);
// Enqueues an add instruction onto the computation.
XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a subtract instruction onto the computation.
XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a multiply instruction onto the computation.
XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a divide instruction onto the computation.
XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a remainder instruction onto the computation.
XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a max instruction onto the computation.
XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a min instruction onto the computation.
XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Element-wise logical operators
XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
XlaOp Not(const XlaOp& operand);
XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-XlaOp ShiftRightArithmetic(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-XlaOp ShiftRightLogical(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
+XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions = {});
+XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions = {});
// Reduces an array among the provided dimensions, given "computation" as a
// reduction operator.
XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+ absl::Span<const int64> dimensions_to_reduce);
+
+// Reduces several arrays simultaneously among the provided dimensions, given
+// "computation" as a reduction operator.
+XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
+ absl::Span<const XlaOp> init_values,
+ const XlaComputation& computation,
+ absl::Span<const int64> dimensions_to_reduce);
// Convenience wrapper around the above that reduces all the dimensions in the
// operand shape.
@@ -1831,25 +1802,23 @@ XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
// Enqueues a windowed reduce instruction onto the computation.
XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides, Padding padding);
// As ReduceWindow(), but the padding is given in the format
// returned by MakePadding().
XlaOp ReduceWindowWithGeneralPadding(
const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding);
// Returns the sum of the operand value within each subgroup of replicas. All
// replicas supply one input to the sum and all replicas receive the resulting
// sum for each subgroup.
-XlaOp CrossReplicaSum(
- const XlaOp& operand,
- tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups = {});
+XlaOp CrossReplicaSum(const XlaOp& operand,
+ absl::Span<const ReplicaGroup> replica_groups = {});
// Enqueues an operation that do an AllReduce of the operand cross cores. Here
// AllReduce means doing a reduction on the input operand cross cores and then
@@ -1870,7 +1839,7 @@ XlaOp CrossReplicaSum(
// TODO(b/79737069): Rename this to AllReduce when it's ready to use.
XlaOp CrossReplicaSum(
const XlaOp& operand, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups = {},
+ absl::Span<const ReplicaGroup> replica_groups = {},
const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
// Enqueues an operation that do an Alltoall of the operand cross cores.
@@ -1893,27 +1862,26 @@ XlaOp CollectivePermute(
// Enqueues an operation that scatters the `source` array to the selected
// indices of each window.
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding, const XlaOp& source,
- const XlaOp& init_value, const XlaComputation& scatter);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides, Padding padding,
+ const XlaOp& source, const XlaOp& init_value,
+ const XlaComputation& scatter);
// As SelectAndScatter(), but the padding is given in the format
// returned by MakePadding().
XlaOp SelectAndScatterWithGeneralPadding(
const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const XlaOp& source, const XlaOp& init_value,
- const XlaComputation& scatter);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
+ const XlaOp& init_value, const XlaComputation& scatter);
// Enqueues an abs instruction onto the computation.
XlaOp Abs(const XlaOp& operand);
// Enqueues a atan2 instruction onto the computation.
XlaOp Atan2(const XlaOp& y, const XlaOp& x,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues an exp instruction onto the computation.
XlaOp Exp(const XlaOp& operand);
@@ -1960,7 +1928,7 @@ XlaOp Imag(const XlaOp& operand);
// Enqueues a lhs^rhs computation onto the computation.
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues an operator that tests if the operand's values are finite, i.e.,
// not Inf or NaN. Defined only for floating-point types. Returns an array of
@@ -1969,10 +1937,10 @@ XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
XlaOp IsFinite(const XlaOp& operand);
// Enqueues an iota operation onto the computation.
-XlaOp IotaGen(XlaBuilder* builder, const Shape& shape, int64 iota_dimension);
+XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension);
// Enqueues a rank-1 iota operation onto the computation.
-XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size);
+XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
// Enqueues a convert instruction onto the computation that changes the
// element type of the operand array to primitive_type.
@@ -1988,13 +1956,12 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type);
XlaOp Neg(const XlaOp& operand);
// Enqueues a transpose instruction onto the computation.
-XlaOp Transpose(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> permutation);
+XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation);
// Enqueues a reverse instruction onto the computation. The order of the
// elements in the given dimensions is reversed (i.e., the element at index i
// is moved to index dimension_size - 1 - i).
-XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions);
+XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
// Enqueues a sort (as increasing order) instruction onto the computation.
// If only keys are provided:
@@ -2019,10 +1986,9 @@ XlaOp Sort(XlaOp keys, absl::optional<XlaOp> values = absl::nullopt,
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
// Enqueues a map instruction onto the computation.
-XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> operands,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<XlaOp> static_operands = {});
+XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
+ const XlaComputation& computation, absl::Span<const int64> dimensions,
+ absl::Span<const XlaOp> static_operands = {});
// Enqueues a N(mu, sigma) random number generation instruction onto the
// computation.
@@ -2049,7 +2015,7 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
// Enqueues a Gather node onto the computation.
XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
// Enqueues a Scatter node onto the computation.
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
@@ -2107,7 +2073,7 @@ XlaOp CreateToken(XlaBuilder* builder);
// Enqueues an AfterAll instruction which produces a token-shaped value and
// takes a variadic number of token-shaped operands. The number of operands must
// be greater than zero. Used for joining tokens.
-XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> tokens);
+XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
// Normalizes operand across spatial and batch dimensions for each feature.
//
@@ -2155,7 +2121,7 @@ XlaOp XlaBuilder::ConstantR0(NativeT value) {
}
template <typename NativeT>
-XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values) {
+XlaOp XlaBuilder::ConstantR1(absl::Span<const NativeT> values) {
return ConstantLiteral(*LiteralUtil::CreateR1<NativeT>(values));
}
@@ -2232,8 +2198,7 @@ XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
}
template <typename NativeT>
-XlaOp ConstantR1(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<NativeT> values) {
+XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
return ConstantLiteral(builder, *LiteralUtil::CreateR1<NativeT>(values));
}
diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc
index 693dcb3a3e..3fadabcf52 100644
--- a/tensorflow/compiler/xla/index_util.cc
+++ b/tensorflow/compiler/xla/index_util.cc
@@ -27,7 +27,7 @@ limitations under the License.
namespace xla {
/* static */ int64 IndexUtil::MultidimensionalIndexToLinearIndex(
- const Shape& shape, tensorflow::gtl::ArraySlice<int64> multi_index) {
+ const Shape& shape, absl::Span<const int64> multi_index) {
DCHECK_EQ(shape.dimensions_size(), multi_index.size());
// Padding and nested layouts not supported yet.
DCHECK_EQ(0, shape.layout().padded_dimensions_size());
@@ -118,8 +118,8 @@ namespace xla {
return multi_index;
}
-/* static */ bool IndexUtil::BumpIndices(
- const Shape& shape, tensorflow::gtl::MutableArraySlice<int64> indices) {
+/* static */ bool IndexUtil::BumpIndices(const Shape& shape,
+ absl::Span<int64> indices) {
for (int64 dimno = indices.size() - 1; dimno >= 0; --dimno) {
int64 limit = shape.dimensions(dimno);
if (indices[dimno] + 1 < limit) {
@@ -149,8 +149,8 @@ namespace xla {
return stride;
}
-/* static */ bool IndexUtil::IndexInBounds(
- const Shape& shape, tensorflow::gtl::ArraySlice<int64> index) {
+/* static */ bool IndexUtil::IndexInBounds(const Shape& shape,
+ absl::Span<const int64> index) {
int64 rank = ShapeUtil::Rank(shape);
if (rank != index.size()) {
return false;
@@ -163,9 +163,8 @@ namespace xla {
return true;
}
-/* static */ int IndexUtil::CompareIndices(
- tensorflow::gtl::ArraySlice<int64> lhs,
- tensorflow::gtl::ArraySlice<int64> rhs) {
+/* static */ int IndexUtil::CompareIndices(absl::Span<const int64> lhs,
+ absl::Span<const int64> rhs) {
int64 rank = lhs.size();
CHECK_EQ(rhs.size(), rank);
for (int64 dim = 0; dim < rank; ++dim) {
diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h
index 142006f262..2979cf87dd 100644
--- a/tensorflow/compiler/xla/index_util.h
+++ b/tensorflow/compiler/xla/index_util.h
@@ -20,9 +20,9 @@ limitations under the License.
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@@ -35,7 +35,7 @@ class IndexUtil {
// on the shape and its layout. The first index in the multi_index is
// dimension 0.
static int64 MultidimensionalIndexToLinearIndex(
- const Shape& shape, tensorflow::gtl::ArraySlice<int64> multi_index);
+ const Shape& shape, absl::Span<const int64> multi_index);
// Converts a linear index into multidimensional index (eg {x, y, z}) based on
// the shape and its layout. The first index in the returned multidimensional
@@ -58,8 +58,7 @@ class IndexUtil {
//
// Returns true iff the indices were successfully bumped; false if we've hit
// the limit where it can no longer be bumped in-bounds.
- static bool BumpIndices(const Shape& shape,
- tensorflow::gtl::MutableArraySlice<int64> indices);
+ static bool BumpIndices(const Shape& shape, absl::Span<int64> indices);
// Calculates the stride size (in number of elements, not byte size) of a
// given logical shape dimension (from 0 to rank-1). If available, padded
@@ -71,15 +70,14 @@ class IndexUtil {
// Returns true iff the given multi-index is contained in the bounds for the
// shape.
- static bool IndexInBounds(const Shape& shape,
- tensorflow::gtl::ArraySlice<int64> index);
+ static bool IndexInBounds(const Shape& shape, absl::Span<const int64> index);
// Compares the given indices in lexicographic order. lhs[0] and rhs[0] are
// compared first, and lhs[rank-1] and rhs[rank-1] last. If lhs is larger,
// then -1 is returned. If rhs is larger, then 1 is returned. Otherwise, 0 is
// returned.
- static int CompareIndices(tensorflow::gtl::ArraySlice<int64> lhs,
- tensorflow::gtl::ArraySlice<int64> rhs);
+ static int CompareIndices(absl::Span<const int64> lhs,
+ absl::Span<const int64> rhs);
private:
TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil);
diff --git a/tensorflow/compiler/xla/index_util_test.cc b/tensorflow/compiler/xla/index_util_test.cc
index 7c4efdee48..93522d2ca8 100644
--- a/tensorflow/compiler/xla/index_util_test.cc
+++ b/tensorflow/compiler/xla/index_util_test.cc
@@ -142,13 +142,13 @@ TEST(IndexUtilTest, LinearToMultiToLinear) {
TEST(IndexUtilTest, BumpIndices2x2) {
auto shape = ShapeUtil::MakeShape(S32, {2, 2});
std::vector<int64> indices = {0, 0};
- EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices));
+ EXPECT_TRUE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices)));
EXPECT_THAT(indices, ::testing::ElementsAre(0, 1));
- EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices));
+ EXPECT_TRUE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices)));
EXPECT_THAT(indices, ::testing::ElementsAre(1, 0));
- EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices));
+ EXPECT_TRUE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices)));
EXPECT_THAT(indices, ::testing::ElementsAre(1, 1));
- EXPECT_FALSE(IndexUtil::BumpIndices(shape, &indices));
+ EXPECT_FALSE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices)));
}
} // namespace
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index cce1838ef3..d310335618 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -56,7 +56,7 @@ void SetDefaultLayoutToContainer(
} // namespace
/* static */ Layout LayoutUtil::MakeLayout(
- tensorflow::gtl::ArraySlice<int64> minor_to_major) {
+ absl::Span<const int64> minor_to_major) {
Layout layout;
layout.set_format(DENSE);
for (int64 dimension_number : minor_to_major) {
@@ -66,7 +66,7 @@ void SetDefaultLayoutToContainer(
}
/* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor(
- tensorflow::gtl::ArraySlice<int64> major_to_minor) {
+ absl::Span<const int64> major_to_minor) {
Layout layout;
layout.set_format(DENSE);
for (int i = major_to_minor.size() - 1; i >= 0; i--) {
@@ -307,7 +307,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
return false;
}
-/* static */ tensorflow::gtl::ArraySlice<int64> LayoutUtil::PaddedDimensions(
+/* static */ absl::Span<const int64> LayoutUtil::PaddedDimensions(
const Shape& shape) {
CHECK(IsDenseArray(shape));
return AsInt64Slice(shape.layout().padded_dimensions());
@@ -363,13 +363,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
return protobuf_util::ProtobufEquals(lhs, rhs);
}
-/* static */ tensorflow::gtl::ArraySlice<int64> LayoutUtil::MinorToMajor(
+/* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
const Shape& shape) {
CHECK(IsDenseArray(shape));
return AsInt64Slice(shape.layout().minor_to_major());
}
-/* static */ tensorflow::gtl::ArraySlice<int64> LayoutUtil::MinorToMajor(
+/* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
const Layout& layout) {
CHECK(layout.format() == DENSE);
return AsInt64Slice(layout.minor_to_major());
@@ -472,7 +472,7 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
}
/* static */ bool LayoutUtil::AreDimensionsConsecutive(
- const Layout& layout, tensorflow::gtl::ArraySlice<int64> dims) {
+ const Layout& layout, absl::Span<const int64> dims) {
CHECK(IsDense(layout));
std::vector<int64> positions_in_layout;
for (int64 dim : dims) {
diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h
index 739bbe7367..b78883c2d8 100644
--- a/tensorflow/compiler/xla/layout_util.h
+++ b/tensorflow/compiler/xla/layout_util.h
@@ -20,10 +20,10 @@ limitations under the License.
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -34,11 +34,11 @@ class LayoutUtil {
public:
// Creates a layout with the given minor-to-major dimension order. (This is a
// convenience function for protobuf construction.)
- static Layout MakeLayout(tensorflow::gtl::ArraySlice<int64> minor_to_major);
+ static Layout MakeLayout(absl::Span<const int64> minor_to_major);
// Similar to MakeLayout, but take indices in reverse order.
static Layout MakeLayoutFromMajorToMinor(
- tensorflow::gtl::ArraySlice<int64> major_to_minor);
+ absl::Span<const int64> major_to_minor);
// Creates a sparse layout with the given maximum number of elements. (This is
// a convenience function for protobuf construction.)
@@ -104,8 +104,7 @@ class LayoutUtil {
// Returns the padded_dimensions array for the given Shape. Requires that the
// shape is an array and has a dense layout.
- static tensorflow::gtl::ArraySlice<int64> PaddedDimensions(
- const Shape& shape);
+ static absl::Span<const int64> PaddedDimensions(const Shape& shape);
// Returns the given index of the padded_dimensions array for the given Shape.
// Requires that the shape is an array and has a dense layout.
@@ -138,8 +137,8 @@ class LayoutUtil {
// Returns the minor_to_major array for the given Shape. Requires that the
// shape is an array and has a dense layout.
- static tensorflow::gtl::ArraySlice<int64> MinorToMajor(const Shape& shape);
- static tensorflow::gtl::ArraySlice<int64> MinorToMajor(const Layout& layout);
+ static absl::Span<const int64> MinorToMajor(const Shape& shape);
+ static absl::Span<const int64> MinorToMajor(const Layout& layout);
// Major(0) is the most major logical dimension number, Major(1) is the
// second-most-major logical dimension number and so on.
@@ -196,7 +195,7 @@ class LayoutUtil {
// Returns whether the given dimensions are consecutive in the given layout,
// not necessarily in the order given.
static bool AreDimensionsConsecutive(const Layout& layout,
- tensorflow::gtl::ArraySlice<int64> dims);
+ absl::Span<const int64> dims);
// Compute a hash for `layout`.
static size_t Hash(const Layout& layout);
diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc
index e4c825450d..f25dae6ff4 100644
--- a/tensorflow/compiler/xla/layout_util_test.cc
+++ b/tensorflow/compiler/xla/layout_util_test.cc
@@ -27,15 +27,15 @@ namespace {
class LayoutUtilTest : public ::testing::Test {
protected:
Shape MakeShapeWithLayout(PrimitiveType element_type,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> minor_to_major) {
+ absl::Span<const int64> dimensions,
+ absl::Span<const int64> minor_to_major) {
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
*shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
return shape;
}
Shape MakeShapeWithSparseLayout(PrimitiveType element_type,
- tensorflow::gtl::ArraySlice<int64> dimensions,
+ absl::Span<const int64> dimensions,
int64 max_sparse_elements) {
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
*shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index 93e808469a..3f7635bd40 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -73,7 +73,7 @@ std::ostream& operator<<(std::ostream& out, const Literal& literal) {
MutableLiteralBase::StrideConfig::StrideConfig(
const Shape& source_shape, const Shape& dest_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions)
+ absl::Span<const int64> dimensions)
: dimensions(dimensions),
base(dimensions.size(), 0),
step(dimensions.size(), 1) {
@@ -197,14 +197,13 @@ SparseIndexArray* MutableLiteralBase::sparse_indices(
template <typename NativeT>
Status MutableLiteralBase::CopySliceFromInternal(
- const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size) {
+ const LiteralBase& src_literal, absl::Span<const int64> src_base,
+ absl::Span<const int64> dest_base, absl::Span<const int64> copy_size) {
TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size());
TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size());
auto linear_index = [](const Shape& shape,
- tensorflow::gtl::ArraySlice<int64> multi_index) {
+ absl::Span<const int64> multi_index) {
return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
};
@@ -232,7 +231,7 @@ Status MutableLiteralBase::CopySliceFromInternal(
MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(),
copy_size);
- auto copy_proc = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+ auto copy_proc = [&](absl::Span<const int64> indexes) {
// Map from multi-dimensional index, to source index.
std::transform(indexes.begin(), indexes.end(), src_base.begin(),
src_indexes.begin(), std::plus<int64>());
@@ -257,10 +256,9 @@ Status MutableLiteralBase::CopySliceFromInternal(
return Status::OK();
}
-Status MutableLiteralBase::CopyElementFrom(
- const LiteralSlice& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_index,
- tensorflow::gtl::ArraySlice<int64> dest_index) {
+Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
+ absl::Span<const int64> src_index,
+ absl::Span<const int64> dest_index) {
DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(
src_literal.shape(), src_index);
@@ -355,9 +353,9 @@ namespace {
// Copies the elements in 'src' to 'dest'. The shape and layout of the data in
// the array slices are indicated by dest_shape and src_shape respectively.
template <typename NativeT>
-void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
- tensorflow::gtl::ArraySlice<NativeT> src,
- const Shape& dest_shape, const Shape& src_shape) {
+void CopyElementsBetween(absl::Span<NativeT> dest,
+ absl::Span<const NativeT> src, const Shape& dest_shape,
+ const Shape& src_shape) {
CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
if (ShapeUtil::IsZeroElementArray(dest_shape)) {
return;
@@ -366,7 +364,7 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
do {
dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
- } while (IndexUtil::BumpIndices(dest_shape, &index));
+ } while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index)));
}
} // namespace
@@ -487,11 +485,10 @@ Status Literal::MoveFrom(Literal&& src_literal,
return Status::OK();
}
-Status MutableLiteralBase::CopySliceFrom(
- const LiteralSlice& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size) {
+Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal,
+ absl::Span<const int64> src_base,
+ absl::Span<const int64> dest_base,
+ absl::Span<const int64> copy_size) {
TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape());
TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape()))
<< ShapeUtil::HumanString(src_literal.shape());
@@ -591,8 +588,7 @@ std::unique_ptr<Literal> LiteralBase::Relayout(
}
StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
- const Shape& result_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions) const {
+ const Shape& result_shape, absl::Span<const int64> dimensions) const {
if (!ShapeUtil::IsArray(shape())) {
return InvalidArgument("Broadcast only supports arrays.");
}
@@ -615,7 +611,7 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
ShapeUtil::ForEachIndex(
- result_shape, [&](tensorflow::gtl::ArraySlice<int64> output_index) {
+ result_shape, [&](absl::Span<const int64> output_index) {
for (int64 i = 0; i < dimensions.size(); ++i) {
scratch_source_index[i] = output_index[dimensions[i]];
}
@@ -632,7 +628,7 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
}
StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
- tensorflow::gtl::ArraySlice<int64> dimensions) const {
+ absl::Span<const int64> dimensions) const {
if (!ShapeUtil::IsArray(shape())) {
return InvalidArgument("Reshape does not support tuples.");
}
@@ -661,7 +657,7 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
}
std::unique_ptr<Literal> LiteralBase::Transpose(
- tensorflow::gtl::ArraySlice<int64> permutation) const {
+ absl::Span<const int64> permutation) const {
CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
<< "Given permutation is not a permutation of dimension numbers";
@@ -700,12 +696,11 @@ std::unique_ptr<Literal> LiteralBase::Transpose(
template <typename NativeT>
std::unique_ptr<Literal> LiteralBase::SliceInternal(
- const Shape& result_shape,
- tensorflow::gtl::ArraySlice<int64> start_indices) const {
+ const Shape& result_shape, absl::Span<const int64> start_indices) const {
auto result_literal = absl::make_unique<Literal>(result_shape);
DimensionVector new_indices(ShapeUtil::Rank(result_shape));
result_literal->EachCell<NativeT>(
- [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT /*value*/) {
+ [&](absl::Span<const int64> indices, NativeT /*value*/) {
for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
new_indices[i] = indices[i] + start_indices[i];
}
@@ -716,8 +711,8 @@ std::unique_ptr<Literal> LiteralBase::SliceInternal(
}
std::unique_ptr<Literal> LiteralBase::Slice(
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices) const {
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices) const {
CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
DimensionVector result_dimensions;
@@ -761,7 +756,7 @@ std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
return result;
}
-string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
+string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
const ShapeIndex& shape_index) const {
const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
CHECK(LayoutUtil::IsDenseArray(subshape));
@@ -858,7 +853,7 @@ string LiteralBase::GetSparseElementAsString(
}
StatusOr<int64> LiteralBase::GetIntegralAsS64(
- tensorflow::gtl::ArraySlice<int64> multi_index) const {
+ absl::Span<const int64> multi_index) const {
CHECK(LayoutUtil::IsDenseArray(shape()));
switch (shape().element_type()) {
case PRED:
@@ -900,8 +895,8 @@ size_t LiteralBase::Hash() const {
return hash_value;
}
-Status MutableLiteralBase::SetIntegralAsS64(
- tensorflow::gtl::ArraySlice<int64> multi_index, int64 value) {
+Status MutableLiteralBase::SetIntegralAsS64(absl::Span<const int64> multi_index,
+ int64 value) {
CHECK(LayoutUtil::IsDenseArray(shape()));
switch (shape().element_type()) {
case PRED:
@@ -929,7 +924,7 @@ Status MutableLiteralBase::SetIntegralAsS64(
return Status::OK();
}
-tensorflow::gtl::ArraySlice<int64> LiteralBase::GetSparseIndex(
+absl::Span<const int64> LiteralBase::GetSparseIndex(
int64 sparse_element_number, const ShapeIndex& shape_index) const {
const Piece& p = piece(shape_index);
CHECK_GE(sparse_element_number, 0);
@@ -998,7 +993,7 @@ void LiteralBase::Piece::SortSparseElementsInternal() {
auto values = data<NativeT>();
CHECK_LE(num_elements, values.size());
sparse_indices()->SortWithValues(
- tensorflow::gtl::MutableArraySlice<NativeT>(values.data(), num_elements));
+ absl::Span<NativeT>(values.data(), num_elements));
}
namespace {
@@ -1064,8 +1059,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
CHECK(LayoutUtil::IsDenseArray(subshape));
- auto element_to_string =
- [&](tensorflow::gtl::ArraySlice<int64> indices) -> string {
+ auto element_to_string = [&](absl::Span<const int64> indices) -> string {
PrimitiveType element_type = subshape.element_type();
if (element_type == PRED) {
// We display predicates in a densely packed form.
@@ -1160,7 +1154,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
pieces->push_back(shape_to_string(subshape));
pieces->push_back(" {");
literal.EachCellAsString(
- [&](tensorflow::gtl::ArraySlice<int64> indices, const string& value) {
+ [&](absl::Span<const int64> indices, const string& value) {
pieces->push_back(" ");
pieces->push_back(value);
});
@@ -1183,7 +1177,7 @@ string LiteralBase::ToString(bool print_layout) const {
}
void LiteralBase::EachCellAsString(
- const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ const std::function<void(absl::Span<const int64> indices,
const string& value)>& per_cell) const {
if (ShapeUtil::IsZeroElementArray(shape())) {
return;
@@ -1192,7 +1186,7 @@ void LiteralBase::EachCellAsString(
shape(), /*linear_index=*/0);
do {
per_cell(indices, GetAsString(indices));
- } while (IndexUtil::BumpIndices(shape(), &indices));
+ } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
}
namespace {
@@ -1250,10 +1244,8 @@ std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
ShapeUtil::ChangeElementType(src_literal.shape(), C64));
using NativeSrcT =
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
- tensorflow::gtl::ArraySlice<NativeSrcT> src_data =
- src_literal.data<NativeSrcT>();
- tensorflow::gtl::MutableArraySlice<complex64> dest_data =
- result_literal->data<complex64>();
+ absl::Span<const NativeSrcT> src_data = src_literal.data<NativeSrcT>();
+ absl::Span<complex64> dest_data = result_literal->data<complex64>();
int64 num_elements = src_literal.element_count();
for (int64 i = 0; i < num_elements; ++i) {
dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
@@ -1392,12 +1384,12 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
elements.push_back(std::move(*new_element));
}
auto converted = absl::make_unique<Literal>();
- *converted = MutableLiteralBase::MoveIntoTuple(&elements);
+ *converted = MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
return std::move(converted);
}
/* static */ Literal MutableLiteralBase::MoveIntoTuple(
- tensorflow::gtl::MutableArraySlice<Literal> elements) {
+ absl::Span<Literal> elements) {
std::vector<Shape> element_shapes;
for (const Literal& element : elements) {
element_shapes.push_back(element.shape());
@@ -1488,7 +1480,7 @@ bool LiteralBase::operator==(const LiteralBase& other) const {
namespace {
template <typename NativeT>
-static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data,
+static bool AllElementsEqualValue(absl::Span<const NativeT> data,
NativeT value) {
for (int64 i = 0; i < data.size(); ++i) {
if (data[i] != value) {
@@ -1687,7 +1679,62 @@ bool LiteralBase::IsAllFirst() const {
});
}
-bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
+bool LiteralBase::IsR1Iota() const {
+ if (!ShapeUtil::IsArray(shape())) {
+ return false;
+ }
+
+ if (ShapeUtil::Rank(shape()) != 1) {
+ return false;
+ }
+
+ auto is_iota_at_idx = [&](const int64 idx) {
+ switch (shape().element_type()) {
+ case U8:
+ return Get<uint8>({idx}) == idx;
+ case U16:
+ return Get<uint16>({idx}) == idx;
+ case U32:
+ return Get<uint32>({idx}) == idx;
+ case U64:
+ return Get<uint64>({idx}) == idx;
+ case S8:
+ return Get<int8>({idx}) == idx;
+ case S16:
+ return Get<int16>({idx}) == idx;
+ case S32:
+ return Get<int32>({idx}) == idx;
+ case S64:
+ return Get<int64>({idx}) == idx;
+ case F32:
+ return Get<float>({idx}) == idx;
+ case F64:
+ return Get<double>({idx}) == idx;
+ case F16:
+ return Get<half>({idx}) == static_cast<half>(idx);
+ case BF16:
+ return Get<bfloat16>({idx}) == static_cast<bfloat16>(idx);
+ case C64:
+ return Get<complex64>({idx}) == complex64(idx, 0.0f);
+ case PRED:
+ return Get<bool>({idx}) == idx;
+ // token, opaque, tuple, etc. are all not iota.
+ default:
+ return false;
+ }
+ };
+
+ const int64 elements = ShapeUtil::ElementsIn(shape());
+ for (int64 idx = 0; idx < elements; ++idx) {
+ if (!is_iota_at_idx(idx)) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool LiteralBase::IsZero(absl::Span<const int64> indices) const {
CHECK(ShapeUtil::IsArray(shape()));
switch (shape().element_type()) {
case U8:
@@ -1723,7 +1770,7 @@ namespace {
template <typename RepeatedFieldT, typename NativeT>
void CopyToRepeatedField(RepeatedFieldT* dest,
- const tensorflow::gtl::ArraySlice<NativeT> src) {
+ const absl::Span<const NativeT> src) {
*dest = RepeatedFieldT(src.begin(), src.end());
}
@@ -1801,7 +1848,7 @@ void* LiteralBase::Piece::untyped_data() {
namespace {
template <typename RepeatedFieldT, typename NativeT>
-Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
+Status CopyFromRepeatedField(absl::Span<NativeT> dest,
const RepeatedFieldT& src) {
if (dest.size() != src.size()) {
return InvalidArgument(
@@ -2071,8 +2118,8 @@ BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
root_piece_.set_subshape(shape_.get());
}
-BorrowingLiteral::BorrowingLiteral(
- tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs, const Shape& shape)
+BorrowingLiteral::BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
+ const Shape& shape)
: LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
CHECK(ShapeUtil::IsTuple(*shape_));
CHECK(!ShapeUtil::IsNestedTuple(*shape_));
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index aad435ed5b..b928cb6374 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -27,6 +27,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -41,7 +42,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/bitmap.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -70,13 +70,12 @@ class LiteralBase {
// Serialize to proto.
LiteralProto ToProto() const;
- // Returns an ArraySlice of the array for this literal for the given NativeT
+ // Returns a Span of the array for this literal for the given NativeT
// (e.g., float). CHECKs if the subshape of the literal at the given
// ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
// to native type.
template <typename NativeT>
- tensorflow::gtl::ArraySlice<NativeT> data(
- const ShapeIndex& shape_index = {}) const;
+ absl::Span<const NativeT> data(const ShapeIndex& shape_index = {}) const;
// Returns a const pointer to the sparse index array. Returns nullptr if the
// literal is not a sparse array.
@@ -100,12 +99,12 @@ class LiteralBase {
// Gets an element in the literal at the given index. The multi_index is
// CHECKed against the dimension sizes.
template <typename NativeT>
- NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index,
+ NativeT Get(absl::Span<const int64> multi_index,
const ShapeIndex& shape_index) const;
// Overloads of Get for array literals. CHECKs if the literal is not
// array-shaped and dense.
template <typename NativeT>
- NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
+ NativeT Get(absl::Span<const int64> multi_index) const;
// Returns the element value at index (0, ..., 0), however many zeroes are
// required for that index.
@@ -114,7 +113,7 @@ class LiteralBase {
// As Get(), but determines the correct type and converts the value
// into text.
- string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
+ string GetAsString(absl::Span<const int64> multi_index,
const ShapeIndex& shape_index = {}) const;
// As GetSparseElement(), but determines the correct type and converts the
// value into text.
@@ -122,14 +121,13 @@ class LiteralBase {
const ShapeIndex& shape_index = {}) const;
// As Get(), but determines the correct type and converts the value into
// int64. This literal must be an array.
- StatusOr<int64> GetIntegralAsS64(
- tensorflow::gtl::ArraySlice<int64> multi_index) const;
+ StatusOr<int64> GetIntegralAsS64(absl::Span<const int64> multi_index) const;
// Returns the multi-index of the element in a sparse literal at the given
// sparse element number. The sparse element number is the position with in
// the sparse array's list of (index, value) pairs, and is checked against the
// total number of (index, value) pairs in the sparse array.
- tensorflow::gtl::ArraySlice<int64> GetSparseIndex(
+ absl::Span<const int64> GetSparseIndex(
int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
// Returns the value of the element in a sparse literal at the given sparse
@@ -150,12 +148,12 @@ class LiteralBase {
//
// This literal must have a dense layout.
void EachCellAsString(
- const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ const std::function<void(absl::Span<const int64> indices,
const string& value)>& per_cell) const;
template <typename NativeT>
- void EachCell(std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
- NativeT value)>
- per_cell) const;
+ void EachCell(
+ std::function<void(absl::Span<const int64> indices, NativeT value)>
+ per_cell) const;
// Returns whether every element in this literal is equal to value.
//
@@ -195,9 +193,12 @@ class LiteralBase {
// Literal consists entirely of the first element of the literal.
bool IsAllFirst() const;
+ // Literal consists entirely of an iota.
+ bool IsR1Iota() const;
+
// Returns whether this literal is zero at the specified index. This literal
// must be an array with a dense layout.
- bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
+ bool IsZero(absl::Span<const int64> indices) const;
// Returns the count of the elements in the array at the given shape index in
// this literal.
@@ -270,13 +271,12 @@ class LiteralBase {
// implementation currently only supports monotonic dim0-major layouts.
// This literal must be an array.
StatusOr<std::unique_ptr<Literal>> Reshape(
- tensorflow::gtl::ArraySlice<int64> dimensions) const;
+ absl::Span<const int64> dimensions) const;
// Creates a new literal by broadcasting this literal with `dimensions` to
// yield a literal of shape `result_shape`.
StatusOr<std::unique_ptr<Literal>> Broadcast(
- const Shape& result_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions) const;
+ const Shape& result_shape, absl::Span<const int64> dimensions) const;
// Creates a new literal by reordering the dimensions of this literal.
// The given `permutation` must be a permutation of the dimension numbers
@@ -285,8 +285,7 @@ class LiteralBase {
// For example, a transpose call on a literal of shape [3 x 8 x 4] and
// `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
// This literal must be an array.
- std::unique_ptr<Literal> Transpose(
- tensorflow::gtl::ArraySlice<int64> permutation) const;
+ std::unique_ptr<Literal> Transpose(absl::Span<const int64> permutation) const;
// Creates a sub-array from this literal by extracting the indices
// [start_index, limit_index) of each dimension. The result literal has the
@@ -294,9 +293,8 @@ class LiteralBase {
// start_indices and limit_indices must be the rank of the literal, and the
// indices follow the order of the dimensions.
// This literal must be an array.
- std::unique_ptr<Literal> Slice(
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices) const;
+ std::unique_ptr<Literal> Slice(absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices) const;
// Creates a literal with a prepended dimension with bound "times"; e.g. a
// f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
@@ -325,9 +323,9 @@ class LiteralBase {
// Returns the buffer holding the array data for this piece as an array
// slice. This piece must be array-shaped.
template <typename NativeT>
- tensorflow::gtl::ArraySlice<NativeT> data() const;
+ absl::Span<const NativeT> data() const;
template <typename NativeT>
- tensorflow::gtl::MutableArraySlice<NativeT> data();
+ absl::Span<NativeT> data();
// Returns the buffer holding the array data for this piece as a void*. This
// piece must be array-shaped.
@@ -338,9 +336,9 @@ class LiteralBase {
// is CHECKed against the dimension sizes of the array. This piece must be
// array-shaped.
template <typename NativeT>
- NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const;
+ NativeT Get(absl::Span<const int64> index) const;
template <typename NativeT>
- void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value);
+ void Set(absl::Span<const int64> index, NativeT value);
// Gets/sets the buffer holding the array data.
char* buffer() const { return buffer_; }
@@ -542,8 +540,7 @@ class LiteralBase {
private:
template <typename NativeT>
std::unique_ptr<Literal> SliceInternal(
- const Shape& result_shape,
- tensorflow::gtl::ArraySlice<int64> start_indices) const;
+ const Shape& result_shape, absl::Span<const int64> start_indices) const;
};
// Abstract base class representing a mutable literal in XLA.
@@ -551,13 +548,12 @@ class MutableLiteralBase : public LiteralBase {
public:
virtual ~MutableLiteralBase() = 0;
- // Returns a MutableArraySlice view of the array for this literal for the
+ // Returns a Span view of the array for this literal for the
// given NativeT (e.g., float). CHECKs if the subshape of the literal at the
// given ShapeIndex is not array. See primitive_util.h for the mapping from
// XLA type to native type.
template <typename NativeT>
- tensorflow::gtl::MutableArraySlice<NativeT> data(
- const ShapeIndex& shape_index = {});
+ absl::Span<NativeT> data(const ShapeIndex& shape_index = {});
// Unhide const method from parent class.
using LiteralBase::data;
@@ -584,8 +580,7 @@ class MutableLiteralBase : public LiteralBase {
// are populated.
template <typename NativeT>
void PopulateSparse(SparseIndexArray indices,
- tensorflow::gtl::ArraySlice<NativeT> values,
- bool sort = true);
+ absl::Span<const NativeT> values, bool sort = true);
// Copy values from 'src_literal' rooted at 'src_shape_index' into this
// literal rooted at 'dest_shape_index'. The subshape of this literal rooted
@@ -606,39 +601,38 @@ class MutableLiteralBase : public LiteralBase {
// corresponding base indices being 0.
// This literal and 'src_literal' must be arrays.
Status CopySliceFrom(const LiteralSlice& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size);
+ absl::Span<const int64> src_base,
+ absl::Span<const int64> dest_base,
+ absl::Span<const int64> copy_size);
// Copies one element from src_literal[src_index] to (*this)[dest_index].
Status CopyElementFrom(const LiteralSlice& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_index,
- tensorflow::gtl::ArraySlice<int64> dest_index);
+ absl::Span<const int64> src_index,
+ absl::Span<const int64> dest_index);
// Sets an element in the literal at the given index. The multi_index is
// CHECKed against the dimension sizes.
template <typename NativeT>
- void Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index, NativeT value);
+ void Set(absl::Span<const int64> multi_index, const ShapeIndex& shape_index,
+ NativeT value);
// Overloads of Set for array literals. CHECKs if the literal is not
// array-shaped and dense.
template <typename NativeT>
- void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
+ void Set(absl::Span<const int64> multi_index, NativeT value);
// Appends the given element to the literal. If the elements are not appended
// in sorted order, then SortSparseElements should be called before calling
// other methods. This literal must have a sparse layout.
template <typename NativeT>
- void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,
- NativeT value, const ShapeIndex& shape_index = {});
+ void AppendSparseElement(absl::Span<const int64> multi_index, NativeT value,
+ const ShapeIndex& shape_index = {});
// Sorts the elements in a sparse array.
void SortSparseElements(const ShapeIndex& shape_index = {});
// As Set(), but truncates `value` to the literal element type before storing.
// This literal must be an array.
- Status SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
- int64 value);
+ Status SetIntegralAsS64(absl::Span<const int64> multi_index, int64 value);
// Populate this literal with the given values. Examples:
//
@@ -653,7 +647,7 @@ class MutableLiteralBase : public LiteralBase {
// example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
// array of S32.
template <typename NativeT>
- void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values);
+ void PopulateR1(absl::Span<const NativeT> values);
void PopulateR1(const tensorflow::core::Bitmap& values);
template <typename NativeT>
void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
@@ -670,7 +664,7 @@ class MutableLiteralBase : public LiteralBase {
// in this literal object.
//
// generator must be a callable of the type
- // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
+ // NativeT(absl::Span<int64> indexes) or compatible.
//
// This literal must have a dense layout.
template <typename NativeT, typename FnType>
@@ -690,8 +684,7 @@ class MutableLiteralBase : public LiteralBase {
// moved into the tuple elements of a new tuple-shaped Literal which is
// returned. Upon return, each of the Literals in 'elements' is set to a nil
// shape (empty tuple).
- static Literal MoveIntoTuple(
- tensorflow::gtl::MutableArraySlice<Literal> elements);
+ static Literal MoveIntoTuple(absl::Span<Literal> elements);
// Serialize from a proto.
static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
@@ -709,20 +702,20 @@ class MutableLiteralBase : public LiteralBase {
// arguments one by one.
template <typename NativeT>
Status CopySliceFromInternal(const LiteralBase& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size);
+ absl::Span<const int64> src_base,
+ absl::Span<const int64> dest_base,
+ absl::Span<const int64> copy_size);
// Utility structure which is used to create the optimal configuration for
// a ShapeUtil::ForEachIndex() scan across two literals.
struct StrideConfig {
StrideConfig(const Shape& source_shape, const Shape& dest_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const int64> dimensions);
// The dimensions of the stride operation. Essentially every dimension
// will be iterated from base[i] to base[i]+dimensions[i], in step[i]
// steps.
- tensorflow::gtl::ArraySlice<int64> dimensions;
+ absl::Span<const int64> dimensions;
DimensionVector base;
DimensionVector step;
int64 minor_dimension = 0;
@@ -851,7 +844,7 @@ class BorrowingLiteral : public LiteralBase {
// This constructor is only used for array shapes.
BorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
// Similar as above, except to be used for constructing non-nested tuples.
- BorrowingLiteral(tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs,
+ BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
const Shape& shape);
// TODO(b/79707221): adding constructors for nested tuples as well.
@@ -871,7 +864,7 @@ class BorrowingLiteral : public LiteralBase {
};
template <typename NativeT>
-tensorflow::gtl::ArraySlice<NativeT> LiteralBase::Piece::data() const {
+absl::Span<const NativeT> LiteralBase::Piece::data() const {
CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
CHECK_EQ(subshape().element_type(),
primitive_util::NativeToPrimitiveType<NativeT>())
@@ -879,12 +872,12 @@ tensorflow::gtl::ArraySlice<NativeT> LiteralBase::Piece::data() const {
<< PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
<< " type, but literal element type is "
<< PrimitiveType_Name(subshape().element_type());
- return tensorflow::gtl::ArraySlice<NativeT>(
- reinterpret_cast<const NativeT*>(buffer()), element_count());
+ return absl::Span<const NativeT>(reinterpret_cast<const NativeT*>(buffer()),
+ element_count());
}
template <typename NativeT>
-tensorflow::gtl::MutableArraySlice<NativeT> LiteralBase::Piece::data() {
+absl::Span<NativeT> LiteralBase::Piece::data() {
CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
CHECK_EQ(subshape().element_type(),
primitive_util::NativeToPrimitiveType<NativeT>())
@@ -892,20 +885,19 @@ tensorflow::gtl::MutableArraySlice<NativeT> LiteralBase::Piece::data() {
<< PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
<< " type, but literal element type is "
<< PrimitiveType_Name(subshape().element_type());
- return tensorflow::gtl::MutableArraySlice<NativeT>(
- reinterpret_cast<NativeT*>(buffer()), element_count());
+ return absl::Span<NativeT>(reinterpret_cast<NativeT*>(buffer()),
+ element_count());
}
template <typename NativeT>
-NativeT LiteralBase::Piece::Get(
- tensorflow::gtl::ArraySlice<int64> multi_index) const {
+NativeT LiteralBase::Piece::Get(absl::Span<const int64> multi_index) const {
CHECK(LayoutUtil::IsDenseArray(subshape()));
return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
subshape(), multi_index)];
}
template <typename NativeT>
-void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+void LiteralBase::Piece::Set(absl::Span<const int64> multi_index,
NativeT value) {
CHECK(LayoutUtil::IsDenseArray(subshape()));
data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
@@ -913,39 +905,37 @@ void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
}
template <typename NativeT>
-tensorflow::gtl::ArraySlice<NativeT> LiteralBase::data(
+absl::Span<const NativeT> LiteralBase::data(
const ShapeIndex& shape_index) const {
return piece(shape_index).data<NativeT>();
}
template <typename NativeT>
-tensorflow::gtl::MutableArraySlice<NativeT> MutableLiteralBase::data(
- const ShapeIndex& shape_index) {
+absl::Span<NativeT> MutableLiteralBase::data(const ShapeIndex& shape_index) {
return piece(shape_index).data<NativeT>();
}
template <typename NativeT>
-inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice<int64> multi_index,
+inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index,
const ShapeIndex& shape_index) const {
return piece(shape_index).Get<NativeT>(multi_index);
}
template <typename NativeT>
-inline NativeT LiteralBase::Get(
- tensorflow::gtl::ArraySlice<int64> multi_index) const {
+inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index) const {
return root_piece().Get<NativeT>(multi_index);
}
template <typename NativeT>
-inline void MutableLiteralBase::Set(
- tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index, NativeT value) {
+inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
+ const ShapeIndex& shape_index,
+ NativeT value) {
return piece(shape_index).Set<NativeT>(multi_index, value);
}
template <typename NativeT>
-inline void MutableLiteralBase::Set(
- tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value) {
+inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
+ NativeT value) {
return root_piece().Set<NativeT>(multi_index, value);
}
@@ -964,7 +954,7 @@ NativeT LiteralBase::GetSparseElement(int64 sparse_element_number,
template <typename NativeT>
void MutableLiteralBase::AppendSparseElement(
- tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value,
+ absl::Span<const int64> multi_index, NativeT value,
const ShapeIndex& shape_index) {
Piece& p = piece(shape_index);
const Shape& subshape = p.subshape();
@@ -980,8 +970,7 @@ void MutableLiteralBase::AppendSparseElement(
template <typename NativeT>
void LiteralBase::EachCell(
- std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
- NativeT value)>
+ std::function<void(absl::Span<const int64> indices, NativeT value)>
per_cell) const {
if (ShapeUtil::IsZeroElementArray(shape())) {
return;
@@ -989,12 +978,11 @@ void LiteralBase::EachCell(
std::vector<int64> indices(ShapeUtil::Rank(shape()), 0);
do {
per_cell(indices, Get<NativeT>(indices));
- } while (IndexUtil::BumpIndices(shape(), &indices));
+ } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
}
template <typename NativeT>
-inline void MutableLiteralBase::PopulateR1(
- tensorflow::gtl::ArraySlice<NativeT> values) {
+inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 1);
CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
@@ -1039,8 +1027,9 @@ void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
for (int dim = 0; dim < values.num_dimensions(); ++dim) {
CHECK_EQ(values.dim(dim), shape().dimensions(dim));
}
- values.Each([this](tensorflow::gtl::ArraySlice<int64> indices,
- NativeT value) { this->Set(indices, value); });
+ values.Each([this](absl::Span<const int64> indices, NativeT value) {
+ this->Set(indices, value);
+ });
}
template <typename NativeT>
@@ -1059,9 +1048,9 @@ void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
}
template <typename NativeT>
-void MutableLiteralBase::PopulateSparse(
- SparseIndexArray indices, tensorflow::gtl::ArraySlice<NativeT> values,
- bool sort) {
+void MutableLiteralBase::PopulateSparse(SparseIndexArray indices,
+ absl::Span<const NativeT> values,
+ bool sort) {
CHECK(LayoutUtil::IsSparseArray(shape()));
int rank = ShapeUtil::Rank(shape());
CHECK_EQ(indices.rank(), rank);
@@ -1071,7 +1060,7 @@ void MutableLiteralBase::PopulateSparse(
CHECK_LE(num_elements, max_elements);
CHECK_EQ(num_elements, indices.index_count());
auto root_data = root_piece().data<NativeT>();
- // Piece::data() returns an ArraySlice of size equal to the number of indices
+ // Piece::data() returns a Span of size equal to the number of indices
// in the SparseIndexArray. So there is no need to adjust the size of the data
// here. It is enough to just copy the incoming values into the data buffer.
std::copy(values.begin(), values.end(), root_data.begin());
@@ -1091,14 +1080,14 @@ Status MutableLiteralBase::PopulateInternal(const FnType& generator,
TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
TF_RET_CHECK(this_shape.element_type() ==
primitive_util::NativeToPrimitiveType<NativeT>());
- tensorflow::gtl::MutableArraySlice<NativeT> literal_data = data<NativeT>();
+ absl::Span<NativeT> literal_data = data<NativeT>();
if (rank > 0) {
StrideConfig stride_config(this_shape, this_shape,
AsInt64Slice(this_shape.dimensions()));
int64 minor_dimension_size =
ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
- auto init_function = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+ auto init_function = [&](absl::Span<const int64> indexes) {
DimensionVector minor_scan_indexes(rank, 0);
const int64 index =
IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
@@ -1116,7 +1105,7 @@ Status MutableLiteralBase::PopulateInternal(const FnType& generator,
ShapeUtil::ForEachIndex(
this_shape, stride_config.base, stride_config.dimensions,
stride_config.step,
- [&init_function](tensorflow::gtl::ArraySlice<int64> indexes) {
+ [&init_function](absl::Span<const int64> indexes) {
init_function(indexes);
return true;
});
@@ -1162,7 +1151,7 @@ std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
}
DimensionVector output_indices(bounds.size(), 0);
- tensorflow::gtl::ArraySlice<int64> input_indices = output_indices;
+ absl::Span<const int64> input_indices = output_indices;
input_indices.remove_prefix(1);
bool done = false;
diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc
index 43388ac9d1..3d8725ed70 100644
--- a/tensorflow/compiler/xla/literal_comparison.cc
+++ b/tensorflow/compiler/xla/literal_comparison.cc
@@ -38,8 +38,8 @@ namespace {
// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
// -- on miscompare, a nice error message is given in the AssertionFailure.
template <typename FloatT, typename UnsignedT>
-Status CompareFloatsBitwiseEqual(
- FloatT lhs, FloatT rhs, tensorflow::gtl::ArraySlice<int64> multi_index) {
+Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs,
+ absl::Span<const int64> multi_index) {
auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
auto lhs_double = static_cast<double>(lhs);
@@ -60,7 +60,7 @@ Status CompareFloatsBitwiseEqual(
// default gunit implementation).
template <typename NativeT>
Status CompareEqual(NativeT lhs, NativeT rhs,
- tensorflow::gtl::ArraySlice<int64> multi_index) {
+ absl::Span<const int64> multi_index) {
if (lhs == rhs) {
return Status::OK();
}
@@ -74,28 +74,27 @@ Status CompareEqual(NativeT lhs, NativeT rhs,
// comparison is requested.
template <>
Status CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs,
- tensorflow::gtl::ArraySlice<int64> multi_index) {
+ absl::Span<const int64> multi_index) {
return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs, multi_index);
}
template <>
-Status CompareEqual<Eigen::half>(
- Eigen::half lhs, Eigen::half rhs,
- tensorflow::gtl::ArraySlice<int64> multi_index) {
+Status CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs,
+ absl::Span<const int64> multi_index) {
return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs, multi_index);
}
template <>
Status CompareEqual<float>(float lhs, float rhs,
- tensorflow::gtl::ArraySlice<int64> multi_index) {
+ absl::Span<const int64> multi_index) {
return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs, multi_index);
}
template <>
Status CompareEqual<double>(double lhs, double rhs,
- tensorflow::gtl::ArraySlice<int64> multi_index) {
+ absl::Span<const int64> multi_index) {
return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs, multi_index);
}
template <>
Status CompareEqual<complex64>(complex64 lhs, complex64 rhs,
- tensorflow::gtl::ArraySlice<int64> multi_index) {
+ absl::Span<const int64> multi_index) {
auto res = CompareEqual<float>(lhs.real(), rhs.real(), multi_index);
if (!res.ok()) {
return res;
@@ -108,8 +107,7 @@ Status CompareEqual<complex64>(complex64 lhs, complex64 rhs,
// elements are equal.
template <typename NativeT>
Status Equal(LiteralSlice expected, LiteralSlice actual,
- tensorflow::gtl::MutableArraySlice<int64> multi_index,
- int64 dimension) {
+ absl::Span<int64> multi_index, int64 dimension) {
if (dimension == expected.shape().dimensions_size()) {
NativeT expected_value = expected.Get<NativeT>(multi_index);
NativeT actual_value = actual.Get<NativeT>(multi_index);
@@ -164,6 +162,17 @@ bool NanMismatch<half>(half expected, half actual, bool relaxed_nans) {
static_cast<float>(actual), relaxed_nans);
}
+// Returns whether the given value is infinity.
+template <typename NativeT>
+bool IsInf(NativeT val) {
+ return std::isinf(val);
+}
+
+template <>
+bool IsInf<half>(half val) {
+ return std::isinf(static_cast<float>(val));
+}
+
// Converts the given floating-point value to a string.
template <typename NativeT>
string FpValueToString(NativeT value) {
@@ -294,8 +303,7 @@ class NearComparator {
}
// Insert the given error into the given error bucket vector.
- void UpdateErrorBucket(
- float error, tensorflow::gtl::MutableArraySlice<int64> error_buckets) {
+ void UpdateErrorBucket(float error, absl::Span<int64> error_buckets) {
CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size());
for (int i = 0; i < error_buckets.size(); ++i) {
if (error >= kErrorBucketBounds[i]) {
@@ -306,12 +314,13 @@ class NearComparator {
// Compares the two given elements from the expected and actual literals at
// the given literal_index and keeps track of various mismatch statistics.
- void CompareValues(NativeT expected, NativeT actual, int64 linear_index) {
+ template <typename T>
+ void CompareValues(T expected, T actual, int64 linear_index) {
const bool is_nan_mismatch =
NanMismatch(expected, actual, error_.relaxed_nans);
float abs_error;
float rel_error;
- if (actual == expected) {
+ if (CompareEqual<T>(expected, actual, {linear_index}).ok()) {
abs_error = 0;
rel_error = 0;
} else if (is_nan_mismatch) {
@@ -322,6 +331,12 @@ class NearComparator {
// weak ordering requirement of std containers.
abs_error = std::numeric_limits<float>::infinity();
rel_error = std::numeric_limits<float>::infinity();
+ } else if (IsInf(expected) || IsInf(actual)) {
+ // If either the expected or actual value is infinity but not both,
+ // then both absolute and relative error are regarded as inifity.
+ CHECK(!CompareEqual(expected, actual, {linear_index}).ok());
+ abs_error = std::numeric_limits<float>::infinity();
+ rel_error = std::numeric_limits<float>::infinity();
} else {
abs_error = FpAbsoluteValue(actual - expected);
rel_error = abs_error / FpAbsoluteValue(expected);
@@ -335,11 +350,11 @@ class NearComparator {
// bound is exceeded and vice versa.
if (is_abs_mismatch) {
num_abs_mismatches_++;
- UpdateErrorBucket(rel_error, &rel_error_buckets_);
+ UpdateErrorBucket(rel_error, absl::MakeSpan(rel_error_buckets_));
}
if (is_rel_mismatch) {
num_rel_mismatches_++;
- UpdateErrorBucket(abs_error, &abs_error_buckets_);
+ UpdateErrorBucket(abs_error, absl::MakeSpan(abs_error_buckets_));
}
UpdateAbsValueBucket(actual, is_mismatch);
@@ -364,15 +379,36 @@ class NearComparator {
mismatches_.data<bool>()[linear_index] = true;
}
+ // For complex64 types, we compare real and imaginary parts individually.
+ void CompareValues(complex64 expected, complex64 actual, int64 linear_index) {
+ bool mismatch = false;
+ CompareValues<float>(expected.real(), actual.real(), linear_index);
+ if (mismatches_.data<bool>()[linear_index] == true) {
+ mismatch = true;
+ // Delay the mismatch count increase for real part, instead increase
+ // mismatch by 1 for the entire complex number.
+ num_mismatches_--;
+ }
+ CompareValues<float>(expected.imag(), actual.imag(), linear_index);
+ if (mismatches_.data<bool>()[linear_index] == true) {
+ mismatch = true;
+ // Delay the mismatch count increase for imag part, instead increase
+ // mismatch by 1 for the entire complex number.
+ num_mismatches_--;
+ }
+ if (mismatch == true) {
+ num_mismatches_++;
+ }
+ mismatches_.data<bool>()[linear_index] = mismatch;
+ }
+
// Compares the two literals elementwise.
void CompareLiterals() {
// Fast path optimization for the case were layouts match.
if (LayoutUtil::Equal(actual_.shape().layout(),
expected_.shape().layout())) {
- tensorflow::gtl::ArraySlice<const NativeT> expected_data =
- expected_.data<NativeT>();
- tensorflow::gtl::ArraySlice<const NativeT> actual_data =
- actual_.data<NativeT>();
+ absl::Span<const NativeT> expected_data = expected_.data<NativeT>();
+ absl::Span<const NativeT> actual_data = actual_.data<NativeT>();
const int64 len = expected_data.size();
for (int64 i = 0; i < len; ++i) {
CompareValues(expected_data[i], actual_data[i], i);
@@ -447,7 +483,7 @@ class NearComparator {
}
auto print_accum_buckets = [&](const string& header, int64 total,
- tensorflow::gtl::ArraySlice<int64> buckets) {
+ absl::Span<const int64> buckets) {
StrAppend(&out, header, ":\n");
StrAppendFormat(&out, " < %-6g : %7d (%s)\n", kErrorBucketBounds[0],
total - buckets[0],
@@ -538,40 +574,41 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) {
TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
+ auto index = absl::MakeSpan(multi_index);
Status result;
switch (expected.shape().element_type()) {
case PRED:
- result = Equal<bool>(expected, actual, &multi_index, 0);
+ result = Equal<bool>(expected, actual, index, 0);
break;
case U8:
- result = Equal<uint8>(expected, actual, &multi_index, 0);
+ result = Equal<uint8>(expected, actual, index, 0);
break;
case S32:
- result = Equal<int32>(expected, actual, &multi_index, 0);
+ result = Equal<int32>(expected, actual, index, 0);
break;
case S64:
- result = Equal<int64>(expected, actual, &multi_index, 0);
+ result = Equal<int64>(expected, actual, index, 0);
break;
case U32:
- result = Equal<uint32>(expected, actual, &multi_index, 0);
+ result = Equal<uint32>(expected, actual, index, 0);
break;
case U64:
- result = Equal<uint64>(expected, actual, &multi_index, 0);
+ result = Equal<uint64>(expected, actual, index, 0);
break;
case BF16:
- result = Equal<bfloat16>(expected, actual, &multi_index, 0);
+ result = Equal<bfloat16>(expected, actual, index, 0);
break;
case F16:
- result = Equal<half>(expected, actual, &multi_index, 0);
+ result = Equal<half>(expected, actual, index, 0);
break;
case F32:
- result = Equal<float>(expected, actual, &multi_index, 0);
+ result = Equal<float>(expected, actual, index, 0);
break;
case F64:
- result = Equal<double>(expected, actual, &multi_index, 0);
+ result = Equal<double>(expected, actual, index, 0);
break;
case C64:
- result = Equal<complex64>(expected, actual, &multi_index, 0);
+ result = Equal<complex64>(expected, actual, index, 0);
break;
case TUPLE: {
for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc
index e08a9d6e41..1a64594db8 100644
--- a/tensorflow/compiler/xla/literal_test.cc
+++ b/tensorflow/compiler/xla/literal_test.cc
@@ -36,7 +36,6 @@ limitations under the License.
namespace xla {
namespace {
-using tensorflow::gtl::ArraySlice;
using ::testing::ElementsAre;
using ::testing::HasSubstr;
@@ -122,10 +121,10 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
auto bf16_lit = LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
EXPECT_EQ("0.5", bf16_lit->ToString());
- // 3.14 will be truncated to 3.125 in bfloat16 format.
+ // 3.14 will be rounded to 3.14062 in bfloat16 format.
auto bf16_lit_truncated =
LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
- EXPECT_EQ("3.125", bf16_lit_truncated->ToString());
+ ASSERT_EQ("3.14062", bf16_lit_truncated->ToString());
auto bf16_lit_truncated2 =
LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
@@ -222,9 +221,9 @@ TEST_F(LiteralUtilTest, CreateSparse) {
std::vector<int64> expected_values = {8, 9, 7, 10};
EXPECT_EQ(literal->sparse_indices()->data(),
- ArraySlice<int64>(expected_indices.data(),
- expected_indices.num_elements()));
- EXPECT_EQ(literal->data<int64>(), ArraySlice<int64>(expected_values));
+ absl::Span<const int64>(expected_indices.data(),
+ expected_indices.num_elements()));
+ EXPECT_EQ(literal->data<int64>(), absl::Span<const int64>(expected_values));
}
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
@@ -296,7 +295,7 @@ TEST_F(LiteralUtilTest, EachCellR2F32) {
// clang-format on
std::vector<std::tuple<int64, int64, string>> seen;
literal->EachCellAsString(
- [&seen](ArraySlice<int64> indices, const string& value) {
+ [&seen](absl::Span<const int64> indices, const string& value) {
seen.emplace_back(indices[0], indices[1], value);
});
@@ -649,7 +648,7 @@ TEST_F(LiteralUtilTest, TransposeR4) {
// clang-format on
auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1});
- reshape->EachCell<float>([&](ArraySlice<int64> indices, float value) {
+ reshape->EachCell<float>([&](absl::Span<const int64> indices, float value) {
EXPECT_EQ(value, original->Get<float>(
{indices[2], indices[3], indices[0], indices[1]}));
});
@@ -889,7 +888,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
const int64 zero_base[] = {0, 0, 0, 0};
const int64 step[] = {1, 1, 1, 1};
uint32 seqnr = 0;
- auto init_proc = [&](ArraySlice<int64> indexes) {
+ auto init_proc = [&](absl::Span<const int64> indexes) {
source->Set(indexes, ++seqnr);
return true;
};
@@ -905,7 +904,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
bool matched = true;
- auto check_proc = [&](ArraySlice<int64> indexes) {
+ auto check_proc = [&](absl::Span<const int64> indexes) {
std::copy(indexes.begin(), indexes.end(), source_indexes.begin());
std::transform(source_indexes.begin(), source_indexes.end(), src_base,
source_indexes.begin(), std::plus<int64>());
@@ -1093,7 +1092,7 @@ TEST_F(LiteralUtilTest, Populate) {
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
data.layout);
auto literal = absl::make_unique<Literal>(shape);
- auto generator = [&](ArraySlice<int64> indexes) -> uint32 {
+ auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
// Offsets from linear index just to avoid R0 literals to be initialized
// with zero.
return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
@@ -1105,7 +1104,7 @@ TEST_F(LiteralUtilTest, Populate) {
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
bool matched = true;
- auto check_function = [&](ArraySlice<int64> indexes) {
+ auto check_function = [&](absl::Span<const int64> indexes) {
auto value = literal->Get<uint32>(indexes);
matched = matched && (value == generator(indexes));
return matched;
@@ -1135,7 +1134,7 @@ TEST_F(LiteralUtilTest, PopulateParallel) {
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
data.layout);
auto literal = absl::make_unique<Literal>(shape);
- auto generator = [&](ArraySlice<int64> indexes) -> uint32 {
+ auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
// Offsets from linear index just to avoid R0 literals to be initialized
// with zero.
return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
@@ -1147,7 +1146,7 @@ TEST_F(LiteralUtilTest, PopulateParallel) {
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
bool matched = true;
- auto check_function = [&](ArraySlice<int64> indexes) {
+ auto check_function = [&](absl::Span<const int64> indexes) {
auto value = literal->Get<uint32>(indexes);
matched = matched && (value == generator(indexes));
return matched;
@@ -1561,7 +1560,7 @@ TEST_F(LiteralUtilTest, MoveIntoTuple) {
));
- Literal literal = Literal::MoveIntoTuple(&elements);
+ Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements));
ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape()));
ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3);
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 931d2c631b..613449cf10 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -84,8 +84,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
} // namespace
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromDimensions(
- PrimitiveType primitive_type,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+ PrimitiveType primitive_type, absl::Span<const int64> dimensions) {
return Literal::CreateFromShape(
ShapeUtil::MakeShape(primitive_type, dimensions));
}
@@ -301,9 +300,8 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
}
/* static */ std::unique_ptr<Literal> LiteralUtil::ReshapeSlice(
- tensorflow::gtl::ArraySlice<int64> new_dimensions,
- tensorflow::gtl::ArraySlice<int64> minor_to_major,
- const LiteralSlice& literal) {
+ absl::Span<const int64> new_dimensions,
+ absl::Span<const int64> minor_to_major, const LiteralSlice& literal) {
int64 new_num_elements = 1;
for (int64 i = 0; i < new_dimensions.size(); ++i) {
new_num_elements *= new_dimensions[i];
@@ -430,7 +428,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
}
/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTuple(
- tensorflow::gtl::ArraySlice<const Literal*> elements) {
+ absl::Span<const Literal* const> elements) {
std::vector<Shape> element_shapes;
for (const auto* element : elements) {
element_shapes.push_back(element->shape());
@@ -444,7 +442,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
}
/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleFromSlices(
- tensorflow::gtl::ArraySlice<LiteralSlice> elements) {
+ absl::Span<const LiteralSlice> elements) {
std::vector<Shape> element_shapes;
for (const auto& element : elements) {
element_shapes.push_back(element.shape());
@@ -474,7 +472,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
}
/* static */ string LiteralUtil::MultiIndexAsString(
- tensorflow::gtl::ArraySlice<int64> multi_index) {
+ absl::Span<const int64> multi_index) {
return StrCat("{", absl::StrJoin(multi_index, ","), "}");
}
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 3d28c070f2..2d6084a67a 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -29,6 +29,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -44,7 +45,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/bitmap.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -71,8 +71,7 @@ class LiteralUtil {
template <typename NativeT>
static std::unique_ptr<Literal> CreateR0(NativeT value);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR1(
- tensorflow::gtl::ArraySlice<NativeT> values);
+ static std::unique_ptr<Literal> CreateR1(absl::Span<const NativeT> values);
static std::unique_ptr<Literal> CreateR1(
const tensorflow::core::Bitmap& values);
template <typename NativeT>
@@ -141,8 +140,8 @@ class LiteralUtil {
//
template <typename NativeT>
static std::unique_ptr<Literal> CreateSparse(
- tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
- tensorflow::gtl::ArraySlice<NativeT> values, bool sort = true);
+ absl::Span<const int64> dimensions, SparseIndexArray indices,
+ absl::Span<const NativeT> values, bool sort = true);
// Creates a scalar literal value zero of the given primitive type.
static Literal Zero(PrimitiveType primitive_type);
@@ -157,7 +156,7 @@ class LiteralUtil {
// Creates a literal of the given shape where each element is `value`.
template <typename NativeT>
static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
- tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value);
+ absl::Span<const int64> dimensions, NativeT value);
// Creates a new literal from an Array type. The variants not ending with
// WithLayout use the default XLA layout for the literal's linear
@@ -215,10 +214,10 @@ class LiteralUtil {
// Returns a tuple literal composed of given literals. Data is copied from the
// given elements into the returned literal.
static std::unique_ptr<Literal> MakeTuple(
- tensorflow::gtl::ArraySlice<const Literal*> elements);
+ absl::Span<const Literal* const> elements);
static std::unique_ptr<Literal> MakeTupleFromSlices(
- tensorflow::gtl::ArraySlice<LiteralSlice> elements);
+ absl::Span<const LiteralSlice> elements);
// As above, but intended to be invoked with move semantics; i.e.
//
@@ -259,8 +258,7 @@ class LiteralUtil {
// The content of the literal values is the default value of the primitive
// type of literal itself (0 for numeric types, and false for predicates).
static std::unique_ptr<Literal> CreateFromDimensions(
- PrimitiveType primitive_type,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ PrimitiveType primitive_type, absl::Span<const int64> dimensions);
// If the given literal's data type is bfloat16, converts it to a float
// literal; otherwise, returns a copy of it. If the literal is a tuple,
@@ -279,9 +277,8 @@ class LiteralUtil {
// buffer of the input literal is assumed to have the given minor_to_major
// layout order.
static std::unique_ptr<Literal> ReshapeSlice(
- tensorflow::gtl::ArraySlice<int64> new_dimensions,
- tensorflow::gtl::ArraySlice<int64> minor_to_major,
- const LiteralSlice& literal);
+ absl::Span<const int64> new_dimensions,
+ absl::Span<const int64> minor_to_major, const LiteralSlice& literal);
// Creates a literal with the supplied shape, and uses the provided value
// generator to populate the literal's values.
@@ -291,7 +288,7 @@ class LiteralUtil {
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
const Shape& shape,
- const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator);
+ const std::function<T(absl::Span<const int64>)>& generator);
// Creates a literal with the supplied shape, and initializes the literal
// values using a normal distribution with given mean and stddev standard
@@ -319,8 +316,7 @@ class LiteralUtil {
// Returns a multi-dimensional index as a string. For example: '{7, 8}' will
// be returned for a 2-dimensional index with dimension 0 index equal to 7,
// dimension 1 equal to 8.
- static string MultiIndexAsString(
- tensorflow::gtl::ArraySlice<int64> multi_index);
+ static string MultiIndexAsString(absl::Span<const int64> multi_index);
};
std::ostream& operator<<(std::ostream& out, const Literal& literal);
@@ -335,7 +331,7 @@ template <typename NativeT>
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
- tensorflow::gtl::ArraySlice<NativeT> values) {
+ absl::Span<const NativeT> values) {
auto literal = absl::make_unique<Literal>(
ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
{static_cast<int64>(values.size())}));
@@ -427,8 +423,8 @@ template <typename NativeT>
template <typename NativeT>
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateSparse(
- tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
- tensorflow::gtl::ArraySlice<NativeT> values, bool sort) {
+ absl::Span<const int64> dimensions, SparseIndexArray indices,
+ absl::Span<const NativeT> values, bool sort) {
int64 num_elements = values.size();
int64 rank = dimensions.size();
CHECK_EQ(num_elements, indices.index_count());
@@ -570,8 +566,8 @@ template <typename NativeT>
template <typename NativeT>
/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateFullWithDescendingLayout(
- tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) {
+LiteralUtil::CreateFullWithDescendingLayout(absl::Span<const int64> dimensions,
+ NativeT value) {
auto literal =
absl::make_unique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
@@ -583,14 +579,12 @@ template <PrimitiveType type, typename T>
/* static */ StatusOr<std::unique_ptr<Literal>>
LiteralUtil::CreateRandomLiteral(
const Shape& shape,
- const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) {
+ const std::function<T(absl::Span<const int64>)>& generator) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
TF_RET_CHECK(shape.element_type() == type);
auto literal = absl::make_unique<Literal>(shape);
TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
- [&](tensorflow::gtl::ArraySlice<int64> indexes) {
- return generator(indexes);
- }));
+ [&](absl::Span<const int64> indexes) { return generator(indexes); }));
return std::move(literal);
}
@@ -601,9 +595,8 @@ LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
std::normal_distribution<NativeT> generator(mean, stddev);
return CreateRandomLiteral<type, NativeT>(
- shape, [&](tensorflow::gtl::ArraySlice<int64> /*indexes*/) {
- return generator(*engine);
- });
+ shape,
+ [&](absl::Span<const int64> /*indexes*/) { return generator(*engine); });
}
template <PrimitiveType type, typename T>
diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc
index 6e42775f6f..f9473d372b 100644
--- a/tensorflow/compiler/xla/packed_literal_reader.cc
+++ b/tensorflow/compiler/xla/packed_literal_reader.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/base/casts.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -61,10 +61,10 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
result->PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
int64 elements = ShapeUtil::ElementsIn(shape);
- tensorflow::gtl::ArraySlice<float> field = result->data<float>();
- char* data = tensorflow::bit_cast<char*>(field.data());
+ absl::Span<const float> field = result->data<float>();
+ char* data = absl::bit_cast<char*>(field.data());
uint64 bytes = elements * sizeof(float);
- tensorflow::StringPiece sp; // non-absl OK
+ absl::string_view sp;
auto s = file_->Read(offset_, bytes, &sp, data);
offset_ += sp.size();
if (!s.ok()) {
@@ -85,7 +85,7 @@ bool PackedLiteralReader::IsExhausted() const {
// Try to read a single byte from offset_. If we can't, we've
// exhausted the data.
char single_byte[1];
- tensorflow::StringPiece sp; // non-absl OK
+ absl::string_view sp;
auto s = file_->Read(offset_, sizeof(single_byte), &sp, single_byte);
return !s.ok();
}
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index fe91dc0618..f0d84646b9 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -41,6 +41,7 @@ cc_library(
"//tensorflow/python:numpy_lib",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -62,6 +63,7 @@ cc_library(
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index b5fd747cfa..cd6e20b693 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -259,7 +259,7 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
}
LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers(
- tensorflow::gtl::ArraySlice<LocalShapedBuffer*> argument_handles) {
+ absl::Span<LocalShapedBuffer* const> argument_handles) {
LocalClient* client = GetOrCreateLocalClient();
std::vector<const ShapedBuffer*> argument_buffers;
@@ -369,8 +369,7 @@ LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) {
}
LocalOp LocalComputationBuilder::Broadcast(
- const LocalOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
+ const LocalOp& operand, absl::Span<const int64> broadcast_sizes) {
return xla::Broadcast(operand.op(), broadcast_sizes);
}
@@ -380,14 +379,14 @@ LocalOp LocalComputationBuilder::Pad(const LocalOp& operand,
return xla::Pad(operand.op(), padding_value.op(), padding_config);
}
-LocalOp LocalComputationBuilder::Reshape(
- const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes) {
+LocalOp LocalComputationBuilder::Reshape(const LocalOp& operand,
+ absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes) {
return xla::Reshape(operand.op(), dimensions, new_sizes);
}
-LocalOp LocalComputationBuilder::Collapse(
- const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
+LocalOp LocalComputationBuilder::Collapse(const LocalOp& operand,
+ absl::Span<const int64> dimensions) {
return xla::Collapse(operand.op(), dimensions);
}
@@ -395,10 +394,10 @@ LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) {
return xla::CrossReplicaSum(operand.op());
}
-LocalOp LocalComputationBuilder::Slice(
- const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides) {
+LocalOp LocalComputationBuilder::Slice(const LocalOp& operand,
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides) {
return xla::Slice(operand.op(), start_indices, limit_indices, strides);
}
@@ -411,7 +410,7 @@ LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand,
LocalOp LocalComputationBuilder::DynamicSlice(
const LocalOp& operand, const LocalOp& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ absl::Span<const int64> slice_sizes) {
return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes);
}
@@ -421,8 +420,8 @@ LocalOp LocalComputationBuilder::DynamicUpdateSlice(
return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op());
}
-LocalOp LocalComputationBuilder::ConcatInDim(
- tensorflow::gtl::ArraySlice<LocalOp> operands, int64 dimension) {
+LocalOp LocalComputationBuilder::ConcatInDim(absl::Span<const LocalOp> operands,
+ int64 dimension) {
std::vector<XlaOp> xla_ops;
xla_ops.reserve(operands.size());
for (const auto& op : operands) {
@@ -433,18 +432,16 @@ LocalOp LocalComputationBuilder::ConcatInDim(
LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding(
const LocalOp& operand, const LocalComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const LocalOp& source, const LocalOp& init_value,
- const LocalComputation& scatter) {
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding, const LocalOp& source,
+ const LocalOp& init_value, const LocalComputation& scatter) {
return xla::SelectAndScatterWithGeneralPadding(
operand.op(), select.computation(), window_dimensions, window_strides,
padding, source.op(), init_value.op(), scatter.computation());
}
-LocalOp LocalComputationBuilder::Tuple(
- tensorflow::gtl::ArraySlice<LocalOp> elements) {
+LocalOp LocalComputationBuilder::Tuple(absl::Span<const LocalOp> elements) {
std::vector<XlaOp> xla_ops;
xla_ops.reserve(elements.size());
for (const auto& op : elements) {
@@ -471,10 +468,9 @@ LocalOp LocalComputationBuilder::DotGeneral(
LocalOp LocalComputationBuilder::ConvGeneralDilated(
const LocalOp& lhs, const LocalOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers) {
return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers);
@@ -490,9 +486,8 @@ LocalOp LocalComputationBuilder::BitcastConvertType(
return xla::BitcastConvertType(operand.op(), new_element_type);
}
-LocalOp LocalComputationBuilder::Call(
- const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<LocalOp> operands) {
+LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation,
+ absl::Span<const LocalOp> operands) {
std::vector<XlaOp> xla_ops;
xla_ops.reserve(operands.size());
for (const auto& op : operands) {
@@ -502,19 +497,18 @@ LocalOp LocalComputationBuilder::Call(
}
LocalOp LocalComputationBuilder::Transpose(
- const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> permutation) {
+ const LocalOp& operand, absl::Span<const int64> permutation) {
return xla::Transpose(operand.op(), permutation);
}
-LocalOp LocalComputationBuilder::Rev(
- const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
+LocalOp LocalComputationBuilder::Rev(const LocalOp& operand,
+ absl::Span<const int64> dimensions) {
return xla::Rev(operand.op(), dimensions);
}
-LocalOp LocalComputationBuilder::Map(
- tensorflow::gtl::ArraySlice<LocalOp> operands,
- const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+LocalOp LocalComputationBuilder::Map(absl::Span<const LocalOp> operands,
+ const LocalComputation& local_computation,
+ absl::Span<const int64> dimensions) {
std::vector<XlaOp> xla_ops;
xla_ops.reserve(operands.size());
for (const auto& op : operands) {
@@ -528,7 +522,7 @@ LocalOp LocalComputationBuilder::Map(
LocalOp LocalComputationBuilder::Reduce(
const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
+ absl::Span<const int64> dimensions_to_reduce) {
return xla::Reduce(operand.op(), init_value.op(),
local_computation.computation(), dimensions_to_reduce);
}
@@ -536,9 +530,9 @@ LocalOp LocalComputationBuilder::Reduce(
LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding(
const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding) {
return xla::ReduceWindowWithGeneralPadding(
operand.op(), init_value.op(), local_computation.computation(),
window_dimensions, window_strides, padding);
@@ -599,10 +593,10 @@ StatusOr<LocalComputation*> LocalComputationBuilder::BuildConstantSubGraph(
#define _FORWARD_UNOP(method_name) \
_FORWARD(method_name, LocalOp, (const LocalOp& operand), (operand.op()))
-#define _FORWARD_BINOP(method_name) \
- _FORWARD(method_name, LocalOp, \
- (const LocalOp& lhs, const LocalOp& rhs, \
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions), \
+#define _FORWARD_BINOP(method_name) \
+ _FORWARD(method_name, LocalOp, \
+ (const LocalOp& lhs, const LocalOp& rhs, \
+ absl::Span<const int64> broadcast_dimensions), \
(lhs.op(), rhs.op(), broadcast_dimensions))
#define _FORWARD_TRIOP(method_name) \
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index d9543b958d..78b3c598b9 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -23,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace swig {
@@ -122,7 +122,7 @@ class CompiledLocalComputation {
const std::vector<absl::optional<Shape> >& shapes_with_layout);
LocalShapedBuffer* ExecuteWithShapedBuffers(
- tensorflow::gtl::ArraySlice<LocalShapedBuffer*> argument_handles);
+ absl::Span<LocalShapedBuffer* const> argument_handles);
private:
std::unique_ptr<LocalExecutable> executable_;
@@ -199,46 +199,41 @@ class LocalComputationBuilder {
LocalOp ConstantLiteral(const Literal& literal);
LocalOp Broadcast(const LocalOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+ absl::Span<const int64> broadcast_sizes);
LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value,
const PaddingConfig& padding_config);
- LocalOp Reshape(const LocalOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
+ LocalOp Reshape(const LocalOp& operand, absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes);
- LocalOp Collapse(const LocalOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ LocalOp Collapse(const LocalOp& operand, absl::Span<const int64> dimensions);
LocalOp CrossReplicaSum(const LocalOp& operand);
- LocalOp Slice(const LocalOp& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
+ LocalOp Slice(const LocalOp& operand, absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides);
LocalOp SliceInDim(const LocalOp& operand, int64 start_index,
int64 limit_index, int64 stride, int64 dimno);
LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update,
const LocalOp& start_indices);
- LocalOp ConcatInDim(tensorflow::gtl::ArraySlice<LocalOp> operands,
- int64 dimension);
+ LocalOp ConcatInDim(absl::Span<const LocalOp> operands, int64 dimension);
LocalOp SelectAndScatterWithGeneralPadding(
const LocalOp& operand, const LocalComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding,
- const LocalOp& source, const LocalOp& init_value,
- const LocalComputation& scatter);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64> > padding, const LocalOp& source,
+ const LocalOp& init_value, const LocalComputation& scatter);
- LocalOp Tuple(tensorflow::gtl::ArraySlice<LocalOp> elements);
+ LocalOp Tuple(absl::Span<const LocalOp> elements);
LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index);
@@ -249,10 +244,10 @@ class LocalComputationBuilder {
LocalOp ConvGeneralDilated(
const LocalOp& lhs, const LocalOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64> > padding,
+ absl::Span<const int64> lhs_dilation,
+ absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers);
LocalOp ConvertElementType(const LocalOp& operand,
@@ -262,28 +257,27 @@ class LocalComputationBuilder {
PrimitiveType new_element_type);
LocalOp Call(const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<LocalOp> operands);
+ absl::Span<const LocalOp> operands);
LocalOp Transpose(const LocalOp& operand,
- tensorflow::gtl::ArraySlice<int64> permutation);
+ absl::Span<const int64> permutation);
- LocalOp Rev(const LocalOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ LocalOp Rev(const LocalOp& operand, absl::Span<const int64> dimensions);
- LocalOp Map(tensorflow::gtl::ArraySlice<LocalOp> operands,
+ LocalOp Map(absl::Span<const LocalOp> operands,
const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const int64> dimensions);
LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+ absl::Span<const int64> dimensions_to_reduce);
LocalOp ReduceWindowWithGeneralPadding(
const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64> > padding);
LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma,
const Shape& shape);
@@ -316,7 +310,7 @@ class LocalComputationBuilder {
#define _FORWARD_BINOP(method_name) \
_FORWARD(method_name, LocalOp, \
(const LocalOp& lhs, const LocalOp& rhs, \
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions))
+ absl::Span<const int64> broadcast_dimensions))
#define _FORWARD_TRIOP(method_name) \
_FORWARD(method_name, LocalOp, \
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index f6169ebf19..76c09512d8 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -22,15 +22,15 @@ limitations under the License.
//
// C++ Python
// -------------------------------------+---------------------------------------
-// ArraySlice<int64> <- sequence of int
-// ArraySlice<LocalOp> <- sequence of LocalOp
+// Span<int64> <- sequence of int
+// Span<LocalOp> <- sequence of LocalOp
// Literal <-> (nested tuple of) numpy ndarray
// std::vector<Literal> <- sequence of (nested tuple of) ndarray
// Shape -> pair holding (dtype, dimensions)
// <- object duck-typed as xla_client.Shape
// std::vector<Shape> <- sequence of xla_client.Shape objects
// PrimitiveType <- int
-// ArraySlice<pair<int64, in64>> <- sequence of int pairs
+// Span<pair<int64, in64>> <- sequence of int pairs
// PaddingConfig proto <- corresponding Python proto
// ConvolutionDimensionNumbers proto <- corresponding Python proto
// DotDimensionNumbers proto <- corresponding Python proto
@@ -114,7 +114,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "third_party/absl/types/span.h"
#include "tensorflow/compiler/xla/python/numpy_bridge.h"
#include "tensorflow/compiler/xla/python/local_computation_builder.h"
@@ -267,9 +267,9 @@ tensorflow::ImportNumpy();
$result = Py_None;
}
-// ArraySlice<int64>
+// Span<int64>
-%typemap(in) tensorflow::gtl::ArraySlice<int64>
+%typemap(in) absl::Span<const int64>
(std::vector<int64> temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
@@ -299,9 +299,9 @@ tensorflow::ImportNumpy();
$1 = temps;
}
-// ArraySlice<LocalOp>
+// Span<LocalOp>
-%typemap(in) tensorflow::gtl::ArraySlice<xla::swig::LocalOp>(
+%typemap(in) absl::Span<const xla::swig::LocalOp>(
std::vector<LocalOp> temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
@@ -323,7 +323,7 @@ tensorflow::ImportNumpy();
// LocalShapedBuffer*
-%typemap(in) tensorflow::gtl::ArraySlice<xla::swig::LocalShapedBuffer*>
+%typemap(in) absl::Span<xla::swig::LocalShapedBuffer* const>
(std::vector<LocalShapedBuffer*> temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
@@ -496,9 +496,9 @@ tensorflow::ImportNumpy();
$1 = static_cast<PrimitiveType>(value);
}
-// ArraySlice<pair<int64, in64>>
+// Span<pair<int64, in64>>
-%typemap(in) tensorflow::gtl::ArraySlice<std::pair<int64, int64> >
+%typemap(in) absl::Span<const std::pair<int64, int64> >
(std::vector<std::pair<int64, int64> > temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h
index a67c93a4fb..8cae175185 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.h
+++ b/tensorflow/compiler/xla/python/numpy_bridge.h
@@ -25,9 +25,9 @@ limitations under the License.
#include <algorithm>
#include <memory>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/python/lib/core/numpy.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index 3de7ee2bc8..a4854f593f 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -108,17 +108,15 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
// array by adding a fourth dummy dimension of size 1 without stride, padding
// and dilation.
Array4D<float> a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1);
- a4dlhs.Each(
- [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) {
- CHECK_EQ(indices[3], 0);
- *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]);
- });
+ a4dlhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
+ CHECK_EQ(indices[3], 0);
+ *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]);
+ });
Array4D<float> a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1);
- a4drhs.Each(
- [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) {
- CHECK_EQ(indices[3], 0);
- *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]);
- });
+ a4drhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
+ CHECK_EQ(indices[3], 0);
+ *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]);
+ });
// Add a second dummy spatial dimensions.
ConvolutionDimensionNumbers dnums2d = dnums;
dnums2d.add_input_spatial_dimensions(3);
@@ -130,11 +128,10 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
auto convr3 = absl::make_unique<Array3D<float>>(
convr4->planes(), convr4->depth(), convr4->height());
- convr4->Each(
- [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) {
- CHECK_EQ(indices[3], 0);
- convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr;
- });
+ convr4->Each([&](absl::Span<const int64> indices, float* value_ptr) {
+ CHECK_EQ(indices[3], 0);
+ convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr;
+ });
return convr3;
}
@@ -189,11 +186,11 @@ ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
/* static */ std::unique_ptr<std::vector<float>>
ReferenceUtil::ReduceWindow1DGeneric(
- const tensorflow::gtl::ArraySlice<float>& operand, float init,
+ const absl::Span<const float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride,
- const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) {
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride,
+ const absl::Span<const std::pair<int64, int64>>& padding) {
std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
std::vector<int64> window_counts(window.size(), 0);
std::vector<int64> pad_low(window.size(), 0);
@@ -221,10 +218,11 @@ ReferenceUtil::ReduceWindow1DGeneric(
}
/* static */ std::unique_ptr<std::vector<float>>
-ReferenceUtil::ReduceWindow1DAdd(
- const tensorflow::gtl::ArraySlice<float>& operand, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
+ReferenceUtil::ReduceWindow1DAdd(const absl::Span<const float>& operand,
+ float init,
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride,
+ Padding padding) {
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
return ReduceWindow1DGeneric(
@@ -236,9 +234,9 @@ ReferenceUtil::ReduceWindow1DAdd(
ReferenceUtil::ReduceWindow2DGeneric(
const Array2D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride,
- const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) {
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride,
+ const absl::Span<const std::pair<int64, int64>>& padding) {
std::vector<int64> dim_lengths{operand.height(), operand.width()};
std::vector<int64> window_counts(window.size(), 0);
@@ -276,8 +274,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd(
const Array2D<float>& operand, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, Padding padding) {
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
std::vector<int64> dim_lengths{operand.height(), operand.width()};
return ReduceWindow2DGeneric(
@@ -287,8 +285,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
/* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd(
const Array3D<float>& operand, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, Padding padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()};
auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
@@ -334,8 +332,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
ReferenceUtil::ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, Padding padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
operand.n4()};
return ReduceWindow4DGeneric(
@@ -347,9 +345,9 @@ ReferenceUtil::ReduceWindow4DGeneric(
ReferenceUtil::ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride,
- const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) {
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride,
+ const absl::Span<const std::pair<int64, int64>>& padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
operand.n4()};
@@ -402,8 +400,8 @@ ReferenceUtil::ReduceWindow4DGeneric(
/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
const Array4D<float>& operand, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, Padding padding) {
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
padding);
@@ -424,10 +422,12 @@ ReferenceUtil::ReduceWindow4DGeneric(
}
/* static */ std::unique_ptr<Array4D<float>>
-ReferenceUtil::SelectAndScatter4DGePlus(
- const Array4D<float>& operand, const Array4D<float>& source, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding) {
+ReferenceUtil::SelectAndScatter4DGePlus(const Array4D<float>& operand,
+ const Array4D<float>& source,
+ float init,
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride,
+ bool same_padding) {
Padding padding = same_padding ? Padding::kSame : Padding::kValid;
auto result = absl::make_unique<Array4D<float>>(operand.n1(), operand.n2(),
operand.n3(), operand.n4());
@@ -591,7 +591,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
result_literal->shape().dimensions(2),
result_literal->shape().dimensions(3));
- result->Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
+ result->Each([&](absl::Span<const int64> indices, float* value) {
*value = result_literal->Get<float>(indices);
});
@@ -633,8 +633,7 @@ ReferenceUtil::ReduceToRowArray2D(
}
/*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D(
- const Array4D<float>& array, float init,
- tensorflow::gtl::ArraySlice<int64> dims,
+ const Array4D<float>& array, float init, absl::Span<const int64> dims,
const std::function<float(float, float)>& reduce_function) {
std::vector<float> result;
CHECK_EQ(dims.size(), 3);
@@ -707,8 +706,7 @@ ReferenceUtil::ReduceToRowArray2D(
}
/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
- const Array3D<float>& array, float init,
- tensorflow::gtl::ArraySlice<int64> dims,
+ const Array3D<float>& array, float init, absl::Span<const int64> dims,
const std::function<float(float, float)>& reduce_function) {
CHECK_EQ(dims.size(), 1);
int64 rows = dims[0] == 0 ? array.n2() : array.n1();
diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h
index 88f853a359..9ce098029d 100644
--- a/tensorflow/compiler/xla/reference_util.h
+++ b/tensorflow/compiler/xla/reference_util.h
@@ -23,13 +23,13 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -144,8 +144,7 @@ class ReferenceUtil {
// Returns the result of reducing the 4D array to a vector, reducing away
// the dimensions specified in dims.
static std::vector<float> Reduce4DTo1D(
- const Array4D<float>& array, float init,
- tensorflow::gtl::ArraySlice<int64> dims,
+ const Array4D<float>& array, float init, absl::Span<const int64> dims,
const std::function<float(float, float)>& reduce_function);
// Broadcast 1D dimension to 4D, from the dimension `broadcast_from_dim`.
@@ -156,8 +155,7 @@ class ReferenceUtil {
// Returns the result of reducing the 3D array to a 2D array, reducing away
// the dimensions specified in dims.
static std::unique_ptr<Array2D<float>> Reduce3DTo2D(
- const Array3D<float>& array, float init,
- tensorflow::gtl::ArraySlice<int64> dims,
+ const Array3D<float>& array, float init, absl::Span<const int64> dims,
const std::function<float(float, float)>& reduce_function);
// Applies map_function to each element in the input (2D array) and returns
@@ -179,47 +177,47 @@ class ReferenceUtil {
// Windowed reductions with Add as the function to apply.
static std::unique_ptr<std::vector<float>> ReduceWindow1DAdd(
- const tensorflow::gtl::ArraySlice<float>& operand, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
+ const absl::Span<const float>& operand, float init,
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, Padding padding);
static std::unique_ptr<Array2D<float>> ReduceWindow2DAdd(
const Array2D<float>& operand, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, Padding padding);
static std::unique_ptr<Array3D<float>> ReduceWindow3DAdd(
const Array3D<float>& operand, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, Padding padding);
static std::unique_ptr<Array4D<float>> ReduceWindow4DAdd(
const Array4D<float>& operand, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, Padding padding);
// Windowed reductions with a generic reduce function.
static std::unique_ptr<std::vector<float>> ReduceWindow1DGeneric(
- const tensorflow::gtl::ArraySlice<float>& operand, float init,
+ const absl::Span<const float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride,
- const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding);
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride,
+ const absl::Span<const std::pair<int64, int64>>& padding);
static std::unique_ptr<Array2D<float>> ReduceWindow2DGeneric(
const Array2D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride,
- const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding);
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride,
+ const absl::Span<const std::pair<int64, int64>>& padding);
static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, Padding padding);
// With arbitrary padding.
static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride,
- const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding);
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride,
+ const absl::Span<const std::pair<int64, int64>>& padding);
// Batch normalize data.
static std::unique_ptr<Array4D<float>> BatchNorm4D(
@@ -232,8 +230,8 @@ class ReferenceUtil {
// TODO(b/74533103) Switch tests to evaluator and remove this implementation.
static std::unique_ptr<Array4D<float>> SelectAndScatter4DGePlus(
const Array4D<float>& operand, const Array4D<float>& source, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding);
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, bool same_padding);
// Concatenates the lhs and rhs arrays along the concatenate_dimension.
// E.g. if concatenate_dimension is 0, the "n1"/height dimension is
@@ -334,8 +332,8 @@ class ReferenceUtil {
// Slices with index clamping
template <typename T>
- static std::vector<T> ClampSlice1D(
- const tensorflow::gtl::ArraySlice<T>& input, int64 start, int64 size) {
+ static std::vector<T> ClampSlice1D(const absl::Span<const T>& input,
+ int64 start, int64 size) {
start = std::min<int64>(std::max<int64>(0, start), input.size() - size);
std::vector<T> result;
for (int64 i = 0; i < size; ++i) {
@@ -633,7 +631,7 @@ class ReferenceUtil {
Array4D<NativeT> result(output_bounds[0], output_bounds[1],
output_bounds[2], output_bounds[3]);
result.Each(
- [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT* value) {
+ [&](absl::Span<const int64> indices, NativeT* value) {
for (int i = 0; i < 4; ++i) {
bool in_low_padding = indices[i] < pad_low[i];
bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index b68785949c..26b48cf419 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -69,6 +69,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -103,6 +104,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -179,6 +181,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -195,6 +198,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -247,6 +251,7 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -326,6 +331,7 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -378,6 +384,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/types:span",
],
)
@@ -435,6 +442,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -597,6 +605,7 @@ cc_library(
"//third_party/eigen3",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -641,6 +650,7 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
alwayslink = 1,
)
@@ -676,6 +686,7 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -752,6 +763,7 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -787,9 +799,11 @@ cc_library(
":hlo_execution_profile",
":hlo_graph_dumper",
":hlo_proto",
+ ":maybe_owning_device_memory",
":shaped_buffer",
":stream_pool",
"//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -802,6 +816,8 @@ cc_library(
"//tensorflow/stream_executor",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
+ "@com_google_absl//absl/types:variant",
],
)
@@ -820,6 +836,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/types:span",
],
)
@@ -851,6 +868,7 @@ cc_library(
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -907,6 +925,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1001,6 +1020,7 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1026,6 +1046,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
@@ -1144,6 +1165,7 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1265,6 +1287,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
"@com_google_absl//absl/memory",
@@ -1287,6 +1310,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1350,6 +1374,7 @@ cc_library(
hdrs = ["algebraic_simplifier.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
":hlo_creation_utils",
":hlo_pass",
":hlo_query",
@@ -1367,6 +1392,7 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1376,6 +1402,7 @@ tf_cc_test(
deps = [
":algebraic_simplifier",
":hlo",
+ ":hlo_casting_utils",
":hlo_matchers",
":hlo_pass",
"//tensorflow/compiler/xla:literal",
@@ -1702,6 +1729,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1795,6 +1823,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1902,6 +1931,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1918,6 +1948,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1946,6 +1977,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1965,6 +1997,7 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1986,6 +2019,7 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -2087,6 +2121,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -2107,6 +2142,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@@ -2148,6 +2184,7 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -2201,6 +2238,7 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -2426,6 +2464,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/types:span",
],
)
@@ -2568,6 +2607,7 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
@@ -2675,6 +2715,22 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "maybe_owning_device_memory",
+ srcs = [
+ "maybe_owning_device_memory.cc",
+ ],
+ hdrs = [
+ "maybe_owning_device_memory.h",
+ ],
+ deps = [
+ ":device_memory_allocator",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:variant",
],
)
@@ -3023,6 +3079,7 @@ cc_library(
"//tensorflow/core:stream_executor_no_cuda",
"//third_party/eigen3",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
@@ -3043,7 +3100,7 @@ cc_library(
hdrs = ["tuple_util.h"],
deps = [
":hlo",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index c236453fc7..7c078f07d7 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -26,13 +26,16 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
@@ -44,7 +47,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -125,6 +127,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleImag(HloInstruction* imag) override;
+ Status HandleIota(HloInstruction* instruction) override;
+
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleDivide(HloInstruction* divide) override;
@@ -447,8 +451,7 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
Status AlgebraicSimplifierVisitor::HandleConcatenate(
HloInstruction* concatenate) {
- tensorflow::gtl::ArraySlice<HloInstruction*> operands(
- concatenate->operands());
+ absl::Span<HloInstruction* const> operands(concatenate->operands());
if (operands.size() == 1) {
// Unary concatenates are useless.
ReplaceInstructionIfSameShape(concatenate, operands[0]);
@@ -551,6 +554,14 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
constant,
HloInstruction::CreateBroadcast(constant->shape(), scalar, {}));
}
+
+ // If a literal is an increasing sequence from zero, replace it with an iota.
+ if (ShapeUtil::Rank(constant->shape()) == 1 &&
+ ShapeUtil::ElementsIn(constant->shape()) > 1 &&
+ constant->literal().IsR1Iota()) {
+ return ReplaceWithNewInstruction(
+ constant, HloInstruction::CreateIota(constant->shape(), 0));
+ }
return Status::OK();
}
@@ -578,7 +589,7 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
namespace {
template <typename T>
Status InvertConstant(const HloInstruction& constant, Literal* result) {
- return result->Populate<T>([&](tensorflow::gtl::ArraySlice<int64> indices) {
+ return result->Populate<T>([&](absl::Span<const int64> indices) {
return T{1.0} / constant.literal().Get<T>(indices);
});
}
@@ -1238,9 +1249,8 @@ namespace {
// return value = {1, 3}
//
// Precondition: input_dim_indices is sorted.
-std::pair<bool, std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
- const HloInstruction* hlo,
- tensorflow::gtl::ArraySlice<int64> input_dim_indices) {
+absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
+ const HloInstruction* hlo, absl::Span<const int64> input_dim_indices) {
CHECK_EQ(HloOpcode::kReshape, hlo->opcode());
CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
@@ -1258,11 +1268,11 @@ std::pair<bool, std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
}
if (i >= unmodified_dims.size() ||
unmodified_dims[i].first != input_dim_index) {
- return std::make_pair(false, std::vector<int64>());
+ return absl::nullopt;
}
output_dim_indices.push_back(unmodified_dims[i].second);
}
- return std::make_pair(true, output_dim_indices);
+ return output_dim_indices;
}
// Returns true if the output of "instruction" is a permutation of the
@@ -1391,6 +1401,15 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
return Status::OK();
}
+ // broadcast(iota) -> iota.
+ if (operand->opcode() == HloOpcode::kIota) {
+ return ReplaceWithNewInstruction(
+ broadcast,
+ HloInstruction::CreateIota(
+ broadcast->shape(),
+ dims[Cast<HloIotaInstruction>(operand)->iota_dimension()]));
+ }
+
// Merge two consecutive broadcasts into a single one.
if (operand->opcode() == HloOpcode::kBroadcast) {
std::vector<int64> new_dimensions;
@@ -1445,6 +1464,19 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
return Status::OK();
}
+Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
+ // iota -> zero if the iota dimension never produces an element other than
+ // zero.
+ auto* iota = Cast<HloIotaInstruction>(instruction);
+ if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
+ auto zero = computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(iota->shape().element_type()).CloneToUnique()));
+ return ReplaceWithNewInstruction(
+ iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
+ }
+ return Status::OK();
+}
+
Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) {
return ReplaceWithNewInstruction(
@@ -1719,12 +1751,25 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
auto opt_dims = ReshapeLeavesDimensionsUnmodified(
reshape, reshape->operand(0)->dimensions());
- if (opt_dims.first) {
+ if (opt_dims.has_value()) {
return ReplaceWithNewInstruction(
reshape,
HloInstruction::CreateBroadcast(
reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0),
- opt_dims.second));
+ *opt_dims));
+ }
+ }
+
+ // reshape(iota) -> iota.
+ if (operand->opcode() == HloOpcode::kIota) {
+ auto* iota = Cast<HloIotaInstruction>(operand);
+ auto opt_dims =
+ ReshapeLeavesDimensionsUnmodified(reshape, {iota->iota_dimension()});
+ if (opt_dims.has_value()) {
+ CHECK_EQ(opt_dims->size(), 1);
+ return ReplaceWithNewInstruction(
+ reshape,
+ HloInstruction::CreateIota(reshape->shape(), opt_dims->front()));
}
}
@@ -1821,7 +1866,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
auto arg = reduce->mutable_operand(0);
auto init_value = reduce->mutable_operand(1);
- tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ absl::Span<const int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
if (ShapeUtil::IsZeroElementArray(arg->shape()) ||
ShapeUtil::IsZeroElementArray(reduce->shape())) {
@@ -2183,7 +2228,141 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
.CloneToUnique())),
{}));
}
+
const auto& window = convolution->window();
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
+
+ // Try to merge padding/dilation of the input with the convolution's window.
+ TF_ASSIGN_OR_RETURN(bool folded_input_pad, [&]() -> StatusOr<bool> {
+ if (lhs->opcode() != HloOpcode::kPad) {
+ return false;
+ }
+
+ // Convolution's padding is always zero, so bail if the kPad is adding
+ // something other than zero.
+ if (!IsAll(lhs->operand(1), 0)) {
+ return false;
+ }
+
+ const auto& padding = lhs->padding_config();
+
+ // Can't pad batch or feature dims.
+ for (int64 dim :
+ {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
+ const auto& p = padding.dimensions(dim);
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0) {
+ return false;
+ }
+ }
+
+ // Compute the window which is the result of merging the kPad and the
+ // convolution's existing window.
+ Window new_window = window;
+ for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
+ auto& w = *new_window.mutable_dimensions(dim);
+ const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
+ // Edge padding composes with itself in the straightforward way, but
+ // composing interior padding is nontrivial, and we cowardly refuse to
+ // think about it. If we see interior padding in either the kPad or conv,
+ // bail if there's any sort of padding in the other.
+ if (p.interior_padding() != 0 &&
+ (w.padding_low() != 0 || w.padding_high() != 0 ||
+ w.base_dilation() != 1)) {
+ return false;
+ }
+ if (w.base_dilation() != 1 &&
+ (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0)) {
+ return false;
+ }
+
+ w.set_padding_low(w.padding_low() + p.edge_padding_low());
+ w.set_padding_high(w.padding_high() + p.edge_padding_high());
+ if (p.interior_padding() != 0) {
+ CHECK_EQ(w.base_dilation(), 1);
+ w.set_base_dilation(1 + p.interior_padding());
+ }
+ }
+
+ auto new_conv = convolution->CloneWithNewOperands(
+ convolution->shape(), {lhs->mutable_operand(0), rhs});
+ new_conv->set_window(new_window);
+ TF_RETURN_IF_ERROR(
+ ReplaceWithNewInstruction(convolution, std::move(new_conv)));
+ return true;
+ }());
+
+ if (folded_input_pad) {
+ return Status::OK();
+ }
+
+ // Try to merge dilation of the filter with the convolution's window.
+ TF_ASSIGN_OR_RETURN(bool folded_filter_pad, [&]() -> StatusOr<bool> {
+ if (rhs->opcode() != HloOpcode::kPad) {
+ return false;
+ }
+
+ // Convolution's padding is always zero, so bail if the kPad is adding
+ // something other than zero.
+ if (!IsAll(rhs->operand(1), 0)) {
+ return false;
+ }
+
+ const auto& padding = rhs->padding_config();
+
+ // Can't pad or dilate feature dims.
+ for (int64 dim : {dnums.kernel_input_feature_dimension(),
+ dnums.kernel_output_feature_dimension()}) {
+ const auto& p = padding.dimensions(dim);
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0) {
+ return false;
+ }
+ }
+
+ // Compute the window which is the result of merging the kPad and the
+ // convolution's existing window.
+ Window new_window = convolution->window();
+ for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
+ auto& w = *new_window.mutable_dimensions(dim);
+ const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
+
+ // We can only do this transformation if p adds dilation to the filter --
+ // edge padding on the filter is not supported in conv.
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
+ return false;
+ }
+
+ // Nothing to do if the kPad for this dim is entirely a nop.
+ if (p.interior_padding() == 0) {
+ continue;
+ }
+
+ // We cowardly refuse to think about how dilation composes with itself;
+ // bail if both the kPad and conv have dilation on this dimension.
+ if (w.window_dilation() > 1) {
+ return false;
+ }
+ CHECK_EQ(w.window_dilation(), 1);
+ w.set_window_dilation(1 + p.interior_padding());
+ w.set_size(rhs->operand(0)->shape().dimensions(
+ dnums.kernel_spatial_dimensions(dim)));
+ }
+
+ auto new_conv = convolution->CloneWithNewOperands(
+ convolution->shape(), {lhs, rhs->mutable_operand(0)});
+ new_conv->set_window(new_window);
+ TF_RETURN_IF_ERROR(
+ ReplaceWithNewInstruction(convolution, std::move(new_conv)));
+ return true;
+ }());
+
+ if (folded_filter_pad) {
+ return Status::OK();
+ }
+
if (!enable_conv_simplification_) {
return Status::OK();
}
@@ -2200,8 +2379,6 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
return Status::OK();
}
- const ConvolutionDimensionNumbers& dnums =
- convolution->convolution_dimension_numbers();
const Shape& input_shape = lhs->shape();
const Shape& filter_shape = rhs->shape();
const Shape& convolution_shape = convolution->shape();
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index bb63ea26d4..43a891e4fa 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -23,8 +23,10 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
@@ -52,12 +54,7 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() {
return [](const Shape&, const Shape&) { return false; };
}
-class AlgebraicSimplifierTest : public HloVerifiedTestBase {
- public:
- AlgebraicSimplifierTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-};
+class AlgebraicSimplifierTest : public HloVerifiedTestBase {};
// Test that A + 0 is simplified to A
TEST_F(AlgebraicSimplifierTest, AddZero) {
@@ -296,6 +293,21 @@ TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) {
EXPECT_THAT(root, op::Constant());
}
+TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) {
+ HloComputation::Builder builder(TestName());
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f})));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_THAT(root, op::Constant());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Iota());
+}
+
// Test that A - 0 is simplified to A
TEST_F(AlgebraicSimplifierTest, SubZero) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
@@ -519,7 +531,7 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
HloInstruction::CreateParameter(0, r1f32, "param0"));
HloInstruction* constant =
builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::CreateR1<float>({0.f, 1.f, 2.f})));
+ LiteralUtil::CreateR1<float>({1.f, 2.f, 3.f})));
builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide,
param0, constant));
@@ -1826,6 +1838,126 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {
op::Reshape(op::Broadcast(param)));
}
+TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(HloInstruction::CreateIota(
+ ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), 2));
+ Shape result_shape = ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2});
+ builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Iota());
+ EXPECT_TRUE(
+ ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
+}
+
+TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {1, 1}), 0));
+ auto result_shape = iota->shape();
+
+ auto computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Iota());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+
+ auto root = computation->root_instruction();
+ EXPECT_THAT(root, op::Broadcast(op::Constant()));
+ EXPECT_EQ(0.0f, root->operand(0)->literal().GetFirstElement<float>());
+ EXPECT_TRUE(
+ ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
+}
+
+TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2}), 1));
+ builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), iota));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+}
+
+TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4}), 2));
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), iota));
+
+ HloComputation* computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Iota());
+ EXPECT_EQ(Cast<HloIotaInstruction>(computation->root_instruction())
+ ->iota_dimension(),
+ 3);
+}
+
+TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 2}), 2));
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {6, 1, 1, 2}), iota));
+
+ HloComputation* computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Iota());
+ const int64 iota_dim =
+ Cast<HloIotaInstruction>(computation->root_instruction())
+ ->iota_dimension();
+ EXPECT_THAT(iota_dim, ::testing::AnyOf(1, 2, 3));
+}
+
+TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), 2));
+ builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6, 8}), iota));
+
+ HloComputation* computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+}
+
TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
HloComputation::Builder builder(TestName());
HloInstruction* param =
@@ -2012,6 +2144,264 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) {
EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values));
}
+// Used for TEST_Ps that test merging (or not) of a kPad instruction into a
+// convolution's Window.
+struct ConvPaddingTestcase {
+ ConvPaddingTestcase(absl::string_view padding,
+ absl::string_view orig_conv_window,
+ absl::string_view expected_conv_window)
+ : ConvPaddingTestcase(padding, orig_conv_window, expected_conv_window,
+ /*pad_value=*/0) {}
+
+ ConvPaddingTestcase(absl::string_view padding,
+ absl::string_view orig_conv_window,
+ absl::string_view expected_conv_window, float pad_value)
+ : padding(padding),
+ orig_conv_window(orig_conv_window),
+ expected_conv_window(expected_conv_window),
+ pad_value(pad_value) {}
+
+ string ToString() const {
+ return absl::StrFormat(
+ "padding=%s, orig_conv_window=%s, expected_conv_window=%s, "
+ "pad_value=%f",
+ padding, orig_conv_window, expected_conv_window, pad_value);
+ }
+
+ string padding;
+ string orig_conv_window;
+ string expected_conv_window;
+ float pad_value;
+};
+
+// ConvInputPaddingTest (and its one associated TEST_P testcase) checks that a
+// computation that does
+//
+// conv(pad(param0, padding=padding), param1), window=orig_conv_window
+//
+// gets transformed by AlgebraicSimplifier to
+//
+// conv(param0, param1), window=expected_conv_window
+//
+// or, if expected_conv_window is the empty string, checks that
+// AlgebraicSimplifier does *not* transform the original convolution.
+class ConvInputPaddingTest
+ : public AlgebraicSimplifierTest,
+ public ::testing::WithParamInterface<ConvPaddingTestcase> {};
+
+INSTANTIATE_TEST_CASE_P(
+ ConvInputPaddingTestCases, ConvInputPaddingTest,
+ ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{
+ // Merge this edge padding into the conv.
+ {"0_0x0_0x1_1x2_2", "", "pad=1_1x2_2"},
+ // Merge this edge padding with the conv's edge padding.
+ {"0_0x0_0x1_2x3_4", "pad=10_10x20_20", "pad=11_12x23_24"},
+ // Merge this interior-padded kPad with the unpadded conv. The 3x6
+ // interior padding gets transformed to 4x7 conv lhs dilation.
+ {"0_0x0_0x1_2_3x4_5_6", "", "pad=1_2x4_5 lhs_dilate=4x7"},
+ // kPad has dilation on one dim, conv has it on the other; merge them.
+ {"0_0x0_0x0_0_1x0_0_0", "lhs_dilate=1x10", "lhs_dilate=2x10"},
+ // kPad has dilation and edge padding on one dim, conv has them on the
+ // other; merge them.
+ {"0_0x0_0x0_1_1x0_0_0", "pad=0_0x3_0 lhs_dilate=1x10",
+ "pad=0_1x3_0 lhs_dilate=2x10"},
+
+ // Don't transform if the pad value is nonzero.
+ {"0_0x0_0x1_1x2_2", "", "", /*pad_value=*/1},
+
+ // We refuse to transform the following because on some dimension, one
+ // of the kPad and conv has dilation and the other has some sort of
+ // padding.
+ {"0_0x0_0x0_0_1x0_0", "pad=1_0x0_0", ""},
+ {"0_0x0_0x0_0_1x0_0", "pad=0_1x0_0", ""},
+ {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""},
+ {"0_0x0_0x1_0_0x0_0", "lhs_dilate=2x1", ""},
+ {"0_0x0_0x0_1_0x0_0", "lhs_dilate=2x1", ""},
+ {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""},
+
+ // We can't merge feature or batch padding into the conv.
+ {"1_0x0_0x0_0x0_0", "", ""},
+ {"0_0x1_0x0_0x0_0", "", ""},
+ }));
+
+TEST_P(ConvInputPaddingTest, DoTest) {
+ ConvPaddingTestcase testcase = GetParam();
+
+ // It would be better to put the testcase's ToString into the test name, but
+ // gUnit has constraints on what can go into test names, and any reasonable
+ // implementation of ToString() seems to violate them.
+ SCOPED_TRACE(testcase.ToString());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto* input = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1024, 128, 100, 100}), // bf01
+ "input"));
+ auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0(testcase.pad_value)));
+
+ PaddingConfig padding_config =
+ ParsePaddingConfig(testcase.padding).ValueOrDie();
+ auto* lhs_pad = builder.AddInstruction(HloInstruction::CreatePad(
+ ShapeInference::InferPadShape(input->shape(), pad_value->shape(),
+ padding_config)
+ .ValueOrDie(),
+ input, pad_value, padding_config));
+
+ auto* filter = builder.AddInstruction(HloInstruction::CreateParameter(
+ 1,
+ ShapeUtil::MakeShape(
+ F32, {lhs_pad->shape().dimensions(1), 256, 3, 3}), // io01
+ "input"));
+
+ ConvolutionDimensionNumbers dnums =
+ ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie();
+ Window window =
+ ParseWindow(absl::StrCat("size=3x3 ", testcase.orig_conv_window))
+ .ValueOrDie();
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(),
+ window, dnums)
+ .ValueOrDie(),
+ lhs_pad, filter, window, dnums));
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ if (testcase.expected_conv_window.empty()) {
+ ASSERT_FALSE(simplifier.Run(module).ValueOrDie());
+ } else {
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
+ auto* conv = module->entry_computation()->root_instruction();
+ SCOPED_TRACE(module->ToString());
+ ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter()));
+ EXPECT_EQ(window_util::ToString(conv->window()),
+ absl::StrCat("size=3x3 ", testcase.expected_conv_window));
+ }
+}
+
+// ConvFilterPaddingTest (and its one associated TEST_P) checks that a
+// computation that does
+//
+// conv(param0, pad(param1, padding=padding)), window=orig_conv_window
+//
+// gets transformed by AlgebraicSimplifier to
+//
+// conv(param0, param1), window=expected_conv_window
+//
+// or, if expected_conv_window is the empty string, checks that
+// AlgebraicSimplifier does *not* transform the original convolution.
+class ConvFilterPaddingTest
+ : public AlgebraicSimplifierTest,
+ public ::testing::WithParamInterface<ConvPaddingTestcase> {};
+
+INSTANTIATE_TEST_CASE_P(
+ ConvFilterPaddingTestCases, ConvFilterPaddingTest,
+ ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{
+ // Can only merge interior padding on the filter's spatial dimensions;
+ // all
+ // other paddings (edge padding and interior padding on the channel
+ // dims)
+ // should be rejected out of hand.
+ {"1_0_0x0_0_0x0_0x0_0", "", ""},
+ {"0_1_0x0_0_0x0_0x0_0", "", ""},
+ {"0_0_1x0_0_0x0_0x0_0", "", ""},
+ {"0_0_0x1_0_0x0_0x0_0", "", ""},
+ {"0_0_0x0_1_0x0_0x0_0", "", ""},
+ {"0_0_0x0_0_1x0_0x0_0", "", ""},
+ {"0_0_0x0_0_0x1_0x0_0", "", ""},
+ {"0_0_0x0_0_0x0_1x0_0", "", ""},
+ {"0_0_0x0_0_0x0_0x1_0", "", ""},
+ {"0_0_0x0_0_0x0_0x0_1", "", ""},
+
+ // Interior padding on channel dims can be merged into the conv, so long
+ // as the conv and pad don't have interior padding on the same dim.
+ {"0_0x0_0x0_0_5x0_0", "", "rhs_dilate=6x1"},
+ {"0_0x0_0x0_0x0_0_10", "", "rhs_dilate=1x11"},
+ {"0_0x0_0x0_0_10x0_0_100", "", "rhs_dilate=11x101"},
+ {"0_0x0_0x0_0_1x0_0", "rhs_dilate=1x10", "rhs_dilate=2x10"},
+ {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x1", "rhs_dilate=10x6"},
+
+ // Can't merge if for a given dim there's interior padding on both the
+ // pad and conv.
+ {"0_0x0_0x0_0_1x0_0", "rhs_dilate=2x10", ""},
+ {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x2", ""},
+
+ // Don't transform if the pad value is nonzero.
+ {"0_0x0_0x0_0_5x0_0", "", "", /*pad_value=*/1},
+ }));
+
+TEST_P(ConvFilterPaddingTest, DoIt) {
+ ConvPaddingTestcase testcase = GetParam();
+
+ // It would be better to put the testcase's ToString into the test name, but
+ // gUnit has constraints on what can go into test names, and any reasonable
+ // implementation of ToString() seems to violate them.
+ SCOPED_TRACE(testcase.ToString());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0(testcase.pad_value)));
+ auto* filter = builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {128, 256, 3, 3}), // io01
+ "input"));
+ PaddingConfig padding_config =
+ ParsePaddingConfig(testcase.padding).ValueOrDie();
+ auto* rhs_pad = builder.AddInstruction(HloInstruction::CreatePad(
+ ShapeInference::InferPadShape(filter->shape(), pad_value->shape(),
+ padding_config)
+ .ValueOrDie(),
+ filter, pad_value, padding_config));
+
+ auto* input = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0,
+ ShapeUtil::MakeShape(
+ F32, {1024, rhs_pad->shape().dimensions(0), 100, 100}), // bf01
+ "input"));
+
+ ConvolutionDimensionNumbers dnums =
+ ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie();
+ Window window = ParseWindow(absl::StrFormat("size=%dx%d %s",
+ rhs_pad->shape().dimensions(2),
+ rhs_pad->shape().dimensions(3),
+ testcase.orig_conv_window))
+ .ValueOrDie();
+ auto* orig_conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(),
+ window, dnums)
+ .ValueOrDie(),
+ input, rhs_pad, window, dnums));
+
+ // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place
+ // after the transformation.
+ PrecisionConfigProto precision_config;
+ precision_config.add_operand_precision(PrecisionConfigProto::HIGH);
+ precision_config.add_operand_precision(PrecisionConfigProto::HIGHEST);
+ orig_conv->set_precision_config(precision_config);
+
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ if (testcase.expected_conv_window.empty()) {
+ ASSERT_FALSE(simplifier.Run(module).ValueOrDie());
+ } else {
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
+ auto* conv = module->entry_computation()->root_instruction();
+ SCOPED_TRACE(module->ToString());
+ ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter()));
+ EXPECT_EQ(window_util::ToString(conv->window()),
+ absl::StrFormat("size=%dx%d %s",
+ conv->operand(1)->shape().dimensions(2),
+ conv->operand(1)->shape().dimensions(3),
+ testcase.expected_conv_window));
+ EXPECT_THAT(
+ conv->precision_config().operand_precision(),
+ ElementsAre(PrecisionConfigProto::HIGH, PrecisionConfigProto::HIGHEST));
+ }
+}
+
TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
struct ConvTestOptions {
int in_batch = 10;
@@ -2115,7 +2505,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
auto out_dims = in_dims;
out_dims[in_channel_idx] = options.f_output_channels;
- auto make_shape = [](tensorflow::gtl::ArraySlice<int64> dims,
+ auto make_shape = [](absl::Span<const int64> dims,
bool minor_to_major_layout) {
if (minor_to_major_layout) {
return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3});
@@ -2653,6 +3043,47 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) {
EXPECT_THAT(root->dimensions(), ElementsAre(1, 3));
}
+// Test that a broadcast of an iota can be merged to one iota.
+TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) {
+ HloComputation::Builder builder(TestName());
+ Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
+ HloInstruction* iota =
+ builder.AddInstruction(HloInstruction::CreateIota(r2f32, 1));
+ Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
+ builder.AddInstruction(HloInstruction::CreateBroadcast(r3f32, iota, {0, 2}));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Iota());
+ EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
+}
+
+// Test that a broadcast of an iota can be merged to one iota.
+TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) {
+ HloComputation::Builder builder(TestName());
+ Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
+ HloInstruction* iota =
+ builder.AddInstruction(HloInstruction::CreateIota(r3f32, 1));
+ Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
+ builder.AddInstruction(
+ HloInstruction::CreateBroadcast(r4f32, iota, {1, 2, 3}));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Iota());
+ EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
+}
+
struct PadReduceWindowEffectiveBroadcastCase {
std::vector<int64> input_spatials;
std::vector<int64> symmetric_pad_spatials;
@@ -2686,8 +3117,8 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
// a and b are parallel bounds we can either turn into a B F S0 S1 or
// `B S0 S1 F` kind of pattern.
- auto decorate_spatials = [&param](tensorflow::gtl::ArraySlice<int64> spatials,
- int64 a, int64 b) {
+ auto decorate_spatials = [&param](absl::Span<const int64> spatials, int64 a,
+ int64 b) {
std::vector<int64> result;
if (param.prepend_a) {
result.push_back(a);
@@ -2856,12 +3287,7 @@ struct DotOfConcatTestSpec {
class DotOfConcatSimplificationTest
: public HloVerifiedTestBase,
- public ::testing::WithParamInterface<DotOfConcatTestSpec> {
- public:
- DotOfConcatSimplificationTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-};
+ public ::testing::WithParamInterface<DotOfConcatTestSpec> {};
// Test that we transform
// dot(const, concat(A, B, C))
@@ -3034,12 +3460,7 @@ struct DotOfGatherTestSpec {
class DotOfGatherSimplificationTest
: public HloVerifiedTestBase,
- public ::testing::WithParamInterface<DotOfGatherTestSpec> {
- public:
- DotOfGatherSimplificationTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-};
+ public ::testing::WithParamInterface<DotOfGatherTestSpec> {};
// input: dot(DS(ctA), ctB))
// where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}.
diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc
index a6889cb171..5c180cbdd4 100644
--- a/tensorflow/compiler/xla/service/backend.cc
+++ b/tensorflow/compiler/xla/service/backend.cc
@@ -112,11 +112,11 @@ StatusOr<StreamPool::Ptr> Backend::BorrowStream(se::StreamExecutor* executor) {
return stream_pools_.at(executor).BorrowStream(executor);
}
-Backend::Backend(
- se::Platform* platform, Compiler* compiler,
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
- TransferManager* transfer_manager, ComputationPlacer* computation_placer,
- int intra_op_parallelism_threads)
+Backend::Backend(se::Platform* platform, Compiler* compiler,
+ absl::Span<se::StreamExecutor* const> stream_executors,
+ TransferManager* transfer_manager,
+ ComputationPlacer* computation_placer,
+ int intra_op_parallelism_threads)
: platform_(platform),
compiler_(compiler),
transfer_manager_(transfer_manager),
diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h
index 4a6a78daf0..a2dafbe803 100644
--- a/tensorflow/compiler/xla/service/backend.h
+++ b/tensorflow/compiler/xla/service/backend.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@@ -29,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -149,7 +149,7 @@ class Backend {
private:
struct EigenThreadPoolWrapper;
Backend(se::Platform* platform, Compiler* compiler,
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
+ absl::Span<se::StreamExecutor* const> stream_executors,
TransferManager* transfer_manager,
ComputationPlacer* computation_placer,
int intra_op_parallelism_threads);
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
index b342acb025..38f1a5d3a6 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
@@ -24,12 +24,7 @@ namespace {
namespace op = xla::testing::opcode_matchers;
-class BatchDotSimplificationTest : public HloVerifiedTestBase {
- public:
- BatchDotSimplificationTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-};
+class BatchDotSimplificationTest : public HloVerifiedTestBase {};
TEST_F(BatchDotSimplificationTest,
ElideSingleDegenerateBatchDotDim_VectorVector) {
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc
index 01931b2d02..ec281ae68f 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@@ -34,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc
index 1b8b2d2045..d63287539d 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc
@@ -15,12 +15,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc
index 32573ed355..d5b1148058 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc
@@ -15,13 +15,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -69,8 +69,7 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
// Inserts conversion HLOs to replace the called computations' BF16
// operands/outputs to F32.
Status ConvertCalledComputations(
- HloInstruction* hlo,
- tensorflow::gtl::ArraySlice<HloComputation*> bf16_called_comps);
+ HloInstruction* hlo, absl::Span<HloComputation* const> bf16_called_comps);
HloComputation* computation_;
const BFloat16Support* bfloat16_support_;
@@ -114,8 +113,7 @@ Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand(
}
Status BFloat16NormalizationVisitor::ConvertCalledComputations(
- HloInstruction* hlo,
- tensorflow::gtl::ArraySlice<HloComputation*> bf16_called_comps) {
+ HloInstruction* hlo, absl::Span<HloComputation* const> bf16_called_comps) {
std::map<HloComputation*, HloComputation*> cloned_computations;
for (auto& comp : bf16_called_comps) {
auto cloned = comp->parent()->AddEmbeddedComputation(comp->Clone());
@@ -359,6 +357,7 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
hlo->opcode() == HloOpcode::kConditional) {
return Status::OK();
}
+ // TODO(b/112040122): Correctly normalize variadic reduce.
if ((hlo->opcode() == HloOpcode::kSort ||
hlo->opcode() == HloOpcode::kCrossReplicaSum) &&
ShapeUtil::IsTuple(hlo->shape())) {
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index 2fb401c428..545a6ecfb1 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -407,7 +407,7 @@ void BFloat16Propagation::AdjustCalledComputationParameters(
HloInstruction* hlo) {
auto adjust_computation =
[this, hlo](HloComputation* computation,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ absl::Span<HloInstruction* const> operands) {
// Adjust parameters.
CHECK_EQ(operands.size(), computation->num_parameters());
for (int64 i = 0; i < operands.size(); ++i) {
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index b11f15ec7b..8b8c6bfd26 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -58,12 +58,65 @@ string ColocatedBufferSetsToString(const T& container, const char* title) {
return result;
}
-// Walk the call graph of the HLO module and place each computation into either
-// thread_local_computations or global_computations depending upon whether the
-// computation requires thread-local allocations or global allocations. The
-// elements in thread_local_computations and global_computations are in post
-// order (if computation A has an instruction which calls computation B, then A
-// will appear after B in the vector).
+// Checks that points-to set of 'instruction' is unambiguous and distinct
+// (ensured by CopyInsertion), then adds the buffer from the points-to set at
+// 'index' to 'colocated_set'.
+const LogicalBuffer* AddBufferToColocatedSet(
+ const HloInstruction* instruction, const ShapeIndex& index,
+ const TuplePointsToAnalysis& points_to_analysis,
+ std::vector<const LogicalBuffer*>* colocated_set) {
+ // CopyInsertion ensures root points-to set is unambiguous and distinct.
+ const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
+ DCHECK(!points_to.IsAmbiguous());
+ colocated_set->push_back(points_to.element(index)[0]);
+ return colocated_set->back();
+}
+
+// Given the interference map of a graph (the list of interfering node indices
+// for each node), perform graph coloring such that interfering nodes are
+// assigned to different colors. Returns the assigned color of the nodes, where
+// the colors are represented as integer values [0, color_count).
+std::vector<int64> ColorInterferenceGraph(
+ const std::vector<std::vector<int64>>& interference_map) {
+ const int64 node_count = interference_map.size();
+
+ // Sort the nodes such that we assign nodes with more interference first. This
+ // relies on the common heuristic of assigning the most constrained node
+ // first, but it would be good to investigate other ordering heuristics too.
+ std::vector<int64> nodes(node_count);
+ std::iota(nodes.begin(), nodes.end(), 0);
+ std::sort(nodes.begin(), nodes.end(),
+ [&interference_map](const int64 i, const int64 j) {
+ return interference_map[i].size() > interference_map[j].size();
+ });
+
+ const int64 kColorUnassigned = -1;
+ std::vector<int64> assigned_colors(node_count, kColorUnassigned);
+ for (int64 node : nodes) {
+ // Mark the colors that are already assigned to the neighbors.
+ std::vector<bool> available_colors(node_count, true);
+ for (int64 neighbor : interference_map[node]) {
+ int64 color = assigned_colors[neighbor];
+ if (color != kColorUnassigned) {
+ available_colors[color] = false;
+ }
+ }
+
+ // Find the color that is not yet assigned to the neighbors.
+ int64 color = kColorUnassigned;
+ for (color = 0; color < available_colors.size(); ++color) {
+ if (available_colors[color]) {
+ break;
+ }
+ }
+ CHECK_NE(color, kColorUnassigned);
+ assigned_colors[node] = color;
+ }
+ return assigned_colors;
+}
+
+} // namespace
+
Status GatherComputationsByAllocationType(
const HloModule* module,
std::vector<const HloComputation*>* thread_local_computations,
@@ -165,65 +218,6 @@ Status GatherComputationsByAllocationType(
return Status::OK();
}
-// Checks that points-to set of 'instruction' is unambiguous and distinct
-// (ensured by CopyInsertion), then adds the buffer from the points-to set at
-// 'index' to 'colocated_set'.
-const LogicalBuffer* AddBufferToColocatedSet(
- const HloInstruction* instruction, const ShapeIndex& index,
- const TuplePointsToAnalysis& points_to_analysis,
- std::vector<const LogicalBuffer*>* colocated_set) {
- // CopyInsertion ensures root points-to set is unambiguous and distinct.
- const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
- DCHECK(!points_to.IsAmbiguous());
- colocated_set->push_back(points_to.element(index)[0]);
- return colocated_set->back();
-}
-
-// Given the interference map of a graph (the list of interfering node indices
-// for each node), perform graph coloring such that interfering nodes are
-// assigned to different colors. Returns the assigned color of the nodes, where
-// the colors are represented as integer values [0, color_count).
-std::vector<int64> ColorInterferenceGraph(
- const std::vector<std::vector<int64>>& interference_map) {
- const int64 node_count = interference_map.size();
-
- // Sort the nodes such that we assign nodes with more interference first. This
- // relies on the common heuristic of assigning the most constrained node
- // first, but it would be good to investigate other ordering heuristics too.
- std::vector<int64> nodes(node_count);
- std::iota(nodes.begin(), nodes.end(), 0);
- std::sort(nodes.begin(), nodes.end(),
- [&interference_map](const int64 i, const int64 j) {
- return interference_map[i].size() > interference_map[j].size();
- });
-
- const int64 kColorUnassigned = -1;
- std::vector<int64> assigned_colors(node_count, kColorUnassigned);
- for (int64 node : nodes) {
- // Mark the colors that are already assigned to the neighbors.
- std::vector<bool> available_colors(node_count, true);
- for (int64 neighbor : interference_map[node]) {
- int64 color = assigned_colors[neighbor];
- if (color != kColorUnassigned) {
- available_colors[color] = false;
- }
- }
-
- // Find the color that is not yet assigned to the neighbors.
- int64 color = kColorUnassigned;
- for (color = 0; color < available_colors.size(); ++color) {
- if (available_colors[color]) {
- break;
- }
- }
- CHECK_NE(color, kColorUnassigned);
- assigned_colors[node] = color;
- }
- return assigned_colors;
-}
-
-} // namespace
-
size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const {
uint64 h = std::hash<int64>()(s.index());
h = tensorflow::Hash64Combine(h, std::hash<int64>()(s.offset()));
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h
index 94495290c1..24ba7c16f5 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.h
+++ b/tensorflow/compiler/xla/service/buffer_assignment.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
@@ -41,6 +41,17 @@ limitations under the License.
namespace xla {
+// Walk the call graph of the HLO module and place each computation into either
+// thread_local_computations or global_computations depending upon whether the
+// computation requires thread-local allocations or global allocations. The
+// elements in thread_local_computations and global_computations are in post
+// order (if computation A has an instruction which calls computation B, then A
+// will appear after B in the vector).
+Status GatherComputationsByAllocationType(
+ const HloModule* module,
+ std::vector<const HloComputation*>* thread_local_computations,
+ std::vector<const HloComputation*>* global_computations);
+
// This class abstracts an allocation of contiguous memory which can hold the
// values described by LogicalBuffers. Each LogicalBuffer occupies a sub-range
// of the allocation, represented by a Slice. A single BufferAllocation may hold
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 52abda16c4..8bd1533972 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -37,7 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/macros.h"
@@ -79,9 +79,8 @@ const std::vector<const HloInstruction*> GetInstructions(HloInstruction* root) {
return main_list.GetInstructions();
}
-class BufferAssignmentTest : public HloTestBase {
+class BufferAssignmentTest : public HloVerifiedTestBase {
protected:
- BufferAssignmentTest() {}
~BufferAssignmentTest() override {}
std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
@@ -119,7 +118,7 @@ class BufferAssignmentTest : public HloTestBase {
std::unique_ptr<BufferAssignment> RunBufferAssignmentWithInstructionSequence(
HloModule* module,
- tensorflow::gtl::ArraySlice<const HloInstruction*> instruction_sequence,
+ absl::Span<const HloInstruction* const> instruction_sequence,
int64 alignment = 1) {
SequentialHloOrdering::HloModuleSequence module_sequence;
module_sequence[module->entry_computation()] =
@@ -148,6 +147,17 @@ class BufferAssignmentTest : public HloTestBase {
return builder.Build();
}
+ std::unique_ptr<HloComputation> BuildReduceComputation(const string& name) {
+ auto builder = HloComputation::Builder(name);
+ auto param =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
+ auto param2 =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "y"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, param2));
+ return builder.Build();
+ }
+
// Builds a simple compare-to-limit (x < 4) computation for a While.
//
// condition:
@@ -164,8 +174,8 @@ class BufferAssignmentTest : public HloTestBase {
HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
auto index = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
- builder.AddInstruction(
- HloInstruction::CreateBinary(r0f32_, HloOpcode::kLt, index, const4));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, index, const4));
return builder.Build();
}
@@ -312,12 +322,12 @@ TEST_F(BufferAssignmentTest, ScalarConstant) {
module->AddEntryComputation(builder.Build());
{
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
}
{
- auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
+ auto buffers = RunBufferAssignmentNoBuffersForConstants(module);
EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
}
}
@@ -336,13 +346,13 @@ TEST_F(BufferAssignmentTest, BufferForConst) {
module->AddEntryComputation(builder.Build());
{
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
EXPECT_TRUE(buffers->HasTopLevelAllocation(const1));
GetAssignedOutputAllocation(*buffers, add);
}
{
- auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
+ auto buffers = RunBufferAssignmentNoBuffersForConstants(module);
EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
EXPECT_FALSE(buffers->HasTopLevelAllocation(const1));
GetAssignedOutputAllocation(*buffers, add);
@@ -364,7 +374,7 @@ TEST_F(BufferAssignmentTest, HasAllocationAt) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
// Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation()
// reports for the instruction directly.
EXPECT_EQ(buffers->HasTopLevelAllocation(tuple),
@@ -387,7 +397,7 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
// The copy node now has an output buffer.
GetAssignedOutputAllocation(*buffers, copy);
}
@@ -401,12 +411,14 @@ TEST_F(BufferAssignmentTest, Basic) {
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
- f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
+ f32vec100_, HloOpcode::kMultiply, broadcast, param0));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -414,7 +426,7 @@ TEST_F(BufferAssignmentTest, Basic) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
// Distinct input buffers were assigned for parameters.
BufferAllocation paramscalar_buffer =
@@ -448,12 +460,14 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) {
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
- f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
+ f32vec100_, HloOpcode::kMultiply, broadcast, param0));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -473,7 +487,7 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) {
return Status::OK();
};
- auto buffers = RunColoredBufferAssignment(module.get(), colorer);
+ auto buffers = RunColoredBufferAssignment(module, colorer);
// Distinct input buffers were assigned for parameters.
BufferAllocation paramscalar_buffer =
@@ -507,12 +521,14 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) {
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
- f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
+ f32vec100_, HloOpcode::kMultiply, broadcast, param0));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -540,7 +556,7 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) {
return Status::OK();
};
- auto buffers = RunColoredBufferAssignment(module.get(), colorer);
+ auto buffers = RunColoredBufferAssignment(module, colorer);
// Distinct input buffers were assigned for parameters.
BufferAllocation paramscalar_buffer =
@@ -577,12 +593,14 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
- f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
+ f32vec100_, HloOpcode::kMultiply, broadcast, param0));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
auto sub = builder.AddInstruction(
@@ -590,7 +608,7 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
// Input buffers were assigned for parameters.
BufferAllocation paramscalar_buffer =
@@ -641,7 +659,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) {
EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size";
// Assigns buffers and fetches sizes.
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
int64 size0 = ValidateBuffers(level0, *buffers);
int64 size1 = ValidateBuffers(level1, *buffers);
@@ -676,10 +694,10 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
// output. (Reuse is not safe in the general case, as it reshapes and some
// out-of-order reductions could overwrite an element before a use.)
//
- // param0[100] --- (exp1) --- (exp2) --- (reduce x+1) --- (exp3)
+ // param0[100] --- (exp1) --- (exp2) --- (reduce x+y) --- (exp3)
auto module = CreateNewModule();
auto reduce_computation =
- module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1"));
+ module->AddEmbeddedComputation(BuildReduceComputation("f32+f32"));
auto builder = HloComputation::Builder(TestName());
auto param0 = builder.AddInstruction(
@@ -700,7 +718,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
module->AddEntryComputation(builder.Build());
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
const std::vector<const HloInstruction*> instrs = GetInstructions(exp3);
ValidateBuffers(instrs, *buffers);
@@ -756,7 +774,7 @@ TEST_F(BufferAssignmentTest, ExampleWhile) {
EXPECT_EQ(8, levelb.size()) << "Invalid nested body size";
// Assigns buffers and fetches sizes.
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
int64 size0 = ValidateBuffers(level0, *buffers);
int64 sizec = ValidateBuffers(levelc, *buffers);
int64 sizeb = ValidateBuffers(levelb, *buffers);
@@ -821,7 +839,7 @@ TEST_F(BufferAssignmentTest, ExampleConditional) {
EXPECT_EQ(2, true_instrs.size());
EXPECT_EQ(2, false_instrs.size());
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
ValidateBuffers(conditional_instrs, *buffers);
ValidateBuffers(true_instrs, *buffers);
ValidateBuffers(false_instrs, *buffers);
@@ -859,7 +877,7 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// tanh and exp2 can reuse exp1's buffer
EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1));
@@ -888,7 +906,7 @@ TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// negate and broadcast should share a buffer.
EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
@@ -921,7 +939,7 @@ TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// The instructions should not share buffers.
EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
@@ -958,7 +976,7 @@ TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// The instructions should not share buffers.
EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
@@ -993,7 +1011,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// The broadcast output buffer cannot be shared.
EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
@@ -1025,7 +1043,7 @@ TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// negate and broadcast should share a buffer.
EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
@@ -1063,7 +1081,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// The broadcast output buffer cannot be shared.
EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
@@ -1107,7 +1125,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) {
HloInstruction::CreateMap(vec_shape, {call}, map_computation));
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// Allocations for the map computation should be thread-local and not
// live-out.
@@ -1156,7 +1174,7 @@ TEST_F(BufferAssignmentTest, TupleParameterAsOutput) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// There should be four allocations: one for vector of pointers, and one for
// each tuple element.
@@ -1192,7 +1210,7 @@ TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// Only some of the elements of the input param are liveout.
EXPECT_FALSE(
@@ -1235,7 +1253,7 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
EXPECT_EQ(3, assignment->Allocations().size());
}
@@ -1249,7 +1267,7 @@ TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) {
/*operands=*/{}, /*custom_call_target=*/"foo_function"));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
EXPECT_EQ(3, assignment->Allocations().size());
EXPECT_TRUE(
@@ -1280,7 +1298,7 @@ TEST_F(BufferAssignmentTest, TupleCallAsOutput) {
HloInstruction::CreateCall(tuple_shape, {param}, sub_computation));
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
EXPECT_EQ(2, assignment->Allocations().size());
// Buffers for call are colocated with the sub-computation.
@@ -1342,7 +1360,7 @@ TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) {
module->AddEntryComputation(std::move(a_computation));
module->AddEmbeddedComputation(std::move(b_computation));
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// Buffers for call are colocated with the sub-computations.
EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}),
@@ -1378,7 +1396,7 @@ TEST_F(BufferAssignmentTest, BitcastAsOutput) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// Bitcast should get the same allocation as the param.
EXPECT_EQ(1, assignment->Allocations().size());
@@ -1405,7 +1423,7 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// Select shallow copies one of its operands so it defines its own top-level
// buffer and receives its own allocation.
@@ -1443,7 +1461,7 @@ TEST_F(BufferAssignmentTest, TupleBufferNotReused) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// There should be no buffer reuse. The copy should not reuse the tuple
// buffer.
@@ -1477,12 +1495,12 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) {
auto dot_bc = builder.AddInstruction(
HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums));
builder.AddInstruction(
- HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 1));
+ HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0));
// Run buffer assignment with alignment=1.
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1);
+ auto assignment = RunBufferAssignment(module, /*alignment=*/1);
// There are 5 allocations: 3 parameters, 1 output, and 1 temp.
EXPECT_EQ(5, assignment->Allocations().size());
@@ -1501,7 +1519,7 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) {
EXPECT_EQ(80, slice_bc.allocation()->size());
// Re-run buffer assignment with alignment=64.
- assignment = RunBufferAssignment(module.get(), /*alignment=*/64);
+ assignment = RunBufferAssignment(module, /*alignment=*/64);
EXPECT_EQ(5, assignment->Allocations().size());
slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie();
slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie();
@@ -1532,12 +1550,14 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) {
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
- f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
+ f32vec100_, HloOpcode::kMultiply, broadcast, param0));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
builder.AddInstruction(HloInstruction::CreateBinary(
@@ -1545,16 +1565,13 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
- // Trivially, the set of peak memory logical buffer(s) of an allocation with a
- // single logical buffer should be exactly the logical buffer in that
- // allocation.
const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
const std::vector<const LogicalBuffer*>& peak_buffers =
mul_buffer.PeakMemoryLogicalBuffers();
ASSERT_EQ(peak_buffers.size(), 1);
- EXPECT_EQ(peak_buffers[0]->instruction(), mul);
+ EXPECT_EQ(peak_buffers[0]->instruction(), broadcast);
}
TEST_F(BufferAssignmentTest, PeakBuffers) {
@@ -1590,7 +1607,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) {
module->AddEntryComputation(builder.Build());
auto buffers = RunBufferAssignmentWithInstructionSequence(
- module.get(), {param, log, rev, neg, concat, root});
+ module, {param, log, rev, neg, concat, root});
// The temporary buffer should hold the 4 interior instructions.
const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat);
@@ -1646,7 +1663,7 @@ TEST_F(BufferAssignmentTest, PeakBuffersWhile) {
ShapeUtil::MakeShape(F32, {123, 123, 123}), bcast, {0}));
module->AddEntryComputation(builder.Build());
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, bcast);
const std::vector<const LogicalBuffer*>& peak_buffers =
buffer.PeakMemoryLogicalBuffers();
@@ -1696,15 +1713,13 @@ ENTRY main {
}
)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(hlo_text));
-
+ ParseAndVerifyModule(hlo_text);
HloInstruction* constant_1 =
- module->entry_computation()->GetInstructionWithName("constant.1.1");
+ module().entry_computation()->GetInstructionWithName("constant.1.1");
HloInstruction* constant_2 =
- module->entry_computation()->GetInstructionWithName("constant.1.2");
+ module().entry_computation()->GetInstructionWithName("constant.1.2");
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(&module());
{
const BufferAllocation& allocation_for_const_1 =
@@ -1733,7 +1748,7 @@ ENTRY main {
}
}
-class WhileBufferAssignmentTest : public HloTestBase {
+class WhileBufferAssignmentTest : public HloVerifiedTestBase {
protected:
std::unique_ptr<HloComputation> BuildWhileConditionComputation(
const string& name) {
@@ -1807,9 +1822,9 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
auto zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
- HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+ HloInstruction::CreateBroadcast(data_shape_, zero, {}));
auto output1 = builder.AddInstruction(
- HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+ HloInstruction::CreateBroadcast(data_shape_, zero, {}));
auto cond0 =
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
@@ -1833,8 +1848,8 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
module->AddEntryComputation(builder.Build());
- RunCopyInsertion(module.get());
- auto assignment = RunBufferAssignment(module.get());
+ RunCopyInsertion(module);
+ auto assignment = RunBufferAssignment(module);
// Verify 'input0' and read-only use while0{0} alias.
EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).ConsumeValueOrDie(),
@@ -1890,20 +1905,20 @@ ENTRY %test_module {
ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
})";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
+ ParseAndVerifyModule(module_str);
// Run CopyInsertion and check if the graph constructed above doesn't need
// any copies inserted for BufferAssignment to run.
- int64 instruction_count = module->instruction_count();
+ int64 instruction_count = module().instruction_count();
CopyInsertion copy_insertion;
- ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
- ASSERT_EQ(instruction_count, module->instruction_count());
+ ASSERT_IS_OK(copy_insertion.Run(&module()).status());
+ ASSERT_EQ(instruction_count, module().instruction_count());
// Get the instructions in the module.
- const HloInstruction* bcast = module->entry_computation()->root_instruction();
+ const HloInstruction* bcast =
+ module().entry_computation()->root_instruction();
const HloInstruction* param =
- module->entry_computation()->parameter_instruction(0);
+ module().entry_computation()->parameter_instruction(0);
ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
const HloInstruction* while1 = bcast->operand(0);
ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
@@ -1911,7 +1926,7 @@ ENTRY %test_module {
ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
// Run buffer assignment.
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(&module());
TF_ASSERT_OK_AND_ASSIGN(auto slice_param,
assignment->GetUniqueSlice(param, {}));
TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
@@ -1958,20 +1973,20 @@ ENTRY %test_module {
ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
})";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
+ ParseAndVerifyModule(module_str);
// Run CopyInsertion and check if the graph constructed above doesn't need
// any copies inserted for BufferAssignment to run.
- int64 instruction_count = module->instruction_count();
+ int64 instruction_count = module().instruction_count();
CopyInsertion copy_insertion;
- ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
- ASSERT_EQ(instruction_count, module->instruction_count());
+ ASSERT_IS_OK(copy_insertion.Run(&module()).status());
+ ASSERT_EQ(instruction_count, module().instruction_count());
// Get the instructions in the module.
- const HloInstruction* bcast = module->entry_computation()->root_instruction();
+ const HloInstruction* bcast =
+ module().entry_computation()->root_instruction();
const HloInstruction* constant =
- module->entry_computation()->GetInstructionWithName("constant.42");
+ module().entry_computation()->GetInstructionWithName("constant.42");
ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
const HloInstruction* while1 = bcast->operand(0);
ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
@@ -1979,7 +1994,7 @@ ENTRY %test_module {
ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
// Run buffer assignment.
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(&module());
TF_ASSERT_OK_AND_ASSIGN(auto slice_constant,
assignment->GetUniqueSlice(constant, {}));
TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
@@ -2072,7 +2087,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
// any copies inserted for BufferAssignment to run.
int64 instruction_count = module->instruction_count();
CopyInsertion copy_insertion;
- ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
+ ASSERT_IS_OK(copy_insertion.Run(module).status());
ASSERT_EQ(instruction_count, module->instruction_count());
// Create a sequential order among all the instructions in the entry
@@ -2084,8 +2099,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
TF_ASSERT_OK_AND_ASSIGN(
auto assignment,
BufferAssigner::Run(
- module.get(),
- absl::make_unique<SequentialHloOrdering>(module.get(), sequence),
+ module, absl::make_unique<SequentialHloOrdering>(module, sequence),
backend().compiler()->BufferSizeBytesFunction(),
[](LogicalBuffer::Color) { return 1; },
/*allow_input_output_aliasing=*/false,
@@ -2122,7 +2136,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
auto zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
- HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+ HloInstruction::CreateBroadcast(data_shape_, zero, {}));
auto cond0 =
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
@@ -2143,8 +2157,8 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0));
module->AddEntryComputation(builder.Build());
- RunCopyInsertion(module.get());
- auto assignment = RunBufferAssignment(module.get());
+ RunCopyInsertion(module);
+ auto assignment = RunBufferAssignment(module);
// while0 and while1 buffers should be completely aligned.
EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie(),
@@ -2186,13 +2200,13 @@ TEST_F(BufferAssignmentTest, TwoCalls) {
{
FlattenCallGraph flatten;
- TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
}
- RunCopyInsertion(module.get());
- auto assignment = RunBufferAssignment(module.get());
+ RunCopyInsertion(module);
+ auto assignment = RunBufferAssignment(module);
EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment));
}
@@ -2216,15 +2230,14 @@ ENTRY Main {
}
)";
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloModule> module,
- HloRunner::CreateModuleFromString(
- hlo_text, legacy_flags::GetDebugOptionsFromFlags()));
+ HloModuleConfig config;
+ config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
+ ParseAndVerifyModule(hlo_text, config);
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(&module());
- HloComputation* main = module->entry_computation();
- HloComputation* callee = module->GetComputationWithName("Callee");
+ HloComputation* main = module().entry_computation();
+ HloComputation* callee = module().GetComputationWithName("Callee");
EXPECT_NE(callee, nullptr);
HloInstruction* param0 = callee->parameter_instruction(0);
@@ -2284,14 +2297,14 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
auto weights0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, data_shape_, "weights0"));
auto output0 = builder.AddInstruction(
- HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+ HloInstruction::CreateBroadcast(data_shape_, zero, {}));
auto input1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, data_shape_, "input1"));
auto weights1 = builder.AddInstruction(
HloInstruction::CreateParameter(3, data_shape_, "weights1"));
auto output1 = builder.AddInstruction(
- HloInstruction::CreateBroadcast(data_shape_, one, {1}));
+ HloInstruction::CreateBroadcast(data_shape_, one, {}));
auto cond =
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
@@ -2311,18 +2324,18 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
HloInstruction::CreateGetTupleElement(data_shape_, while0, 0));
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, while1, 1));
- auto root_add = builder.AddInstruction(HloInstruction::CreateBinary(
- while0->shape(), HloOpcode::kAdd, gte0, gte1));
+ auto root_add = builder.AddInstruction(
+ HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, gte0, gte1));
module->AddEntryComputation(builder.Build());
{
FlattenCallGraph flatten;
- TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module));
EXPECT_TRUE(result);
}
- RunCopyInsertion(module.get());
+ RunCopyInsertion(module);
auto sequence =
ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie();
@@ -2341,8 +2354,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
auto assignment =
BufferAssigner::Run(
- module.get(),
- absl::make_unique<SequentialHloOrdering>(module.get(), sequence),
+ module, absl::make_unique<SequentialHloOrdering>(module, sequence),
ByteSizeOf, [](LogicalBuffer::Color) { return 1; },
/*allow_input_output_aliasing=*/false,
/*allocate_buffers_for_constants=*/true)
@@ -2363,9 +2375,9 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
auto zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
- HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+ HloInstruction::CreateBroadcast(data_shape_, zero, {}));
auto output1 = builder.AddInstruction(
- HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+ HloInstruction::CreateBroadcast(data_shape_, zero, {}));
auto cond0 =
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
@@ -2396,8 +2408,8 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
HloInstruction::CreateGetTupleElement(data_shape_, while1, 2));
module->AddEntryComputation(builder.Build());
- RunCopyInsertion(module.get());
- auto assignment = RunBufferAssignment(module.get());
+ RunCopyInsertion(module);
+ auto assignment = RunBufferAssignment(module);
// Get BufferAllocation for root instruction.
auto* root_alloc = assignment->GetUniqueTopLevelSlice(while1_out)
.ConsumeValueOrDie()
diff --git a/tensorflow/compiler/xla/service/buffer_value.h b/tensorflow/compiler/xla/service/buffer_value.h
index f4be16e084..69b3646356 100644
--- a/tensorflow/compiler/xla/service/buffer_value.h
+++ b/tensorflow/compiler/xla/service/buffer_value.h
@@ -19,12 +19,12 @@ limitations under the License.
#include <functional>
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/int_type.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h
index d773558c28..52037bf9b5 100644
--- a/tensorflow/compiler/xla/service/channel_tracker.h
+++ b/tensorflow/compiler/xla/service/channel_tracker.h
@@ -18,12 +18,12 @@ limitations under the License.
#include <map>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index 3079695e96..e5a6c28478 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -62,7 +62,7 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options,
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileOnlyService::CompileAheadOfTime(
- const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
+ const absl::Span<const AotXlaComputationInstance> computations,
const AotCompilationOptions& options,
std::unique_ptr<AotCompilationMetadata>* metadata) {
std::vector<std::unique_ptr<HloModule>> hlo_modules;
diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h
index 1ac950bdd6..61136a3e11 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.h
+++ b/tensorflow/compiler/xla/service/compile_only_service.h
@@ -50,12 +50,12 @@ class CompileOnlyService : public Service {
// |CompileOnlyClient::CompileAheadOfTime| for additional details.
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
- const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
+ const absl::Span<const AotXlaComputationInstance> computations,
const AotCompilationOptions& options);
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
- const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
+ const absl::Span<const AotXlaComputationInstance> computations,
const AotCompilationOptions& options,
std::unique_ptr<AotCompilationMetadata>* metadata);
diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h
index 34f7fe12ca..1fdda31c34 100644
--- a/tensorflow/compiler/xla/service/compiler.h
+++ b/tensorflow/compiler/xla/service/compiler.h
@@ -26,6 +26,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -34,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
index 6c477da038..c43a31b167 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
@@ -39,10 +39,6 @@ namespace op = xla::testing::opcode_matchers;
class ConditionalSimplifierTest : public HloVerifiedTestBase {
public:
- ConditionalSimplifierTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-
// Makes a computation that contains a conditional with constant predicate.
HloComputation* MakeConditional(HloModule* module);
};
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index 1b7a7b36ea..b65dfef9c9 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -479,7 +479,7 @@ class CopyRemover {
// 'values' an entry is created in value_to_node which indicates the
// respective ValueNode representing that value.
void AddValueList(
- tensorflow::gtl::ArraySlice<const HloValue*> values,
+ absl::Span<const HloValue* const> values,
tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>* value_to_node) {
ValueNode* tail = nullptr;
ValueNode* head = nullptr;
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 4cd192873f..d412578619 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -51,6 +51,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
alwayslink = True, # Contains per-platform transfer manager registration
)
@@ -63,6 +64,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -89,6 +91,7 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
":target_machine_features",
+ "@com_google_absl//absl/types:span",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/xla/service:scatter_expander",
"//tensorflow/compiler/xla:literal",
@@ -236,6 +239,7 @@ cc_library(
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
"@llvm//:orc_jit",
],
)
@@ -286,6 +290,7 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
"@llvm//:code_gen",
"@llvm//:core",
"@llvm//:support",
@@ -331,6 +336,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
],
)
@@ -461,6 +467,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -666,6 +673,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -760,6 +768,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -915,6 +924,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
"@llvm//:support",
],
diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc
index 408fe0f5bf..1942ea1a2a 100644
--- a/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc
+++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc
@@ -40,7 +40,7 @@ std::vector<BufferInfo> CreateBufferInfosFromBufferAssignment(
}
std::vector<int32> CreateArgIndexTableFromBufferInfos(
- tensorflow::gtl::ArraySlice<BufferInfo> buffer_infos) {
+ absl::Span<const BufferInfo> buffer_infos) {
std::vector<int32> result;
for (int64 i = 0; i < buffer_infos.size(); i++) {
if (buffer_infos[i].is_entry_parameter()) {
diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.h b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h
index 05de70c726..e9ee928ab2 100644
--- a/tensorflow/compiler/xla/service/cpu/buffer_info_util.h
+++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h
@@ -16,9 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_
+#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace cpu {
@@ -34,7 +34,7 @@ CreateBufferInfosFromBufferAssignment(
// If this function returns V then entry parameter i has buffer allocation index
// V[i].
std::vector<int32> CreateArgIndexTableFromBufferInfos(
- tensorflow::gtl::ArraySlice<::tensorflow::cpu_function_runtime::BufferInfo>
+ absl::Span<const ::tensorflow::cpu_function_runtime::BufferInfo>
buffer_infos);
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 6420180b13..796f36510e 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -588,8 +588,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
ScheduleComputationsInModule(*module, BufferSizeBytesFunction(),
DFSMemoryScheduler));
- // Run buffer analysis on the HLO graph. This analysis figures out which
- // temporary buffers are required to run the computation.
+ // Run buffer allocation on the HLO graph.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> assignment,
BufferAssigner::Run(module.get(),
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
index 47b5edabff..f2af923782 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
+#include "absl/types/span.h"
#include "llvm/Target/TargetMachine.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index 08773693fb..29abf38e43 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -75,9 +75,9 @@ CpuExecutable::CpuExecutable(
StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
std::vector<OwningDeviceMemory>>>
-CpuExecutable::CreateTempArray(
+CpuExecutable::CreateBufferTable(
DeviceMemoryAllocator* memory_allocator, int device_ordinal,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ absl::Span<const ShapedBuffer* const> arguments) {
std::vector<se::DeviceMemoryBase> unowning_buffers(
assignment_->Allocations().size());
std::vector<OwningDeviceMemory> owning_buffers(
@@ -136,19 +136,19 @@ CpuExecutable::CreateTempArray(
Status CpuExecutable::ExecuteComputeFunction(
const ExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
+ absl::Span<const se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile) {
// The calling convention for JITed functions is:
//
// void function(void* result, const void* run_options, void** args_array,
- // void** temps_array)
+ // void** buffer_table)
//
// result: Points at the result.
// run_options: the ExecutableRunOptions object.
// args_array: null
- // temps_array: An array of pointers, containing pointers to temporary buffers
- // required by the executable adn pointers to entry computation
- // parameters.
+ // buffer_table: An array of pointers, containing pointers to temporary
+ // buffers required by the executable adn pointers to entry computation
+ // parameters.
//
uint64 start_micros = tensorflow::Env::Default()->NowMicros();
@@ -172,7 +172,7 @@ Status CpuExecutable::ExecuteComputeFunction(
if (VLOG_IS_ON(3)) {
VLOG(3) << "Executing compute function:";
VLOG(3) << absl::StrFormat(
- " func(void* result, void* params[null], void* temps[%u], "
+ " func(void* result, void* params[null], void* buffer_table[%u], "
"uint64 profile_counters[%u])",
buffer_pointers.size(), profile_counters_size);
VLOG(3) << absl::StrFormat(" result = %p", result_buffer);
@@ -181,7 +181,8 @@ Status CpuExecutable::ExecuteComputeFunction(
};
VLOG(3) << " params = nullptr";
VLOG(3) << absl::StrFormat(
- " temps = [%s]", absl::StrJoin(buffer_pointers, ", ", ptr_printer));
+ " buffer_table = [%s]",
+ absl::StrJoin(buffer_pointers, ", ", ptr_printer));
VLOG(3) << absl::StrFormat(" profile_counters = %p", profile_counters);
}
@@ -207,7 +208,7 @@ Status CpuExecutable::ExecuteComputeFunction(
StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::MutableArraySlice<OwningDeviceMemory> buffers) {
+ absl::Span<OwningDeviceMemory> buffers) {
se::Stream* stream = run_options->stream();
ScopedShapedBuffer result_buffer(
/*on_host_shape=*/result_shape(),
@@ -245,7 +246,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) {
TF_ASSIGN_OR_RETURN(
auto result,
@@ -256,7 +257,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ absl::Span<const ShapedBuffer* const> arguments) {
if (hlo_profiling_enabled()) {
return Unimplemented(
"Asynchronous execution on stream with hlo profiling is not yet "
@@ -267,7 +268,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) {
if (GetRootPointsToSet().IsAmbiguous()) {
return Unimplemented("Points-to set of root instruction is ambiguous");
@@ -281,11 +282,12 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
std::vector<se::DeviceMemoryBase> unowning_buffers;
TF_ASSIGN_OR_RETURN(
std::tie(unowning_buffers, owning_buffers),
- CreateTempArray(memory_allocator, stream->parent()->device_ordinal(),
- arguments));
+ CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(),
+ arguments));
- TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
- CreateResultShapedBuffer(run_options, &owning_buffers));
+ TF_ASSIGN_OR_RETURN(
+ ScopedShapedBuffer result,
+ CreateResultShapedBuffer(run_options, absl::MakeSpan(owning_buffers)));
// At this point, `unowning_buffers` contains unowning pointers to all of our
// buffers, and `buffers` contains owning pointers to the non-live-out
@@ -298,7 +300,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
//
// We also need to change the types of some of the variables we capture:
// run_options needs to change from a pointer to a value type, and arguments
- // needs to change from an ArraySlice into a vector. We use a struct instead
+ // needs to change from a Span into a vector. We use a struct instead
// of a lambda to make this explicit.
struct AsyncRunTask {
CpuExecutable* executable;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
index 96e53de57e..3c3c047bfe 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
@@ -57,12 +57,12 @@ class CpuExecutable : public Executable {
StatusOr<ScopedShapedBuffer> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) override;
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) override;
+ absl::Span<const ShapedBuffer* const> arguments) override;
// This should be called after set_ir_module_string.
const string& ir_module_string() const { return ir_module_string_; }
@@ -74,9 +74,10 @@ class CpuExecutable : public Executable {
static int64 ShapeSizeBytes(const Shape& shape);
// Type of the computation function we expect in the JIT.
- using ComputeFunctionType = void (*)(
- void* /*result*/, const ExecutableRunOptions* /*run_options*/,
- const void** /*args*/, void** /*temps*/, int64* /*profile_counters*/);
+ using ComputeFunctionType =
+ void (*)(void* /*result*/, const ExecutableRunOptions* /*run_options*/,
+ const void** /*args*/, void** /*buffer_table*/,
+ int64* /*profile_counters*/);
const ComputeFunctionType& compute_function() const {
return compute_function_;
@@ -92,18 +93,18 @@ class CpuExecutable : public Executable {
// exists) must out-live the task.
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStreamImpl(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile);
- // Creates an array suitable for passing as the "temps" argument to the JIT
- // compiled function pointer.
+ // Creates an array suitable for passing as the "buffer_table" argument to the
+ // JIT compiled function pointer.
//
// Returns (unowning_buffers, owning_buffers) where:
//
- // - unowning_buffers.data() can be passed as the temps argument as-is and
- // includes pointers to the scratch storage required by the computation,
- // the live-out buffer into which the result will be written and entry
- // computation parameters.
+ // - unowning_buffers.data() can be passed as the buffer_table argument as-is
+ // and includes pointers to the scratch storage required by the
+ // computation, the live-out buffer into which the result will be written
+ // and entry computation parameters.
//
// - owning_buffers contains owning pointers to the buffers that were
// allocated by this routine. This routine allocates buffers for temporary
@@ -111,22 +112,21 @@ class CpuExecutable : public Executable {
// result.
StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
std::vector<OwningDeviceMemory>>>
- CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
+ CreateBufferTable(DeviceMemoryAllocator* memory_allocator, int device_ordinal,
+ absl::Span<const ShapedBuffer* const> arguments);
// Calls the generated function performing the computation with the given
// arguments using the supplied buffers.
- Status ExecuteComputeFunction(
- const ExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
- HloExecutionProfile* hlo_execution_profile);
+ Status ExecuteComputeFunction(const ExecutableRunOptions* run_options,
+ absl::Span<const se::DeviceMemoryBase> buffers,
+ HloExecutionProfile* hlo_execution_profile);
// Creates a ScopedShapedBuffer for holding the result of the computation,
// moving buffers out of allocated_buffers and into the result as appropriate.
// The addresses are set according to buffer assignment.
StatusOr<ScopedShapedBuffer> CreateResultShapedBuffer(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::MutableArraySlice<OwningDeviceMemory> buffers);
+ absl::Span<OwningDeviceMemory> buffers);
// Returns the points-to set of the root instruction of the entry
// computation. Uses points-to analysis from buffer assignment.
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
index 7f867fa149..f9cd61bea3 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
@@ -35,7 +35,7 @@ bool CanBeLoopFused(const HloInstruction& hlo) {
hlo.opcode() == HloOpcode::kDynamicSlice ||
hlo.opcode() == HloOpcode::kDynamicUpdateSlice ||
hlo.opcode() == HloOpcode::kGather ||
- hlo.opcode() == HloOpcode::kPad ||
+ hlo.opcode() == HloOpcode::kIota || hlo.opcode() == HloOpcode::kPad ||
hlo.opcode() == HloOpcode::kReshape ||
hlo.opcode() == HloOpcode::kReverse ||
hlo.opcode() == HloOpcode::kSlice ||
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 28aaa28cdb..284929ca07 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -19,11 +19,11 @@ limitations under the License.
#include <set>
#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace op = xla::testing::opcode_matchers;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
index 3681d12d8d..9363af3b89 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
@@ -39,7 +40,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace op = xla::testing::opcode_matchers;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 639064040f..8a44c384bb 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <functional>
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
index 0df2abf001..5519a43b2f 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
@@ -179,7 +179,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
int64 size = GetByteSizeRequirement(literal_shape);
// Note: OSS build didn't like implicit conversion from
// literal_shape.dimensions() to the array slice on 2017-07-10.
- tensorflow::gtl::ArraySlice<int64> dimensions(
+ absl::Span<const int64> dimensions(
tensorflow::bit_cast<const int64*>(literal_shape.dimensions().data()),
literal_shape.dimensions().size());
TF_ASSIGN_OR_RETURN(
@@ -225,7 +225,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
StatusOr<Shape> CpuTransferManager::TransferTupleBuffersFromOutfeed(
se::StreamExecutor* executor,
- tensorflow::gtl::ArraySlice<std::pair<void*, int64>> buffer_data) {
+ absl::Span<const std::pair<void*, int64>> buffer_data) {
return TransferBuffersFromOutfeedInternal(executor, buffer_data,
/*is_tuple=*/true);
}
@@ -238,8 +238,7 @@ StatusOr<Shape> CpuTransferManager::TransferArrayBufferFromOutfeed(
StatusOr<Shape> CpuTransferManager::TransferBuffersFromOutfeedInternal(
se::StreamExecutor* executor,
- tensorflow::gtl::ArraySlice<std::pair<void*, int64>> buffer_data,
- bool is_tuple) {
+ absl::Span<const std::pair<void*, int64>> buffer_data, bool is_tuple) {
std::vector<std::unique_ptr<CpuOutfeedBuffer>> buffers;
for (auto b : buffer_data) {
int64 size = b.second;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
index 7b938e9fd7..361d4b9c84 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
@@ -18,13 +18,13 @@ limitations under the License.
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h"
#include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
@@ -56,7 +56,7 @@ class CpuTransferManager : public GenericTransferManager {
// Helper that transfers a tuple of element buffers from the device's outfeed.
StatusOr<Shape> TransferTupleBuffersFromOutfeed(
se::StreamExecutor* executor,
- tensorflow::gtl::ArraySlice<std::pair<void*, int64>> buffer_data);
+ absl::Span<const std::pair<void*, int64>> buffer_data);
// Helper that transfers an array buffer from the device's outfeed.
StatusOr<Shape> TransferArrayBufferFromOutfeed(se::StreamExecutor* executor,
@@ -68,8 +68,7 @@ class CpuTransferManager : public GenericTransferManager {
// for the given buffers.
StatusOr<Shape> TransferBuffersFromOutfeedInternal(
se::StreamExecutor* executor,
- tensorflow::gtl::ArraySlice<std::pair<void*, int64>> buffer_data,
- bool is_tuple);
+ absl::Span<const std::pair<void*, int64>> buffer_data, bool is_tuple);
TF_DISALLOW_COPY_AND_ASSIGN(CpuTransferManager);
};
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index dd060f54a2..99fa707c95 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -80,7 +80,7 @@ class MemoryTile {
// `minor_dim_offset`}.
//
// Note: `major_dim_offset` is a parameter to the constructor.
- void StoreTile(tensorflow::gtl::ArraySlice<llvm::Value*> tile,
+ void StoreTile(absl::Span<llvm::Value* const> tile,
llvm::Value* minor_dim_offset) const {
CHECK_EQ(tile.size(), pointers_.size());
for (int64 i = 0; i < pointers_.size(); i++) {
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 460363e18f..e5cf15c686 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -29,6 +29,7 @@ limitations under the License.
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
+#include "absl/types/span.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/BasicBlock.h"
@@ -66,7 +67,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
@@ -100,6 +100,11 @@ IrEmitter::IrEmitter(
b_.setFastMathFlags(llvm_ir::GetFastMathFlags(
/*fast_math_enabled=*/hlo_module_config_.debug_options()
.xla_cpu_enable_fast_math()));
+ Status s = GatherComputationsByAllocationType(
+ &hlo_module, &thread_local_computations_, &global_computations_);
+ absl::c_sort(thread_local_computations_);
+ absl::c_sort(global_computations_);
+ TF_CHECK_OK(s) << "Should have failed buffer assignment.";
}
StatusOr<llvm::Function*> IrEmitter::EmitComputation(
@@ -337,10 +342,10 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
// Write the tuple index table.
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
assignment_.GetUniqueSlice(infeed, {0}));
- llvm::Value* data_address = EmitTempBufferPointer(data_slice, data_shape);
+ llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice,
assignment_.GetUniqueSlice(infeed, {1}));
- llvm::Value* token_address = EmitTempBufferPointer(
+ llvm::Value* token_address = EmitBufferPointer(
token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1));
llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_,
module_);
@@ -363,9 +368,9 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
// Only the outer tuple buffer's target address is obtained from
// GetEmittedValueFor, to handle the case when Infeed is the root
// instruction. Target addresses for internal elements can be obtained
- // from EmitTempBufferPointer.
+ // from EmitBufferPointer.
llvm::Value* tuple_element_address =
- EmitTempBufferPointer(buffer, tuple_element_shape);
+ EmitBufferPointer(buffer, tuple_element_shape);
TF_RETURN_IF_ERROR(EmitXfeedTransfer(
XfeedKind::kInfeed, tuple_element_shape, tuple_element_address));
@@ -460,6 +465,12 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
}
Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
+ // Outfeed produces no useful result, but it does return a token[] that can be
+ // threaded through to other side effecting operations to ensure ordering. In
+ // the IR emitter we treat this token as a normal u8[] and thus need to insert
+ // an entry for it in emitted_value_.
+ TF_RETURN_IF_ERROR(EmitTargetAddressForOp(outfeed));
+
HloInstruction* operand = outfeed->operands()[0];
const Shape& operand_shape = operand->shape();
@@ -500,8 +511,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
llvm::Value* IrEmitter::EmitElementalMap(
const HloMapInstruction& map_instr,
- tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
- absl::string_view name) {
+ absl::Span<llvm::Value* const> elemental_operands, absl::string_view name) {
return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name);
}
@@ -1195,7 +1205,7 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
const Shape& operand_shape = crs->operand(i)->shape();
CHECK(ShapeUtil::IsArray(operand_shape))
<< "Operands to cross-replica-sum must be arrays: " << crs->ToString();
- operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape));
+ operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
// TODO(b/63762267): Be more aggressive about specifying alignment.
MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr,
@@ -1449,7 +1459,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction(
const ReductionGenerator& reduction_generator,
const llvm_ir::IrArray::Index& output_index,
const ShardedVectorType& accumulator_type, HloInstruction* init_value,
- HloInstruction* arg, gtl::ArraySlice<int64> dimensions,
+ HloInstruction* arg, absl::Span<const int64> dimensions,
unsigned element_alignment) {
ShardedVector accumulator;
accumulator.reserve(accumulator_type.size());
@@ -1545,7 +1555,7 @@ void IrEmitter::EmitShardedVectorStore(
StatusOr<bool> IrEmitter::EmitVectorizedReduce(
HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
- gtl::ArraySlice<int64> dimensions, HloComputation* function,
+ absl::Span<const int64> dimensions, HloComputation* function,
string* failure_reason) {
if (!ReductionPreservesLayout(*reduce)) {
return false;
@@ -1695,7 +1705,7 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index) {
const HloInstruction* arg = reduce->mutable_operand(0);
const HloInstruction* init_value = reduce->mutable_operand(1);
- gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ absl::Span<const int64> dimensions(reduce->dimensions());
// Initialize an accumulator with init_value.
PrimitiveType accumulator_type = reduce->shape().element_type();
@@ -1752,7 +1762,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
}
auto arg = reduce->mutable_operand(0);
auto init_value = reduce->mutable_operand(1);
- gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ absl::Span<const int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
if (!options::VectorizedReduceDisabled(hlo_module_config_)) {
string vectorization_failure_reason;
@@ -2092,7 +2102,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
{}, &b_, computation->name(),
/*return_value_buffer=*/emitted_value_[call],
/*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/GetTempBuffersArgument(),
+ /*buffer_table_arg=*/GetBufferTableArgument(),
/*profile_counters_arg=*/GetProfileCountersArgument());
HloInstruction* root = computation->root_instruction();
@@ -2107,7 +2117,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
}
Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
- gtl::ArraySlice<HloInstruction*> operands(custom_call->operands());
+ absl::Span<HloInstruction* const> operands(custom_call->operands());
absl::string_view custom_call_target(custom_call->custom_call_target());
llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
llvm::AllocaInst* operands_alloca =
@@ -2227,7 +2237,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
}
StatusOr<bool> IrEmitter::EmitFastConcatenate(
- HloInstruction* concatenate, gtl::ArraySlice<HloInstruction*> operands,
+ HloInstruction* concatenate, absl::Span<HloInstruction* const> operands,
string* failure_reason) {
if (ShouldEmitParallelLoopFor(*concatenate)) {
*failure_reason =
@@ -2363,7 +2373,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
}
Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
- gtl::ArraySlice<HloInstruction*> operands(concatenate->operands());
+ absl::Span<HloInstruction* const> operands(concatenate->operands());
string failure_reason;
TF_ASSIGN_OR_RETURN(
bool successful,
@@ -2612,15 +2622,15 @@ llvm::Value* IrEmitter::GetProfileCountersArgument() {
return compute_function_->profile_counters_arg();
}
-llvm::Value* IrEmitter::GetTempBuffersArgument() {
- return compute_function_->temp_buffers_arg();
+llvm::Value* IrEmitter::GetBufferTableArgument() {
+ return compute_function_->buffer_table_arg();
}
llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
return compute_function_->exec_run_options_arg();
}
-llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
+llvm::Value* IrEmitter::EmitThreadLocalBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape) {
const BufferAllocation& allocation = *slice.allocation();
llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
@@ -2679,11 +2689,11 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo());
}
-llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
+llvm::Value* IrEmitter::EmitGlobalBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape) {
const BufferAllocation& allocation = *slice.allocation();
llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
- GetTempBuffersArgument(), slice.index(), &b_);
+ GetBufferTableArgument(), slice.index(), &b_);
llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr);
if (hlo_module_config_.debug_options()
.xla_llvm_enable_invariant_load_metadata()) {
@@ -2704,14 +2714,14 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
IrShapeType(target_shape)->getPointerTo());
}
-llvm::Value* IrEmitter::EmitTempBufferPointer(
- const BufferAllocation::Slice& slice, const Shape& target_shape) {
+llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice,
+ const Shape& target_shape) {
if (slice.allocation()->is_thread_local()) {
- return EmitThreadLocalTempBufferPointer(slice, target_shape);
+ return EmitThreadLocalBufferPointer(slice, target_shape);
} else if (slice.allocation()->is_constant()) {
return FindOrDie(constant_buffer_to_global_, slice.allocation()->index());
} else {
- return EmitGlobalTempBufferPointer(slice, target_shape);
+ return EmitGlobalBufferPointer(slice, target_shape);
}
}
@@ -2719,7 +2729,7 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
const Shape& target_shape = op->shape();
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
assignment_.GetUniqueTopLevelSlice(op));
- llvm::Value* addr = EmitTempBufferPointer(slice, target_shape);
+ llvm::Value* addr = EmitBufferPointer(slice, target_shape);
addr->setName(AsStringRef(IrName(op)));
emitted_value_[op] = addr;
return Status::OK();
@@ -2748,8 +2758,7 @@ Status IrEmitter::EmitTargetElementLoop(
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
assignment_.GetUniqueSlice(target_op, {i}));
const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i});
- llvm::Value* op_target_address =
- EmitTempBufferPointer(slice, element_shape);
+ llvm::Value* op_target_address = EmitBufferPointer(slice, element_shape);
output_arrays.push_back(
llvm_ir::IrArray(op_target_address, element_shape));
}
@@ -2794,8 +2803,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source,
Status IrEmitter::ElementTypesSameAndSupported(
const HloInstruction& instruction,
- gtl::ArraySlice<const HloInstruction*> operands,
- gtl::ArraySlice<PrimitiveType> supported_types) {
+ absl::Span<const HloInstruction* const> operands,
+ absl::Span<const PrimitiveType> supported_types) {
for (auto operand : operands) {
TF_RET_CHECK(
ShapeUtil::SameElementType(operands[0]->shape(), operand->shape()));
@@ -2825,9 +2834,10 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
}
llvm::Value* IrEmitter::EmitThreadLocalCall(
- const HloComputation& callee,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
+ const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
absl::string_view name) {
+ CHECK(absl::c_binary_search(thread_local_computations_, &callee));
+
const Shape& return_shape = callee.root_instruction()->shape();
// Lifting this restriction to allow "small" arrays should be easy. Allowing
@@ -2856,7 +2866,7 @@ llvm::Value* IrEmitter::EmitThreadLocalCall(
parameter_addrs, &b_, name,
/*return_value_buffer=*/return_value_buffer,
/*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/
+ /*buffer_table_arg=*/
llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
/*profile_counters_arg=*/GetProfileCountersArgument()));
@@ -2865,13 +2875,15 @@ llvm::Value* IrEmitter::EmitThreadLocalCall(
void IrEmitter::EmitGlobalCall(const HloComputation& callee,
absl::string_view name) {
+ CHECK(absl::c_binary_search(global_computations_, &callee));
+
Call(FindOrDie(emitted_functions_, &callee),
GetArrayFunctionCallArguments(
/*parameter_addresses=*/{}, &b_, name,
/*return_value_buffer=*/
llvm::Constant::getNullValue(b_.getInt8PtrTy()),
/*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/GetTempBuffersArgument(),
+ /*buffer_table_arg=*/GetBufferTableArgument(),
/*profile_counters_arg=*/GetProfileCountersArgument()));
}
@@ -2884,7 +2896,7 @@ llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
const BufferAllocation::Slice root_buffer =
assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie();
- return EmitTempBufferPointer(root_buffer, root_inst->shape());
+ return EmitBufferPointer(root_buffer, root_inst->shape());
}
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index f98891246b..58a333b8fb 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -24,6 +24,7 @@ limitations under the License.
#include <vector>
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "llvm/ADT/Triple.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
@@ -46,7 +47,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -62,8 +62,8 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// Create a new LLVM IR emitter.
//
// hlo_module: the HLO module we are emitting IR for.
- // assignment: a BufferAssignment from which we know which temporary buffers
- // are used by the HLO nodes.
+ // assignment: a BufferAssignment from which we know which buffers are used by
+ // the HLO nodes.
// llvm_module: the LLVM module to emit IR into.
// instruction_to_profile_idx: the mapping from HLO instructions to their
// index in the profiling array.
@@ -111,7 +111,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// Emit code to map one element according to `map_instr`.
llvm::Value* EmitElementalMap(
const HloMapInstruction& map_instr,
- tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
+ absl::Span<llvm::Value* const> elemental_operands,
absl::string_view name);
protected:
@@ -219,24 +219,21 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// argument of the computation function being emitted by this emitter.
llvm::Value* GetExecutableRunOptionsArgument();
- // Get the llvm::Value* that represents the "temps" argument of the
+ // Get the llvm::Value* that represents the "buffer_table" argument of the
// computation function being emitted by this emitter.
- llvm::Value* GetTempBuffersArgument();
+ llvm::Value* GetBufferTableArgument();
- // Helper for EmitTempBufferPointer.
- llvm::Value* EmitGlobalTempBufferPointer(const BufferAllocation::Slice& slice,
- const Shape& target_shape);
+ // Helper for EmitBufferPointer.
+ llvm::Value* EmitGlobalBufferPointer(const BufferAllocation::Slice& slice,
+ const Shape& target_shape);
- // Helper for EmitTempBufferPointer.
- llvm::Value* EmitThreadLocalTempBufferPointer(
+ // Helper for EmitBufferPointer.
+ llvm::Value* EmitThreadLocalBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape);
// Emits code that computes the address of the given buffer allocation slice.
- //
- // TODO(sanjoy): This should be renamed to reflect that it no longer provides
- // access to just temporaries.
- llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice,
- const Shape& target_shape);
+ llvm::Value* EmitBufferPointer(const BufferAllocation::Slice& slice,
+ const Shape& target_shape);
// Emits a function into the current module. This can be used for
// computations embedded inside other computations, such as the
@@ -252,10 +249,9 @@ class IrEmitter : public DfsHloVisitorWithDefault,
//
// `parameters` holds the *scalar values* that need to be passed to the
// callee. The return value is the scalar returned by the callee.
- llvm::Value* EmitThreadLocalCall(
- const HloComputation& callee,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
- absl::string_view name);
+ llvm::Value* EmitThreadLocalCall(const HloComputation& callee,
+ absl::Span<llvm::Value* const> parameters,
+ absl::string_view name);
// Emits a call to a "global" function (e.g. to the computation nested within
// a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to
@@ -271,8 +267,8 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// match and are of one of the given supported types.
Status ElementTypesSameAndSupported(
const HloInstruction& instruction,
- tensorflow::gtl::ArraySlice<const HloInstruction*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> supported_types);
+ absl::Span<const HloInstruction* const> operands,
+ absl::Span<const PrimitiveType> supported_types);
// Emit IR to perform a computation for every element in the given target op.
// This produces a series of nested loops (one for each dimension of the op's
@@ -319,10 +315,12 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// concepts that generalize over other vectorizable operations. We should
// consider pulling out these abstractions into a VectorizingIrEmitter or
// something similar.
- StatusOr<bool> EmitVectorizedReduce(
- HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
- tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function,
- string* failure_reason);
+ StatusOr<bool> EmitVectorizedReduce(HloInstruction* reduce,
+ HloInstruction* arg,
+ HloInstruction* init_value,
+ absl::Span<const int64> dimensions,
+ HloComputation* function,
+ string* failure_reason);
// We'd like to keep one or two one cache-line's worth of data in registers
// without generating IR with illegal (e.g. excessively large or
@@ -372,16 +370,15 @@ class IrEmitter : public DfsHloVisitorWithDefault,
const ReductionGenerator& reduction_generator,
const llvm_ir::IrArray::Index& output_index,
const ShardedVectorType& accumulator_type, HloInstruction* init_value,
- HloInstruction* arg, tensorflow::gtl::ArraySlice<int64> dimensions,
+ HloInstruction* arg, absl::Span<const int64> dimensions,
unsigned element_alignment);
// Tries to emit a fast concatenate operation using memcpy. Returns true if
// successful, and false on failure. On failure, sets "failure_reason" to a
// string describing why it could not emit a fast concatenate.
- StatusOr<bool> EmitFastConcatenate(
- HloInstruction* concatenate,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- string* failure_reason);
+ StatusOr<bool> EmitFastConcatenate(HloInstruction* concatenate,
+ absl::Span<HloInstruction* const> operands,
+ string* failure_reason);
// Emits LLVM IR to transfer "element_count" elements of type "primitive_type"
// from the address "source" to the address "target".
@@ -390,8 +387,8 @@ class IrEmitter : public DfsHloVisitorWithDefault,
const llvm_ir::IrArray& target_array,
const llvm_ir::IrArray& source_array);
- // Assignment of the temporary buffers needed by the computation and their
- // shape information.
+ // Assignment of the buffers needed by the computation and their shape
+ // information.
const BufferAssignment& assignment_;
// The LLVM module into which IR will be emitted.
@@ -571,6 +568,9 @@ class IrEmitter : public DfsHloVisitorWithDefault,
tensorflow::gtl::FlatMap<BufferAllocation::Index, llvm::Constant*>
constant_buffer_to_global_;
+ std::vector<const HloComputation*> thread_local_computations_;
+ std::vector<const HloComputation*> global_computations_;
+
TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter);
};
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc
index 784045313d..adfb8392bf 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc
@@ -78,19 +78,20 @@ void IrFunction::Initialize(const string& function_name,
const bool optimize_for_size_requested,
const bool enable_fast_math) {
// The function signature is:
- // void function(i8* retval, i8* run_options, i8** params, i8** temps,
+ // void function(i8* retval, i8* run_options, i8** params, i8**
+ // buffer_table,
// i64* dynamic_loop_bounds, i64* prof_counters)
//
// For thread local functions:
// retval: points to the returned value.
// params: address of an array with pointers to parameters.
- // temps: is null
+ // buffer_table: is null
//
// For global functions:
// retval: is null
// params: is null
- // temps: address of an array with pointers to temporary buffers and entry
- // computation parameters.
+ // buffer_table: address of an array with pointers to temporary buffers and
+ // entry computation parameters (but not to constant buffers).
//
// Therefore, the generated function's signature (FunctionType) is statically
// determined - parameter unpacking is done in code generated into the
@@ -116,7 +117,7 @@ void IrFunction::Initialize(const string& function_name,
// \---------/ \---------/ \-----------/
//
// /---------------------------------------------\
- // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 |
+ // buffer_table---> | buff 0 | guff 1 | ..... | buff N-1 |
// | addr | addr | | addr |
// \---------------------------------------------/
// | | |
@@ -134,9 +135,9 @@ void IrFunction::Initialize(const string& function_name,
// prof counters -> | counter 0 | counter 1 | ..... | counter N-1 |
// \---------------------------------------------/
- // Even though the type of params and temps is void** in the host's view, in
- // LLVM IR this is represented by i8*, similarly to void*. It's up to the code
- // to use GEPs to unravel the indirection layers.
+ // Even though the type of params and buffer_table is void** in the host's
+ // view, in LLVM IR this is represented by i8*, similarly to void*. It's up to
+ // the code to use GEPs to unravel the indirection layers.
llvm::FunctionType* function_type = llvm::FunctionType::get(
/*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()),
/*Params=*/
@@ -160,8 +161,8 @@ void IrFunction::Initialize(const string& function_name,
exec_run_options_arg_ = &*arg_iter;
(++arg_iter)->setName("params");
parameters_arg_ = &*arg_iter;
- (++arg_iter)->setName("temps");
- temp_buffers_arg_ = &*arg_iter;
+ (++arg_iter)->setName("buffer_table");
+ buffer_table_arg_ = &*arg_iter;
if (num_dynamic_loop_bounds_ > 0) {
(++arg_iter)->setName("dynamic_loop_bounds");
dynamic_loop_bounds_arg_ = &*arg_iter;
@@ -200,10 +201,10 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
// Returns an array of compute function call arguments (including parameter
// address buffer).
std::vector<llvm::Value*> GetArrayFunctionCallArguments(
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::IRBuilder<>* b, absl::string_view name,
- llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
- llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) {
+ absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,
+ absl::string_view name, llvm::Value* return_value_buffer,
+ llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg,
+ llvm::Value* profile_counters_arg) {
llvm::Value* parameter_addresses_buffer;
if (parameter_addresses.empty()) {
@@ -230,7 +231,7 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments(
};
std::vector<llvm::Value*> arguments{
to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg),
- parameter_addresses_buffer, temp_buffers_arg};
+ parameter_addresses_buffer, buffer_table_arg};
if (profile_counters_arg != nullptr) {
arguments.push_back(profile_counters_arg);
}
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h
index ee7595f6e9..623a5f185f 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_
+#include "absl/types/span.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
@@ -24,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace cpu {
@@ -80,8 +80,9 @@ class IrFunction {
// Get the llvm::Value* that represents this functions parameters argument.
llvm::Value* parameters_arg() { return parameters_arg_; }
- // Get the llvm::Value* that represents this functions "temps" argument.
- llvm::Value* temp_buffers_arg() { return temp_buffers_arg_; }
+ // Get the llvm::Value* that represents this functions "buffer_table"
+ // argument.
+ llvm::Value* buffer_table_arg() { return buffer_table_arg_; }
// Get the llvm::Value* that represents this functions "prof_counters"
// argument.
@@ -108,17 +109,17 @@ class IrFunction {
llvm::Argument* result_arg_;
llvm::Value* exec_run_options_arg_;
llvm::Value* parameters_arg_;
- llvm::Value* temp_buffers_arg_;
+ llvm::Value* buffer_table_arg_;
llvm::Value* dynamic_loop_bounds_arg_ = nullptr;
llvm::Value* profile_counters_arg_;
};
// Returns an array of compute function call argument ir values.
std::vector<llvm::Value*> GetArrayFunctionCallArguments(
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::IRBuilder<>* b, absl::string_view name,
- llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
- llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg);
+ absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,
+ absl::string_view name, llvm::Value* return_value_buffer,
+ llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg,
+ llvm::Value* profile_counters_arg);
// Emits a call to a runtime fork/join function which dispatches parallel
// calls to 'parallel_function' (and joins threads before returning).
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
index a5f34908d7..2d9492eacf 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
+++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
@@ -61,7 +61,7 @@ using ComputeFunctionType = void (*)(void*, const void*, const void**, void**,
//
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin(
void* result_ptr, const void* run_options_ptr, const void** params,
- void** temps, uint64* prof_counters, int32 num_partitions,
+ void** buffer_table, uint64* prof_counters, int32 num_partitions,
int64* partitions, int32 num_partitioned_dims, void* function_ptr) {
VLOG(2) << "ParallelForkJoin ENTRY"
<< " num_partitions: " << num_partitions
@@ -81,9 +81,9 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin(
for (int32 i = 1; i < num_partitions; ++i) {
const int64 offset = i * stride;
run_options->intra_op_thread_pool()->enqueueNoNotification(
- [i, function, result_ptr, run_options_ptr, temps, prof_counters,
+ [i, function, result_ptr, run_options_ptr, buffer_table, prof_counters,
partitions, offset, &bc]() {
- function(result_ptr, run_options_ptr, nullptr, temps,
+ function(result_ptr, run_options_ptr, nullptr, buffer_table,
&partitions[offset], prof_counters);
bc.DecrementCount();
VLOG(3) << "ParallelForkJoin partition " << i << " done.";
@@ -91,7 +91,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin(
}
// Call first compute function inline.
- function(result_ptr, run_options_ptr, params, temps, &partitions[0],
+ function(result_ptr, run_options_ptr, params, buffer_table, &partitions[0],
prof_counters);
VLOG(3) << "ParallelForkJoin partition 0 done.";
bc.Wait();
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h
index 1cf0ec6e3d..a279c7d2d6 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h
+++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h
@@ -24,7 +24,7 @@ extern "C" {
// threads before returning. See comments in runtime_fork_join.cc for details.
extern void __xla_cpu_runtime_ParallelForkJoin(
void* result_ptr, const void* run_options_ptr, const void** params,
- void** temps, tensorflow::uint64* prof_counters,
+ void** buffer_table, tensorflow::uint64* prof_counters,
tensorflow::int32 num_partitions, tensorflow::int64* partitions,
tensorflow::int32 num_partitioned_dims, void* function_ptr);
diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
index ae80a6f497..7d8e51f909 100644
--- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
@@ -102,22 +102,22 @@ TEST_F(ShapePartitionIteratorTest, Shape53WithLayout10) {
{
ShapePartitionIterator iterator(shape, {1});
EXPECT_EQ(1, iterator.GetTotalPartitionCount());
- EXPECT_TRUE(ContainersEqual(Partition({{0, 5}}), iterator.GetPartition(0)));
+ EXPECT_TRUE(absl::c_equal(Partition({{0, 5}}), iterator.GetPartition(0)));
}
{
ShapePartitionIterator iterator(shape, {2});
EXPECT_EQ(2, iterator.GetTotalPartitionCount());
- EXPECT_TRUE(ContainersEqual(Partition({{0, 2}}), iterator.GetPartition(0)));
- EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(1)));
+ EXPECT_TRUE(absl::c_equal(Partition({{0, 2}}), iterator.GetPartition(0)));
+ EXPECT_TRUE(absl::c_equal(Partition({{2, 3}}), iterator.GetPartition(1)));
}
{
ShapePartitionIterator iterator(shape, {3});
EXPECT_EQ(3, iterator.GetTotalPartitionCount());
- EXPECT_TRUE(ContainersEqual(Partition({{0, 1}}), iterator.GetPartition(0)));
- EXPECT_TRUE(ContainersEqual(Partition({{1, 1}}), iterator.GetPartition(1)));
- EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(2)));
+ EXPECT_TRUE(absl::c_equal(Partition({{0, 1}}), iterator.GetPartition(0)));
+ EXPECT_TRUE(absl::c_equal(Partition({{1, 1}}), iterator.GetPartition(1)));
+ EXPECT_TRUE(absl::c_equal(Partition({{2, 3}}), iterator.GetPartition(2)));
}
}
@@ -128,20 +128,20 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) {
ShapePartitionIterator iterator(shape, {1, 1});
EXPECT_EQ(1, iterator.GetTotalPartitionCount());
EXPECT_TRUE(
- ContainersEqual(Partition({{0, 5}, {0, 3}}), iterator.GetPartition(0)));
+ absl::c_equal(Partition({{0, 5}, {0, 3}}), iterator.GetPartition(0)));
}
{
ShapePartitionIterator iterator(shape, {2, 2});
EXPECT_EQ(4, iterator.GetTotalPartitionCount());
EXPECT_TRUE(
- ContainersEqual(Partition({{0, 2}, {0, 1}}), iterator.GetPartition(0)));
+ absl::c_equal(Partition({{0, 2}, {0, 1}}), iterator.GetPartition(0)));
EXPECT_TRUE(
- ContainersEqual(Partition({{0, 2}, {1, 2}}), iterator.GetPartition(1)));
+ absl::c_equal(Partition({{0, 2}, {1, 2}}), iterator.GetPartition(1)));
EXPECT_TRUE(
- ContainersEqual(Partition({{2, 3}, {0, 1}}), iterator.GetPartition(2)));
+ absl::c_equal(Partition({{2, 3}, {0, 1}}), iterator.GetPartition(2)));
EXPECT_TRUE(
- ContainersEqual(Partition({{2, 3}, {1, 2}}), iterator.GetPartition(3)));
+ absl::c_equal(Partition({{2, 3}, {1, 2}}), iterator.GetPartition(3)));
}
}
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
index 780c07f819..e2c7af541e 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
@@ -54,6 +54,33 @@ CHECK: private constant [48 x i8]
/*match_optimized_ir=*/false);
}
+TEST_F(CpuOutfeedTest, OutfeedTokenInTuple) {
+ const string hlo_text = R"(
+HloModule OutfeedTokenInTuple
+
+ENTRY main {
+ const = f32[] constant(42)
+ epoch = token[] after-all()
+ outfeed.tok = token[] outfeed(const, epoch)
+ ROOT root = (token[], f32[]) tuple(outfeed.tok, const)
+}
+)";
+
+ string filecheck_pattern = R"(
+CHECK: Outfeed
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_text));
+
+ CpuAotCompilationOptions options{
+ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"",
+ /*entry_point_name=*/"entry",
+ /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static};
+
+ CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern,
+ /*match_optimized_ir=*/false);
+}
} // namespace
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
index 962ea69c09..1bd4b59dd6 100644
--- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
+++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
@@ -428,7 +428,7 @@ std::vector<llvm::Value*> TileVariable::Get() const {
return result;
}
-void TileVariable::Set(tensorflow::gtl::ArraySlice<llvm::Value*> value) {
+void TileVariable::Set(absl::Span<llvm::Value* const> value) {
CHECK_EQ(value.size(), storage_.size());
for (int64 i = 0, e = value.size(); i < e; i++) {
storage_[i].Set(value[i]);
diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h
index c728f6df0a..5690d2be2f 100644
--- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h
+++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h
@@ -18,12 +18,12 @@ limitations under the License.
#include <string>
+#include "absl/types/span.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace cpu {
@@ -324,7 +324,7 @@ class TileVariable {
std::vector<llvm::Value*> initial_value);
std::vector<llvm::Value*> Get() const;
- void Set(tensorflow::gtl::ArraySlice<llvm::Value*> value);
+ void Set(absl::Span<llvm::Value* const> value);
private:
std::vector<VectorVariable> storage_;
diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc
index 47543b2082..b9e47f5aad 100644
--- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc
@@ -37,7 +37,7 @@ void XfeedQueueManager::Reset() {
}
void XfeedQueueManager::EnqueueBuffersAtomically(
- tensorflow::gtl::ArraySlice<XfeedBuffer*> buffers) {
+ absl::Span<XfeedBuffer* const> buffers) {
tensorflow::mutex_lock l(mu_);
bool was_empty = enqueued_buffers_.empty();
for (XfeedBuffer* b : buffers) {
diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h
index b4ace23260..990ff94ba2 100644
--- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h
+++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h
@@ -22,10 +22,10 @@ limitations under the License.
#include <deque>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mutex.h"
namespace xla {
@@ -63,8 +63,7 @@ class XfeedQueueManager {
// called when the buffer will no longer be accessed by the XfeedManager,
// either as a result of a call to Reset or because the runtime has dequeued
// and used the buffer.
- void EnqueueBuffersAtomically(
- tensorflow::gtl::ArraySlice<XfeedBuffer*> buffers);
+ void EnqueueBuffersAtomically(absl::Span<XfeedBuffer* const> buffers);
// Blocks until the queue is non-empty, then returns the buffer at the head of
// the queue. Sets the current buffer to be the returned buffer. It is an
diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc
index 37d1895d41..e727ba49cb 100644
--- a/tensorflow/compiler/xla/service/defuser_test.cc
+++ b/tensorflow/compiler/xla/service/defuser_test.cc
@@ -26,11 +26,6 @@ namespace xla {
namespace {
class DefuserTest : public HloVerifiedTestBase {
- public:
- DefuserTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-
protected:
// Returns the number of fusion instructions in the module.
int FusionCount() {
diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc
index 1d0297cfbf..edbcb25247 100644
--- a/tensorflow/compiler/xla/service/device_memory_allocator.cc
+++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc
@@ -25,7 +25,7 @@ namespace xla {
StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
const se::Platform* platform,
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors)
+ absl::Span<se::StreamExecutor* const> stream_executors)
: DeviceMemoryAllocator(platform),
stream_executors_(stream_executors.begin(), stream_executors.end()) {}
diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h
index d87b86caf0..a2308ee7a4 100644
--- a/tensorflow/compiler/xla/service/device_memory_allocator.h
+++ b/tensorflow/compiler/xla/service/device_memory_allocator.h
@@ -18,10 +18,10 @@ limitations under the License.
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/owning_device_memory.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
@@ -80,7 +80,7 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
public:
StreamExecutorMemoryAllocator(
const se::Platform* platform,
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors);
+ absl::Span<se::StreamExecutor* const> stream_executors);
StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
bool retry_on_failure) override;
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index f6f8fc5a2a..5761573791 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -20,13 +20,13 @@ limitations under the License.
#include <vector>
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index 4f620e4c3a..4cd10ab06c 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -17,13 +17,13 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 813e93fafa..4bb1e071d8 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -856,7 +856,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
auto getFloat = [&](const float f) {
return llvm::ConstantFP::get(b_->getFloatTy(), f);
};
- auto multiply_add = [&](tensorflow::gtl::ArraySlice<float> coefficients,
+ auto multiply_add = [&](absl::Span<const float> coefficients,
llvm::Value* w) {
llvm::Value* p = getFloat(coefficients.front());
coefficients.remove_prefix(1);
@@ -893,7 +893,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
SetToFirstInsertPoint(if_data.true_block, b_);
{
llvm::Value* lw = FSub(w, getFloat(2.5f));
- tensorflow::gtl::ArraySlice<float> lq{
+ absl::Span<const float> lq{
2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
-4.39150654e-06f, 0.00021858087f, -0.00125372503f,
-0.00417768164f, 0.246640727f, 1.50140941f};
@@ -908,7 +908,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f));
- tensorflow::gtl::ArraySlice<float> gq{
+ absl::Span<const float> gq{
-0.000200214257f, 0.000100950558f, 0.00134934322f,
-0.00367342844f, 0.00573950773f, -0.0076224613f,
0.00943887047f, 1.00167406f, 2.83297682f};
@@ -2117,29 +2117,40 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
iota->shape().dimensions(iota->iota_dimension())};
elem_index_linear = elem_index.Linearize(iota_bound, b_);
}
- if (ShapeUtil::ElementIsIntegral(iota->shape())) {
- return b_->CreateIntCast(
+ Shape component_shape =
+ ShapeUtil::ElementIsComplex(iota->shape())
+ ? ShapeUtil::ComplexComponentShape(iota->shape())
+ : iota->shape();
+ PrimitiveType component_element_type = component_shape.element_type();
+ llvm::Value* iota_result;
+ if (ShapeUtil::ElementIsIntegral(component_shape)) {
+ iota_result = b_->CreateIntCast(
elem_index_linear,
- llvm_ir::PrimitiveTypeToIrType(element_type, module_),
+ llvm_ir::PrimitiveTypeToIrType(component_element_type, module_),
/*isSigned=*/false);
} else {
- TF_RET_CHECK(ShapeUtil::ElementIsFloating(iota->shape()))
- << element_type;
+ TF_RET_CHECK(ShapeUtil::ElementIsFloating(component_shape))
+ << component_element_type;
llvm::Type* float_ir_type;
- if (element_type == BF16) {
+ if (component_element_type == BF16) {
float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_);
} else {
float_ir_type =
- llvm_ir::PrimitiveTypeToIrType(element_type, module_);
+ llvm_ir::PrimitiveTypeToIrType(component_element_type, module_);
}
llvm::Value* float_val =
b_->CreateUIToFP(elem_index_linear, float_ir_type);
- if (element_type == BF16) {
- return EmitF32ToBF16(float_val, b_);
+ if (component_element_type == BF16) {
+ iota_result = EmitF32ToBF16(float_val, b_);
} else {
- return float_val;
+ iota_result = float_val;
}
}
+ if (ShapeUtil::ElementIsComplex(iota->shape())) {
+ return EmitComposeComplex(iota, iota_result, nullptr);
+ } else {
+ return iota_result;
+ }
};
case HloOpcode::kSlice:
return [this, hlo, &operand_to_generator](
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
index 5ab0756219..1b3be199f6 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
@@ -28,8 +28,7 @@ using absl::nullopt;
class ElementalIrEmitterExecutionTest : public HloTestBase {
protected:
- void RunTest(const string& hlo_text,
- tensorflow::gtl::ArraySlice<Literal*> args) {
+ void RunTest(const string& hlo_text, absl::Span<Literal* const> args) {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc
index 78edf918a4..47c56e2f7f 100644
--- a/tensorflow/compiler/xla/service/executable.cc
+++ b/tensorflow/compiler/xla/service/executable.cc
@@ -26,13 +26,12 @@ limitations under the License.
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/env.h"
-using tensorflow::gtl::ArraySlice;
namespace xla {
StatusOr<std::vector<ScopedShapedBuffer>> Executable::ExecuteOnStreams(
- ArraySlice<const ServiceExecutableRunOptions> run_options,
- ArraySlice<ArraySlice<const ShapedBuffer*>> arguments) {
+ absl::Span<const ServiceExecutableRunOptions> run_options,
+ absl::Span<const absl::Span<const ShapedBuffer* const>> arguments) {
TF_RET_CHECK(run_options.size() == arguments.size());
std::vector<ScopedShapedBuffer> return_values;
@@ -63,7 +62,7 @@ StatusOr<std::vector<ScopedShapedBuffer>> Executable::ExecuteOnStreams(
StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper(
const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile,
- ArraySlice<const ShapedBuffer*> arguments) {
+ absl::Span<const ShapedBuffer* const> arguments) {
se::Stream* stream = run_options->stream();
std::unique_ptr<se::Timer> timer;
if (profile != nullptr) {
diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h
index 98eaeee30a..3a6780f2a6 100644
--- a/tensorflow/compiler/xla/service/executable.h
+++ b/tensorflow/compiler/xla/service/executable.h
@@ -18,7 +18,10 @@ limitations under the License.
#include <memory>
#include <utility>
+#include <vector>
+#include "absl/types/span.h"
+#include "absl/types/variant.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@@ -26,18 +29,33 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
+#include "tensorflow/compiler/xla/service/owning_device_memory.h"
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace xla {
+// ExecutionOutput encapsulates the output buffers of a execution and the
+// leftover buffers to be released by the caller.
+struct ExecutionOutput {
+ ExecutionOutput(ScopedShapedBuffer result,
+ std::vector<OwningDeviceMemory> to_be_released)
+ : result(std::move(result)), to_be_released(std::move(to_be_released)) {}
+ ScopedShapedBuffer result;
+
+ // Leftover buffers for the caller to release. Elements in this list are
+ // donated input memory buffers that are not reused by XLA as outputs.
+ std::vector<OwningDeviceMemory> to_be_released;
+};
+
// A given platform's compiler will produce an Executable -- this is a uniform
// interface that is used for launching compiled programs across platforms.
class Executable {
@@ -63,25 +81,46 @@ class Executable {
// Returns a shaped buffer containing the result of the computation.
virtual StatusOr<ScopedShapedBuffer> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) = 0;
// Same as ExecuteOnStream(), but this call is non-blocking and returns as
// soon as all of the operations are enqueued for launch on the stream.
virtual StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) = 0;
+ absl::Span<const ShapedBuffer* const> arguments) = 0;
+
+ // Starts the given program executing on the given stream/executor.
+ //
+ // `arguments` are ShapeTree containing the input parameters. For each element
+ // in the shape tree, if the element holds the ownership of the memory, it is
+ // considered donated and XLA will potentially reuse it as output buffers. For
+ // all donated inputs, XLA is also responsible for freeing them.
+ //
+ // If an input is donated to XLA but is not reused as output, it is returned
+ // as an leftover buffer for the caller to release.
+ virtual StatusOr<ExecutionOutput> ExecuteOnStream(
+ const ServiceExecutableRunOptions* run_options,
+ std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments,
+ HloExecutionProfile* hlo_execution_profile) {
+ return Unimplemented(
+ "MaybeOwningDeviceMemory version of overload is not implemented ");
+ }
+
+ virtual StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
+ const ServiceExecutableRunOptions* run_options,
+ std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments) {
+ return Unimplemented(
+ "MaybeOwningDeviceMemory version of overload is not implemented ");
+ }
// Same as ExecuteOnStream(), but runs this executable on multiple
// streams. arguments[i] contains the arguments to the execution on
// run_options[i]->stream() and the returned value is at index i of the
// returned vector.
virtual StatusOr<std::vector<ScopedShapedBuffer>> ExecuteOnStreams(
- tensorflow::gtl::ArraySlice<const ServiceExecutableRunOptions>
- run_options,
- tensorflow::gtl::ArraySlice<
- tensorflow::gtl::ArraySlice<const ShapedBuffer*>>
- arguments);
+ absl::Span<const ServiceExecutableRunOptions> run_options,
+ absl::Span<const absl::Span<const ShapedBuffer* const>> arguments);
// Populates `hlo_execution_profile` from `executor`. This is implicit in any
// Execute* API call that takes a hlo_execution_profile argument, but must be
@@ -97,7 +136,7 @@ class Executable {
// given ExecutionProfile if non-null.
StatusOr<ScopedShapedBuffer> ExecuteOnStreamWrapper(
const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
+ absl::Span<const ShapedBuffer* const> arguments);
// Returns the ExecutionProfile from executing on the device. This includes
// the number of cycles taken for the computation or the compilation time.
diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc
index 3f1a881372..cb86c98579 100644
--- a/tensorflow/compiler/xla/service/gather_expander.cc
+++ b/tensorflow/compiler/xla/service/gather_expander.cc
@@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
namespace xla {
-using tensorflow::gtl::ArraySlice;
static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast(
HloInstruction* start_indices, int64 index_vector_dim) {
@@ -225,7 +224,7 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
static StatusOr<HloInstruction*> CreateGatherLoopAccumulatorInitValue(
HloComputation* computation, PrimitiveType element_type,
- ArraySlice<int64> slice_sizes, int64 gather_loop_trip_count,
+ absl::Span<const int64> slice_sizes, int64 gather_loop_trip_count,
const GatherDimensionNumbers& dim_numbers) {
std::vector<int64> accumulator_state_shape_dims;
accumulator_state_shape_dims.reserve(1 + slice_sizes.size());
@@ -244,7 +243,7 @@ static StatusOr<HloInstruction*> CreateGatherLoopAccumulatorInitValue(
// are the major dimensions and the offset dimensions are the minor dimensions.
// Fix this up with a transpose.
static StatusOr<HloInstruction*> PermuteBatchAndOffsetDims(
- HloInstruction* accumulator, ArraySlice<int64> offset_dims,
+ HloInstruction* accumulator, absl::Span<const int64> offset_dims,
int64 output_rank) {
std::vector<int64> permutation;
permutation.reserve(output_rank);
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index 0ce2db907b..4ed91ef187 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -42,8 +42,7 @@ se::Platform::Id GenericTransferManager::PlatformId() const {
}
Status GenericTransferManager::WriteSingleTupleIndexTable(
- se::Stream* stream,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> elements,
+ se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
const Shape& shape, se::DeviceMemoryBase* region) {
TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape));
@@ -163,7 +162,7 @@ Status GenericTransferManager::TransferLiteralFromOutfeed(
}
Status GenericTransferManager::ResetDevices(
- tensorflow::gtl::ArraySlice<se::StreamExecutor*>
+ absl::Span<se::StreamExecutor* const>
/*executors*/) {
return Unimplemented(
"Device reset is not yet supported on this platform (b/30481585)");
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h
index 6c1a21587a..86c8b1c145 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h
@@ -55,15 +55,13 @@ class GenericTransferManager : public TransferManager {
const Shape& literal_shape,
MutableBorrowingLiteral literal) override;
- Status ResetDevices(
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> executors) override;
+ Status ResetDevices(absl::Span<se::StreamExecutor* const> executors) override;
int64 GetByteSizeRequirement(const Shape& shape) const override;
protected:
Status WriteSingleTupleIndexTable(
- se::Stream* stream,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> elements,
+ se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
const Shape& shape, se::DeviceMemoryBase* region) override;
private:
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index d6e9436348..a68b7a1bef 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -132,6 +132,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
],
)
@@ -192,6 +193,7 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
"@llvm//:support",
],
@@ -237,6 +239,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:math_ops",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
"@llvm//:support",
],
@@ -257,6 +260,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
@@ -356,6 +360,7 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -443,7 +448,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:shape_inference",
- "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:test",
],
@@ -454,6 +459,7 @@ cc_library(
srcs = ["instruction_fusion.cc"],
hdrs = ["instruction_fusion.h"],
deps = [
+ ":gpu_fusible",
":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -483,6 +489,7 @@ cc_library(
srcs = ["multi_output_fusion.cc"],
hdrs = ["multi_output_fusion.h"],
deps = [
+ ":gpu_fusible",
":instruction_fusion",
":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util",
@@ -531,6 +538,7 @@ cc_library(
srcs = ["fusion_merger.cc"],
hdrs = ["fusion_merger.h"],
deps = [
+ ":gpu_fusible",
":instruction_fusion",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
@@ -645,9 +653,9 @@ cc_library(
":gpu_constants",
":gpu_copy_insertion",
":gpu_executable",
+ ":gpu_hlo_schedule",
":gpu_hlo_support_checker",
":gpu_layout_assignment",
- ":hlo_schedule",
":instruction_fusion",
":ir_emission_utils",
":ir_emitter",
@@ -668,7 +676,6 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_liveness",
"//tensorflow/compiler/xla/service:call_inliner",
"//tensorflow/compiler/xla/service:conditional_simplifier",
- "//tensorflow/compiler/xla/service:convolution_feature_group_converter",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
"//tensorflow/compiler/xla/service:hlo",
@@ -702,6 +709,7 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
],
alwayslink = True, # Contains compiler registration
@@ -794,9 +802,9 @@ tf_cc_test(
)
cc_library(
- name = "hlo_schedule",
- srcs = ["hlo_schedule.cc"],
- hdrs = ["hlo_schedule.h"],
+ name = "gpu_hlo_schedule",
+ srcs = ["gpu_hlo_schedule.cc"],
+ hdrs = ["gpu_hlo_schedule.h"],
deps = [
":stream_assignment",
"//tensorflow/compiler/xla:statusor",
@@ -811,12 +819,12 @@ cc_library(
)
tf_cc_test(
- name = "hlo_schedule_test",
+ name = "gpu_hlo_schedule_test",
srcs = [
- "hlo_schedule_test.cc",
+ "gpu_hlo_schedule_test.cc",
],
deps = [
- ":hlo_schedule",
+ ":gpu_hlo_schedule",
":stream_assignment",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:types",
@@ -875,7 +883,9 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:stream_executor_no_cuda",
],
)
@@ -924,3 +934,26 @@ xla_test(
"//tensorflow/core:test_main",
],
)
+
+cc_library(
+ name = "gpu_fusible",
+ srcs = ["gpu_fusible.cc"],
+ hdrs = ["gpu_fusible.h"],
+ deps = [
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla/service:hlo",
+ ],
+)
+
+tf_cc_test(
+ name = "gpu_fusible_test",
+ srcs = ["gpu_fusible_test.cc"],
+ deps = [
+ ":gpu_fusible",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/strings",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
index 86af83b6b9..528209abc7 100644
--- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
+++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
@@ -83,7 +83,7 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
0) {
return InternalError(
"Address returned by memory_allocator->Allocate must be a "
- "multiple of %x, but was %p",
+ "multiple of 0x%x, but was %p",
kXlaAllocatedBufferAlignBytes, buffer.opaque());
}
// We do manual memory management within BufferAllocations. Be sure not
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h
index f13eab0dd7..14186b8faa 100644
--- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h
+++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h
@@ -20,10 +20,10 @@ limitations under the License.
#include <set>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index eea31f3de1..05448d863d 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -37,8 +37,8 @@ ConvolutionThunk::ConvolutionThunk(
const BufferAllocation::Slice& tuple_result_buffer,
const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape,
const Shape& filter_shape, const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dim_nums, int64 algorithm,
- bool tensor_ops_enabled, const HloInstruction* hlo)
+ const ConvolutionDimensionNumbers& dim_nums, int64 feature_group_count,
+ int64 algorithm, bool tensor_ops_enabled, const HloInstruction* hlo)
: Thunk(Kind::kConvolution, hlo),
convolution_kind_(convolution_kind),
input_buffer_(input_buffer),
@@ -51,6 +51,7 @@ ConvolutionThunk::ConvolutionThunk(
output_shape_(output_shape),
window_(window),
dim_nums_(dim_nums),
+ feature_group_count_(feature_group_count),
algorithm_(algorithm),
tensor_ops_enabled_(tensor_ops_enabled) {}
@@ -72,8 +73,8 @@ Status ConvolutionThunk::ExecuteOnStream(
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
TF_RETURN_IF_ERROR(RunCudnnConvolution(
convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data,
- filter_data, output_data, scratch, window_, dim_nums_, algorithm_config,
- stream));
+ filter_data, output_data, scratch, window_, dim_nums_,
+ feature_group_count_, algorithm_config, stream));
// Figure out which of output/input/filter is the result produced by
// this op, and write the result tuple.
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index f7952787c1..68d67c40c5 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -59,7 +59,8 @@ class ConvolutionThunk : public Thunk {
const BufferAllocation::Slice& scratch_buffer,
const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dim_nums, int64 algorithm,
+ const ConvolutionDimensionNumbers& dim_nums,
+ int64 feature_group_count, int64 algorithm,
bool tensor_ops_enabled, const HloInstruction* hlo);
ConvolutionThunk(const ConvolutionThunk&) = delete;
@@ -71,19 +72,6 @@ class ConvolutionThunk : public Thunk {
HloExecutionProfiler* profiler) override;
private:
- class ScratchAllocator;
-
- Status Convolve(const se::dnn::BatchDescriptor& input_descriptor,
- se::DeviceMemory<float> input_data,
- const se::dnn::FilterDescriptor& filter_descriptor,
- se::DeviceMemory<float> filter_data,
- const se::dnn::BatchDescriptor& output_descriptor,
- se::DeviceMemory<float> output_data,
- const se::dnn::ConvolutionDescriptor& convolution_descriptor,
- const se::dnn::AlgorithmConfig& algorithm_config,
- se::Stream* stream, ScratchAllocator* scratch_allocator,
- se::dnn::ProfileResult* profile_result);
-
const CudnnConvKind convolution_kind_;
const BufferAllocation::Slice input_buffer_;
@@ -98,6 +86,7 @@ class ConvolutionThunk : public Thunk {
const Window window_;
const ConvolutionDimensionNumbers dim_nums_;
+ int64 feature_group_count_;
int64 algorithm_;
bool tensor_ops_enabled_;
};
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index dbdf8e7a0e..5c2555148a 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -178,7 +178,8 @@ StatusOr<std::tuple<int64, bool, int64>>
CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) {
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
+ HloInstruction* instr) {
CHECK_EQ(input_shape.element_type(), filter_shape.element_type());
CHECK_EQ(input_shape.element_type(), output_shape.element_type());
// TODO(timshen): for now only check fp16. It can be expanded to other types,
@@ -192,6 +193,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
// concurrently and then run them sequentially.
tensorflow::mutex_lock lock = LockGpu(stream_exec_);
+ // Make sure any previous activity on this executor is done. We don't want to
+ // interfere with programs that are still running on the GPU.
+ if (!stream_exec_->SynchronizeAllActivity()) {
+ return InternalError("Failed to synchronize GPU for autotuning.");
+ }
+
// Create a stream for us to do our work on.
se::Stream stream{stream_exec_};
stream.Init();
@@ -204,9 +211,8 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
if (allocator_ != nullptr) {
allocator = allocator_;
} else {
- se_allocator.emplace(
- stream_exec_->platform(),
- tensorflow::gtl::ArraySlice<se::StreamExecutor*>({stream_exec_}));
+ se_allocator.emplace(stream_exec_->platform(),
+ absl::Span<se::StreamExecutor* const>({stream_exec_}));
allocator = &*se_allocator;
}
@@ -234,8 +240,8 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
CHECK_EQ(0, left_over_bytes % 2);
constexpr float kBroadcastedConstant = 0.1f;
- Eigen::half halfs[2] = {Eigen::half(kBroadcastedConstant),
- Eigen::half(kBroadcastedConstant)};
+ static const Eigen::half halfs[2] = {Eigen::half(kBroadcastedConstant),
+ Eigen::half(kBroadcastedConstant)};
uint32 bits;
static_assert(sizeof(bits) == sizeof(halfs), "");
memcpy(&bits, halfs, sizeof(bits));
@@ -259,7 +265,6 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
.ThenMemZero(&filter_buf, filter_buf.size())
.ThenMemZero(&output_buf, output_buf.size());
}
- TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
DeviceMemoryBase* result_buf = [&] {
switch (kind) {
@@ -290,10 +295,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
<< instr->ToString();
bool launch_ok =
- RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- input_buf, filter_buf, output_buf,
- &scratch_allocator, window, dnums,
- AlgorithmConfig(alg), &stream, &profile_result)
+ RunCudnnConvolution(
+ kind, input_shape, filter_shape, output_shape, input_buf,
+ filter_buf, output_buf, &scratch_allocator, window, dnums,
+ feature_group_count, AlgorithmConfig(alg), &stream, &profile_result)
.ok();
if (launch_ok && profile_result.is_valid()) {
@@ -379,17 +384,20 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape,
/*filter_shape=*/rhs_shape,
/*output_shape=*/conv_result_shape, instr->window(),
- instr->convolution_dimension_numbers(), instr);
+ instr->convolution_dimension_numbers(),
+ instr->feature_group_count(), instr);
} else if (call_target == kCudnnConvBackwardInputCallTarget) {
alg_scratch_and_tc = PickBestAlgorithm(
CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape,
/*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(),
- instr->convolution_dimension_numbers(), instr);
+ instr->convolution_dimension_numbers(), instr->feature_group_count(),
+ instr);
} else if (call_target == kCudnnConvBackwardFilterCallTarget) {
alg_scratch_and_tc = PickBestAlgorithm(
CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape,
/*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape,
- instr->window(), instr->convolution_dimension_numbers(), instr);
+ instr->window(), instr->convolution_dimension_numbers(),
+ instr->feature_group_count(), instr);
} else {
LOG(FATAL) << "Unknown custom call target for cudnn conv: "
<< instr->ToString();
@@ -423,14 +431,9 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
backend_config.set_algorithm(algorithm);
backend_config.set_tensor_ops_enabled(tensor_ops_enabled);
- HloInstruction* new_call =
- computation->AddInstruction(HloInstruction::CreateCustomCall(
- new_call_shape,
- {instr->mutable_operand(0), instr->mutable_operand(1)},
- instr->custom_call_target()));
- new_call->set_window(instr->window());
- new_call->set_convolution_dimension_numbers(
- instr->convolution_dimension_numbers());
+ HloInstruction* new_call = computation->AddInstruction(
+ instr->CloneWithNewOperands(new_call_shape, {instr->mutable_operand(0),
+ instr->mutable_operand(1)}));
TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config));
// Repackage new_call so it has the same shape as the original call, namely
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index f76d273e8c..0cb01161b0 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -51,7 +51,8 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dnums, HloInstruction* instr);
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
+ HloInstruction* instr);
se::StreamExecutor* stream_exec_; // never null
DeviceMemoryAllocator* allocator_; // may be null
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index 905b5ee876..9bf721ecd2 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -59,6 +59,11 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
HloInstruction* conv) {
const auto no_match_result =
std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
+ // TODO(b/31709653): Figure out if we can use grouped convolutions also on
+ // backward filter.
+ if (conv->feature_group_count() > 1) {
+ return no_match_result;
+ }
// Step 1: match the instruction pattern without considering the paddings and
// dimension numbers just yet. We may need some generic pattern matcher
// similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h
@@ -218,6 +223,12 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
const auto no_match_result =
std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
+ // TODO(b/31709653): Figure out if we can use grouped convolutions also on
+ // backward input.
+ if (conv->feature_group_count() > 1) {
+ return no_match_result;
+ }
+
// Match instruction pattern.
CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
HloInstruction* reverse_filter = conv->mutable_operand(1);
@@ -234,6 +245,23 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
<< "Backward input convolution should reverse all kernel dimensions.";
return no_match_result;
}
+ } else if (reverse_filter->IsConstant()) {
+ // If the filter is a constant, we're willing to pattern-match to a
+ // backwards-input conv, on the theory that
+ //
+ // a) reversing a constant is free, and
+ // b) even if the user specified this filter as reverse(constant), we would
+ // long ago have constant-folded away the reverse.
+ //
+ // If the constant has any other uses, reversing it isn't entirely free,
+ // since we'd now have two constants to keep in memory. But hopefully it's
+ // free enough.
+ //
+ // TODO(jlebar): Should we do this even if the filter is not a constant?
+ // Reversing a non-constant filter is probably cheaper than padding the
+ // input!
+
+ // Nothing to do, just fall through.
} else {
// Possibly 1x1 filter.
for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) {
@@ -373,22 +401,25 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
}
}
- // Fuse the matched HLOs into a backward convolution instruction.
- //
- // If the reverse is omitted (for 1x1 filters) in the original pattern, we add
- // it back in the fusion instruction so that later passes (such as
- // PadInsertion) can handle such fusion instructions easily.
+ // OK, it's a match! Canonicalize the conv's filter so that it's a reverse.
+ // This simplifies things for our caller, and algebraic-simplifier will later
+ // remove any unnecessary reverses.
if (reverse_filter->opcode() != HloOpcode::kReverse) {
- reverse_filter = reverse_filter->parent()->AddInstruction(
+ // Create a double-reverse, which is a nop.
+ HloComputation* c = conv->parent();
+ reverse_filter = c->AddInstruction(
+ HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
+ AsInt64Slice(kernel_spatial_dims)));
+ reverse_filter = c->AddInstruction(
HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
AsInt64Slice(kernel_spatial_dims)));
TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter));
}
+
dnums.set_kernel_input_feature_dimension(
conv->convolution_dimension_numbers().kernel_output_feature_dimension());
dnums.set_kernel_output_feature_dimension(
conv->convolution_dimension_numbers().kernel_input_feature_dimension());
-
return std::make_tuple(true, new_window, dnums);
}
@@ -405,7 +436,7 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
if (match) {
return CreateCudnnConvBackwardFilter(
conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1),
- window, dnums);
+ window, dnums, conv->feature_group_count());
}
std::tie(match, window, dnums) = MatchBackwardInput(conv);
@@ -415,15 +446,17 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
CHECK_EQ(reverse->opcode(), HloOpcode::kReverse);
HloInstruction* rhs = reverse->mutable_operand(0);
- return CreateCudnnConvBackwardInput(
- conv->shape(), conv->mutable_operand(0), rhs, window, dnums);
+ return CreateCudnnConvBackwardInput(conv->shape(),
+ conv->mutable_operand(0), rhs, window,
+ dnums, conv->feature_group_count());
}
// If all else fails, try a forward convolution.
if (CanImplementAsCudnnForwardConv(conv)) {
return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0),
conv->mutable_operand(1), conv->window(),
- conv->convolution_dimension_numbers());
+ conv->convolution_dimension_numbers(),
+ conv->feature_group_count());
}
return nullptr;
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
index 65588b6aaf..46c23db465 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@@ -32,10 +32,13 @@ namespace gpu {
namespace {
namespace op = xla::testing::opcode_matchers;
+using ::testing::_;
-class CudnnConvolutionRewriterTest : public HloTestBase {
+class CudnnConvolutionRewriterTest : public HloVerifiedTestBase {
public:
- CudnnConvolutionRewriterTest() {
+ CudnnConvolutionRewriterTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false) {
for (int i = 0; i < 2; ++i) {
WindowDimension* window_dim = default_conv_window_.add_dimensions();
window_dim->set_size(1);
@@ -114,7 +117,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) {
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
@@ -142,7 +145,7 @@ TEST_F(CudnnConvolutionRewriterTest,
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
@@ -172,7 +175,7 @@ TEST_F(CudnnConvolutionRewriterTest,
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
@@ -202,7 +205,7 @@ TEST_F(CudnnConvolutionRewriterTest,
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
@@ -230,7 +233,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) {
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
@@ -280,7 +283,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) {
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
ASSERT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
@@ -325,7 +328,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) {
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
@@ -357,7 +360,7 @@ TEST_F(CudnnConvolutionRewriterTest,
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(
entry_computation->root_instruction(),
op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
@@ -410,7 +413,7 @@ TEST_F(CudnnConvolutionRewriterTest,
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
ASSERT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
@@ -457,7 +460,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(
entry_computation->root_instruction(),
op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
@@ -510,7 +513,7 @@ TEST_F(CudnnConvolutionRewriterTest,
auto module = CreateNewModule();
const HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
ASSERT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
@@ -562,12 +565,38 @@ TEST_F(CudnnConvolutionRewriterTest,
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(
entry_computation->root_instruction(),
op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
}
+// Check that we will materialize a reversed version of a constant in order to
+// pattern-match a backwards input convolution.
+TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) {
+ Array4D<float> constant_arr(4, 4, 2, 2);
+ constant_arr.FillIota(0);
+ string constant_str =
+ LiteralUtil::CreateR4FromArray4D(constant_arr)->ToString();
+ ParseAndVerifyModule(absl::StrFormat(R"(
+ HloModule test
+
+ ENTRY entry_computation {
+ param0 = f32[128,2,16,16]{3,2,1,0} parameter(0)
+ constant = f32[4,4,2,2]{3,2,1,0} constant(%s)
+ ROOT convolution = f32[128,2,32,32]{3,2,1,0} convolution(param0, constant),
+ window={size=4x4 pad=2_2x2_2 lhs_dilate=2x2},
+ dim_labels=bf01_01oi->bf01, feature_group_count=1
+ })",
+ constant_str));
+ EXPECT_TRUE(RunPass(&module()));
+ EXPECT_THAT(
+ module().entry_computation()->root_instruction(),
+ op::GetTupleElement(op::CustomCall(kCudnnConvBackwardInputCallTarget, _,
+ op::Reverse(op::Constant())),
+ 0));
+}
+
} // anonymous namespace
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 07b96fbd3f..05125e9d1f 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -77,8 +77,9 @@ Status RunCudnnConvolution(
const Shape& output_shape, DeviceMemory<T> input_buf,
DeviceMemory<T> filter_buf, DeviceMemory<T> output_buf,
se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm,
- Stream* stream, ProfileResult* profile_result /*= nullptr*/) {
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
+ AlgorithmConfig algorithm, Stream* stream,
+ ProfileResult* profile_result /*= nullptr*/) {
VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id();
VLOG(3) << "tensor_ops_enabled: "
<< algorithm.algorithm().tensor_ops_enabled();
@@ -144,6 +145,7 @@ Status RunCudnnConvolution(
}
ConvolutionDescriptor convolution_descriptor(effective_num_dimensions);
+ convolution_descriptor.set_group_count(feature_group_count);
for (int dim = 0; dim < num_dimensions; ++dim) {
convolution_descriptor
.set_zero_padding(
@@ -222,14 +224,14 @@ Status RunCudnnConvolution(
const Shape& output_shape, se::DeviceMemoryBase input_buf,
se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
se::DeviceMemoryBase scratch_buf, const Window& window,
- const ConvolutionDimensionNumbers& dnums,
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
se::dnn::ProfileResult* profile_result) {
ScratchBufAllocator scratch_allocator(scratch_buf);
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- input_buf, filter_buf, output_buf,
- &scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
+ return RunCudnnConvolution(
+ kind, input_shape, filter_shape, output_shape, input_buf, filter_buf,
+ output_buf, &scratch_allocator, window, dnums, feature_group_count,
+ algorithm, stream, profile_result);
}
Status RunCudnnConvolution(
@@ -237,32 +239,32 @@ Status RunCudnnConvolution(
const Shape& output_shape, se::DeviceMemoryBase input_buf,
se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums,
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
se::dnn::ProfileResult* profile_result) {
PrimitiveType output_primitive_type = output_shape.element_type();
switch (output_primitive_type) {
case F16:
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<Eigen::half>(input_buf),
- se::DeviceMemory<Eigen::half>(filter_buf),
- se::DeviceMemory<Eigen::half>(output_buf),
- scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
+ return RunCudnnConvolution(
+ kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<Eigen::half>(input_buf),
+ se::DeviceMemory<Eigen::half>(filter_buf),
+ se::DeviceMemory<Eigen::half>(output_buf), scratch_allocator, window,
+ dnums, feature_group_count, algorithm, stream, profile_result);
case F32:
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<float>(input_buf),
- se::DeviceMemory<float>(filter_buf),
- se::DeviceMemory<float>(output_buf),
- scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
+ return RunCudnnConvolution(
+ kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<float>(input_buf),
+ se::DeviceMemory<float>(filter_buf),
+ se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums,
+ feature_group_count, algorithm, stream, profile_result);
case F64:
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<double>(input_buf),
- se::DeviceMemory<double>(filter_buf),
- se::DeviceMemory<double>(output_buf),
- scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
+ return RunCudnnConvolution(
+ kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<double>(input_buf),
+ se::DeviceMemory<double>(filter_buf),
+ se::DeviceMemory<double>(output_buf), scratch_allocator, window,
+ dnums, feature_group_count, algorithm, stream, profile_result);
default:
LOG(FATAL) << ShapeUtil::HumanString(output_shape);
}
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
index 944e4ac686..a1b4fc71d0 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
@@ -75,7 +75,7 @@ Status RunCudnnConvolution(
const Shape& output_shape, se::DeviceMemoryBase input_buf,
se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
se::DeviceMemoryBase scratch_buf, const Window& window,
- const ConvolutionDimensionNumbers& dnums,
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
se::dnn::ProfileResult* profile_result = nullptr);
@@ -84,7 +84,7 @@ Status RunCudnnConvolution(
const Shape& output_shape, se::DeviceMemoryBase input_buf,
se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums,
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
se::dnn::ProfileResult* profile_result = nullptr);
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index 57a3a43a6f..c1aaa4bf04 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -74,10 +74,8 @@ GpuElementalIrEmitter::GpuElementalIrEmitter(
compute_nested_(std::move(compute_nested)) {}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) {
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
// The libdevice math functions differentiate between "double" and "float" by
// appending an 'f' to the function's name. libdevice doesn't have f16 math
// functions, so we convert the operands to f32 before calling the function
@@ -119,10 +117,8 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall(
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) {
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
// llvm intrinsics differentiate between half/float/double functions via
// the suffixes ".f16", ".f32" and ".f64".
string munged_callee = callee_name;
@@ -144,10 +140,8 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) {
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
// Binary math functions transform are of type [T] -> T.
for (PrimitiveType input_type : input_types) {
if (output_type != input_type) {
@@ -290,11 +284,9 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
}
llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type,
- tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes) {
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
+ absl::Span<const llvm::Attribute::AttrKind> attributes) {
std::vector<llvm::Type*> ir_input_types;
for (PrimitiveType input_type : input_types) {
ir_input_types.push_back(
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
index 91942785d2..e8b56a39ce 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/types/span.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
@@ -30,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace gpu {
@@ -38,9 +38,9 @@ namespace gpu {
class GpuElementalIrEmitter : public ElementalIrEmitter {
public:
// A NestedComputer computes an element of the output of the given computation
- // given an ArraySlice of its input elements.
+ // given a Span of its input elements.
using NestedComputer = std::function<StatusOr<llvm::Value*>(
- const HloComputation&, tensorflow::gtl::ArraySlice<llvm::Value*>)>;
+ const HloComputation&, absl::Span<llvm::Value* const>)>;
GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config,
llvm::Module* module, llvm::IRBuilder<>* b,
@@ -96,37 +96,29 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
// Emits IR to call a device function named "callee_name" on the given
// operand. Returns the IR value that represents the return value.
llvm::Value* EmitDeviceFunctionCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_type,
- PrimitiveType output_type,
- tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes);
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_type, PrimitiveType output_type,
+ absl::Span<const llvm::Attribute::AttrKind> attributes);
// Emits IR to call an LLVM intrinsic of type [T] -> T. Adjusts
// callee_name according to T. Returns the IR value that represents the
// return value of the function.
StatusOr<llvm::Value*> EmitLlvmIntrinsicMathCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type);
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
// Emits IR to call a libdevice function of type [T] -> T. Adjusts
// callee_name according to T. Returns the IR value that represents the
// return value of the function.
StatusOr<llvm::Value*> EmitLibdeviceMathCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type);
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
// Emits IR to call a function of type [T] -> T. Does not munge callee_name.
// Returns the IR value that represents the return value of the function.
StatusOr<llvm::Value*> EmitMathCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type);
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
const HloModuleConfig& hlo_module_config_;
NestedComputer compute_nested_;
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
index 11549cdac5..ca4a605af5 100644
--- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
@@ -92,8 +92,7 @@ string FftTypeToString(se::fft::Type type) {
} // namespace
-FftThunk::FftThunk(FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length,
+FftThunk::FftThunk(FftType fft_type, absl::Span<const int64> fft_length,
const BufferAllocation::Slice& input_buffer,
const BufferAllocation::Slice& output_buffer,
const Shape& input_shape, const Shape& output_shape,
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
index 4adec7ee54..2be50e08bd 100644
--- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
@@ -62,7 +62,7 @@ class FftThunk : public Thunk {
public:
// Constructs a thunk for launching an FFT on a stream.
// Semantics of null hlo_instruction argument are as in Thunk.
- FftThunk(FftType fft_type, tensorflow::gtl::ArraySlice<int64> fft_length,
+ FftThunk(FftType fft_type, absl::Span<const int64> fft_length,
const BufferAllocation::Slice& input_buffer,
const BufferAllocation::Slice& output_buffer,
const Shape& input_shape, const Shape& output_shape,
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
index 1bd88233e1..30c1f90889 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -225,10 +226,11 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
// Skip 'fusion' instruction if we cannot merge into all of its users.
// Merging into all users enables the removal of 'fusion' from the
// computation.
- if (!absl::c_all_of(fusion->users(), [](const HloInstruction* user) {
+ if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) {
return user->opcode() == HloOpcode::kFusion &&
(user->fusion_kind() == HloInstruction::FusionKind::kLoop ||
- user->fusion_kind() == HloInstruction::FusionKind::kInput);
+ (user->fusion_kind() == HloInstruction::FusionKind::kInput &&
+ LayoutsAreReduceInputFusionFriendly(*fusion, *user)));
})) {
VLOG(3) << "Not merging " << fusion->name()
<< ": Some of its users are not loop/input fusion kernels.";
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
index b22bb1d39b..7cc869ed9e 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
@@ -286,6 +286,39 @@ TEST_F(FusionMergerTest, WillMergeIntoInputFusion) {
op::Fusion(op::Parameter()));
}
+TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) {
+ auto module = ParseHloString(R"(
+ HloModule m
+
+ f1_computation {
+ f1_p0 = f32[16,16,256]{0,1,2} parameter(0)
+ add = f32[16,16,256]{0,1,2} add(f1_p0, f1_p0)
+ // Note that the copy changes the layout from {0,1,2} to {2,1,0}.
+ ROOT f1_root = f32[16,16,256]{2,1,0} copy(add)
+ }
+
+ add_computation {
+ add_lhs = f32[] parameter(0)
+ add_rhs = f32[] parameter(1)
+ ROOT add_root = f32[] add(add_lhs, add_rhs)
+ }
+
+ f2_computation {
+ f2_p0 = f32[16,16,256]{2,1,0} parameter(0)
+ f2_zero = f32[] constant(0)
+ ROOT f2_root = f32[] reduce(f2_p0, f2_zero), dimensions={0,1,2},
+ to_apply=add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[16,16,256]{0,1,2} parameter(0)
+ f1 = f32[16,16,256]{2,1,0} fusion(p0), kind=kLoop, calls=f1_computation
+ ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation
+ })")
+ .ValueOrDie();
+ EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie());
+}
+
} // namespace
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 71a02e70df..31a9f9b1be 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -234,7 +234,7 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) {
StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) {
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
@@ -325,7 +325,7 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ absl::Span<const ShapedBuffer* const> arguments) {
// TODO(b/30671675): Implement asynchronous execution mode.
return Unimplemented(
"Asynchronous execution on stream is not yet supported on GPU.");
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index 627a05e240..38b0f8f15b 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/executable.h"
@@ -34,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -78,12 +78,12 @@ class GpuExecutable : public Executable {
// match the compute capability passed to this object's constructor.
StatusOr<ScopedShapedBuffer> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) override;
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) override;
+ absl::Span<const ShapedBuffer* const> arguments) override;
private:
// If `block_host_until_done` is false, execution will not block the host
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
new file mode 100644
index 0000000000..2d31fd5570
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
@@ -0,0 +1,84 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
+
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+void AppendParams(const HloInstruction& instr,
+ std::vector<HloInstruction*>* params) {
+ if (instr.opcode() == HloOpcode::kFusion) {
+ params->insert(std::end(*params), std::begin(instr.fused_parameters()),
+ std::end(instr.fused_parameters()));
+ } else {
+ for (HloInstruction* operand : instr.operands()) {
+ params->push_back(operand);
+ }
+ }
+}
+} // namespace
+
+bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
+ const HloInstruction& reduce) {
+ std::vector<HloInstruction*> params;
+ AppendParams(producer, &params);
+ AppendParams(reduce, &params);
+ int64 max_rank = -1;
+ const Layout* max_rank_layout;
+ for (HloInstruction* param : params) {
+ if (ShapeUtil::IsArray(param->shape()) &&
+ ShapeUtil::Rank(param->shape()) > max_rank) {
+ max_rank = ShapeUtil::Rank(param->shape());
+ max_rank_layout = &param->shape().layout();
+ }
+ }
+ return absl::c_all_of(params, [&](HloInstruction* param) {
+ return (!ShapeUtil::IsArray(param->shape())) ||
+ (ShapeUtil::Rank(param->shape()) < max_rank) ||
+ (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout));
+ });
+}
+
+bool IsInputFusibleReduction(const HloInstruction& instr) {
+ if (instr.IsMultiOutputFusion()) {
+ for (const HloInstruction* operand :
+ instr.fused_expression_root()->operands()) {
+ if (IsReductionToVector(*operand)) {
+ CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput)
+ << " Multi-output fusion rooted at reduction-to-vector ops must be "
+ "of kind kInput: "
+ << instr.ToString();
+ return true;
+ }
+ }
+ return false;
+ } else if (instr.opcode() == HloOpcode::kFusion) {
+ if (IsReductionToVector(*instr.fused_expression_root())) {
+ CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput)
+ << " Fusion rooted at reduction-to-vector op must be of kind kInput: "
+ << instr.ToString();
+ return true;
+ }
+ return false;
+ }
+ return IsReductionToVector(instr);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h
new file mode 100644
index 0000000000..f7c24a0d5b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_FUSIBLE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_FUSIBLE_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+
+// TODO(b/112957171): Extract logic to determine fusibility of HLO ops from
+// GpuInstructionFusion, FusionMerger, and GpuMultiOutputFusion.
+
+namespace xla {
+namespace gpu {
+
+// The code emitted for reduce-rooted input fusions (EmitReductionToVector)
+// suffers from poor data locality if the layouts of input parameters differ. In
+// such situtations it is better not to fuse. Only input params with
+// maximum rank are considered. Params with smaller ranks will be broadcasted
+// and have not been observed to cause data locality issues.
+// TODO(b/111977086): Improve reduce emitters to remove this limitation.
+bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
+ const HloInstruction& reduce);
+
+// Whether `instr` is fusible as root of a reduce input fusions, i.e. `instr`
+// is either an unfused reduction-to-vector op, an input fusion rooted at a
+// reduction-to-vector op, or a multi-output input fusion with at least one
+// reduction-to-vector op root.
+// Note that reduction ops are lowered in different ways. Reduce input fusions
+// are lowered by IrEmitterUnnested::EmitReductionToVector and must be rooted at
+// reduction-to-vector ops. Other reduction ops are lowered by
+// GpuElementalIrEmitter and fused like elementwise ops.
+bool IsInputFusibleReduction(const HloInstruction& instr);
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_FUSIBLE_H_
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc
new file mode 100644
index 0000000000..d91b7bc61f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc
@@ -0,0 +1,332 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+
+using GpuFusibleTest = HloTestBase;
+
+const char kModulePrefix[] = R"(
+ HloModule test_module
+ scalar_add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ })";
+
+TEST_F(GpuFusibleTest,
+ LayoutsAreReduceInputFusionFriendly_ElementwiseProducer) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ ENTRY entry {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ c0 = f32[] constant(0)
+ exp = f32[2,2,2]{2,1,0} exponential(p0)
+ ROOT reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add
+ })"))
+ .ValueOrDie();
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce);
+ const HloInstruction* exp =
+ module->entry_computation()->root_instruction()->operand(0);
+ ASSERT_EQ(exp->opcode(), HloOpcode::kExp);
+ EXPECT_TRUE(LayoutsAreReduceInputFusionFriendly(*exp, *reduce));
+}
+
+TEST_F(GpuFusibleTest,
+ LayoutsAreReduceInputFusionFriendly_MixedLayoutProducer) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ mixed_input_layouts_computation {
+ p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+ p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
+ copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1)
+ c0 = f16[] constant(0)
+ broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={}
+ greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast)
+ ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
+ }
+ fused_reduce {
+ p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+ convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2)
+ c0.2 = f32[] constant(0)
+ ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add
+ }
+ ENTRY entry {
+ p0 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+ p1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
+ loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation
+ reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce
+ ROOT root = (f32[1024]{0}, f16[128,1024,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion)
+ })"))
+ .ValueOrDie();
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* reduce_fusion =
+ module->entry_computation()->root_instruction()->operand(0);
+ ASSERT_EQ(reduce_fusion->fused_expression_root()->opcode(),
+ HloOpcode::kReduce);
+ const HloInstruction* loop_fusion =
+ module->entry_computation()->root_instruction()->operand(1);
+ ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kSelect);
+ EXPECT_FALSE(
+ LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion));
+}
+
+TEST_F(GpuFusibleTest, LayoutsAreReduceInputFusionFriendly_CopyProducer) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ fused_reduce {
+ p0.1 = f32[128,1024,32,32]{1,3,2,0} parameter(0)
+ c0.1 = f32[] constant(0)
+ ROOT reduce = f32[1024]{0} reduce(p0.1, c0.1), dimensions={0,2,3}, to_apply=scalar_add
+ }
+ ENTRY entry {
+ p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0)
+ copy = f32[128,1024,32,32]{1,3,2,0} copy(p0)
+ ROOT reduce_fusion = f32[1024]{0} fusion(copy), kind=kInput, calls=fused_reduce
+ })"))
+ .ValueOrDie();
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->fused_expression_root()->opcode(), HloOpcode::kReduce);
+ const HloInstruction* copy =
+ module->entry_computation()->root_instruction()->operand(0);
+ ASSERT_EQ(copy->opcode(), HloOpcode::kCopy);
+ EXPECT_FALSE(LayoutsAreReduceInputFusionFriendly(*copy, *reduce));
+}
+
+TEST_F(GpuFusibleTest,
+ LayoutsAreReduceInputFusionFriendly_LayoutChangingFusionProducer) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ layout_changing_computation {
+ p0.1 = f16[128,1024,32,32]{3,2,1,0} parameter(0)
+ p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
+ c0 = f16[] constant(0)
+ broadcast = f16[128,1024,32,32]{3,2,1,0} broadcast(c0), dimensions={}
+ greater-than = pred[128,1024,32,32]{3,2,1,0} greater-than(p1.1, broadcast)
+ select = f16[128,1024,32,32]{3,2,1,0} select(greater-than, p0.1, broadcast)
+ ROOT root = f16[128,1024,32,32]{1,3,2,0} copy(select)
+ }
+ fused_reduce {
+ p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+ convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2)
+ c0.2 = f32[] constant(0)
+ ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add
+ }
+ ENTRY entry {
+ p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0)
+ p1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
+ loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=layout_changing_computation
+ ROOT reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce
+ })"))
+ .ValueOrDie();
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* reduce_fusion =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce_fusion->fused_expression_root()->opcode(),
+ HloOpcode::kReduce);
+ const HloInstruction* loop_fusion =
+ module->entry_computation()->root_instruction()->operand(0);
+ ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kCopy);
+ EXPECT_FALSE(
+ LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion));
+}
+
+TEST_F(GpuFusibleTest,
+ LayoutsAreReduceInputFusionFriendly_ConsiderMaximumRanksParamsOnly) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ broadcasting_computation {
+ p0.1 = f32[128,1024,32,32]{1,3,2,0} parameter(0)
+ p1.1 = f32[128]{0} parameter(1)
+ broadcast = f32[128,1024,32,32]{1,3,2,0} broadcast(p1.1), dimensions={0}
+ ROOT add = f32[128,1024,32,32]{1,3,2,0} add(p0.1, broadcast)
+ }
+ ENTRY entry {
+ p0 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+ p1 = f16[128]{0} parameter(1)
+ loop_fusion = f32[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=broadcasting_computation
+ c0.2 = f32[] constant(0)
+ ROOT reduce = f32[128,1024]{0,1} reduce(loop_fusion, c0.2), dimensions={0,2,3}, to_apply=scalar_add
+ })"))
+ .ValueOrDie();
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce);
+ const HloInstruction* loop_fusion =
+ module->entry_computation()->root_instruction()->operand(0);
+ ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kAdd);
+ EXPECT_TRUE(LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce));
+}
+
+TEST_F(GpuFusibleTest, IsInputFusibleReduction_ReductionToVector) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ ENTRY entry {
+ c0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ // Reduction-to-vector lowered by IrEmitterUnnested.
+ ROOT reduce = f32[512]{0} reduce(p1, c0), dimensions={0,2,3}, to_apply=scalar_add
+ })"))
+ .ValueOrDie();
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce);
+ EXPECT_TRUE(IsInputFusibleReduction(*reduce));
+}
+
+TEST_F(GpuFusibleTest, IsInputFusibleReduction_ElementalReduction) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ ENTRY entry {
+ c0 = f32[] parameter(0)
+ p1 = f32[8,512,5,16,1,1]{5,4,3,2,1,0} parameter(1)
+ // Reduction lowered by GpuElementalIrEmitter.
+ ROOT reduce = f32[8,512,5,1,1]{4,3,2,1,0} reduce(p1, c0), dimensions={3}, to_apply=scalar_add
+ })"))
+ .ValueOrDie();
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce);
+ EXPECT_FALSE(IsInputFusibleReduction(*reduce));
+}
+
+TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputInputReduceFusion) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ fused_reduction {
+ c0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ ROOT reduce = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add
+ }
+ ENTRY entry {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ ROOT fusion = f32[128,512]{1,0} fusion(p0), kind=kInput, calls=fused_reduction
+ })"))
+ .ValueOrDie();
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion);
+ EXPECT_TRUE(IsInputFusibleReduction(*reduce));
+}
+
+TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputLoopReduceFusion) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ fused_reduction {
+ c0 = f32[] parameter(0)
+ p1 = f32[8,512,5,16,1,1]{5,4,3,2,1,0} parameter(1)
+ ROOT reduce = f32[8,5,1,1]{3,2,1,0} reduce(p1, c0), dimensions={1,3}, to_apply=scalar_add
+ }
+ ENTRY entry {
+ p0 = f32[8,512,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ ROOT fusion = f32[8,5,1,1]{3,2,1,0} fusion(p0), kind=kLoop, calls=fused_reduction
+ })"))
+ .ValueOrDie();
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion);
+ EXPECT_FALSE(IsInputFusibleReduction(*reduce));
+}
+
+TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputInputReduceFusion) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ fused_reduction {
+ c0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ reduce.0 = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add
+ reduce.1 = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add
+ ROOT root = (f32[128,512]{1,0}, f32[128,512]{1,0}) tuple(reduce.0, reduce.1)
+ }
+ ENTRY entry {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ ROOT fusion = (f32[128,512]{1,0}, f32[128,512]{1,0}) fusion(p0), kind=kInput, calls=fused_reduction
+ })"))
+ .ValueOrDie();
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion);
+ EXPECT_TRUE(IsInputFusibleReduction(*reduce));
+}
+
+TEST_F(GpuFusibleTest,
+ IsInputFusibleReduction_MultiOutputInputReduceFusionWithExtraOutputs) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ fused_reduction {
+ c0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ reduce = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(p1, p1)
+ ROOT root = (f32[128,512]{1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(reduce, mul)
+ }
+ ENTRY entry {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ ROOT fusion = (f32[128,512]{1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kInput, calls=fused_reduction
+ })"))
+ .ValueOrDie();
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion);
+ EXPECT_TRUE(IsInputFusibleReduction(*reduce));
+}
+
+TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputLoopReduceFusion) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ fused_reduction {
+ c0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ reduce.0 = f32[512,28]{1,0} reduce(p1, c0), dimensions={0,2}, to_apply=scalar_add
+ reduce.1 = f32[512,28]{1,0} reduce(p1, c0), dimensions={0,2}, to_apply=scalar_add
+ ROOT root = (f32[512,28]{1,0}, f32[512,28]{1,0}) tuple(reduce.0, reduce.1)
+ }
+ ENTRY entry {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ ROOT fusion = (f32[512,28]{1,0}, f32[512,28]{1,0}) fusion(p0), kind=kLoop, calls=fused_reduction
+ })"))
+ .ValueOrDie();
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion);
+ EXPECT_FALSE(IsInputFusibleReduction(*reduce));
+}
+
+TEST_F(GpuFusibleTest,
+ IsInputFusibleReduction_MultiOutputLoopFusionReduceAndElementwiseOp) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ fused_reduction {
+ c0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ reduce = f32[512,28]{1,0} reduce(p1, c0), dimensions={0,2}, to_apply=scalar_add
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(p1, p1)
+ ROOT root = (f32[512,28]{1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(reduce, mul)
+ }
+ ENTRY entry {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ ROOT fusion = (f32[512,28]{1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_reduction
+ })"))
+ .ValueOrDie();
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion);
+ EXPECT_FALSE(IsInputFusibleReduction(*reduce));
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
index 76055ff009..743035a84e 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <memory>
#include <unordered_map>
-#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
@@ -184,13 +184,13 @@ void BFSLaunchOrder(const HloComputation* computation,
} // end namespace
-HloSchedule::HloSchedule() {}
+GpuHloSchedule::GpuHloSchedule() {}
/* static */
-StatusOr<std::unique_ptr<HloSchedule>> HloSchedule::Build(
+StatusOr<std::unique_ptr<GpuHloSchedule>> GpuHloSchedule::Build(
const HloModule& module, const StreamAssignment& stream_assignment,
int64 pointer_size) {
- std::unique_ptr<HloSchedule> schedule(new HloSchedule);
+ std::unique_ptr<GpuHloSchedule> schedule(new GpuHloSchedule);
// Initialize thunk_launch_order_, the total order of thunk launches.
const HloComputation* entry_computation = module.entry_computation();
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h
index 1ce7a48ac8..30a0e7cecd 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SCHEDULE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SCHEDULE_H_
#include <memory>
#include <vector>
@@ -34,11 +34,11 @@ namespace gpu {
// schedule is used by BufferAssigner to determine buffer liveness (i.e. to
// minimize allocations), and also by ThunkSchedule to determine the thunk
// launch order.
-class HloSchedule {
+class GpuHloSchedule {
public:
- // Constructs an HloSchedule for the given module, based on the given stream
- // assignment.
- static StatusOr<std::unique_ptr<HloSchedule>> Build(
+ // Constructs an GpuHloSchedule for the given module, based on the given
+ // stream assignment.
+ static StatusOr<std::unique_ptr<GpuHloSchedule>> Build(
const HloModule& module, const StreamAssignment& stream_assignment,
int64 pointer_size);
@@ -56,7 +56,7 @@ class HloSchedule {
}
private:
- HloSchedule();
+ GpuHloSchedule();
std::vector<const HloInstruction*> thunk_launch_order_;
std::unique_ptr<HloOrdering> hlo_ordering_;
@@ -65,4 +65,4 @@ class HloSchedule {
} // namespace gpu
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SCHEDULE_H_
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
index bb147c8d98..0922e44a12 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
@@ -13,13 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
#include <algorithm>
#include <unordered_set>
#include "absl/memory/memory.h"
-#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -31,16 +30,16 @@ limitations under the License.
namespace xla {
namespace gpu {
-class HloScheduleTest : public HloTestBase {
+class GpuHloScheduleTest : public HloTestBase {
protected:
using HloVec = std::vector<const HloInstruction*>;
// Pre-canned shapes.
Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2});
- static std::unique_ptr<HloSchedule> BuildHloSchedule(
+ static std::unique_ptr<GpuHloSchedule> BuildGpuHloSchedule(
const HloModule& module, const StreamAssignment& streams) {
- return HloSchedule::Build(module, streams, /*pointer_size=*/8)
+ return GpuHloSchedule::Build(module, streams, /*pointer_size=*/8)
.ConsumeValueOrDie();
}
@@ -66,7 +65,7 @@ class HloScheduleTest : public HloTestBase {
// Test of a single stream, where data dependencies fully determine the
// execution order.
-TEST_F(HloScheduleTest, SequentialMatMul) {
+TEST_F(GpuHloScheduleTest, SequentialMatMul) {
HloComputation::Builder builder("entry_computation");
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
@@ -86,7 +85,7 @@ TEST_F(HloScheduleTest, SequentialMatMul) {
EXPECT_EQ(streams->StreamNumberForHlo(*dot1),
streams->StreamNumberForHlo(*dot2));
- auto schedule = BuildHloSchedule(*module, *streams);
+ auto schedule = BuildGpuHloSchedule(*module, *streams);
// Remove parameters, which are unordered.
EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}),
HloVec({dot1, dot2}));
@@ -124,7 +123,7 @@ TEST_F(HloScheduleTest, SequentialMatMul) {
// Test of a single stream, where data dependencies do not fully determine the
// execution order, but the stream assignment does.
-TEST_F(HloScheduleTest, SequentialAdd) {
+TEST_F(GpuHloScheduleTest, SequentialAdd) {
HloComputation::Builder builder("entry_computation");
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
@@ -148,7 +147,7 @@ TEST_F(HloScheduleTest, SequentialAdd) {
EXPECT_EQ(streams->StreamNumberForHlo(*add1),
streams->StreamNumberForHlo(*add3));
- auto schedule = BuildHloSchedule(*module, *streams);
+ auto schedule = BuildGpuHloSchedule(*module, *streams);
// Remove parameters, which are unordered.
EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}),
HloVec({add1, add2, add3}));
@@ -196,7 +195,7 @@ TEST_F(HloScheduleTest, SequentialAdd) {
}
// Test of two streams.
-TEST_F(HloScheduleTest, ConcurrentMatMul) {
+TEST_F(GpuHloScheduleTest, ConcurrentMatMul) {
HloComputation::Builder builder("entry_computation");
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
@@ -216,7 +215,7 @@ TEST_F(HloScheduleTest, ConcurrentMatMul) {
EXPECT_NE(streams->StreamNumberForHlo(*dot1),
streams->StreamNumberForHlo(*dot2));
- auto schedule = BuildHloSchedule(*module, *streams);
+ auto schedule = BuildGpuHloSchedule(*module, *streams);
// Remove parameters, which are unordered.
HloVec thunk_launch_order = RemoveHlo(schedule->ThunkLaunchOrder(), {x, y});
EXPECT_TRUE(thunk_launch_order == HloVec({dot1, dot2, add}) ||
@@ -252,7 +251,7 @@ TEST_F(HloScheduleTest, ConcurrentMatMul) {
}
// Test of multiple streams.
-TEST_F(HloScheduleTest, LatticeMatMul) {
+TEST_F(GpuHloScheduleTest, LatticeMatMul) {
// d00 -- layer 0
// / \
// d10 d11 -- layer 1
@@ -308,7 +307,7 @@ TEST_F(HloScheduleTest, LatticeMatMul) {
// We don't check the thunk launch order, since there are many valid total
// orders, and it's annoying to express.
- auto schedule = BuildHloSchedule(*module, *streams);
+ auto schedule = BuildGpuHloSchedule(*module, *streams);
auto order = schedule->ConsumeHloOrdering();
const HloVec all_params(
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
index 0e205b9c02..51627402b4 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
@@ -35,8 +35,8 @@ using absl::StrAppend;
using absl::StrCat;
void HloToIrBindings::EmitBasePointersForHlos(
- tensorflow::gtl::ArraySlice<const HloInstruction*> io_hlos,
- tensorflow::gtl::ArraySlice<const HloInstruction*> non_io_hlos) {
+ absl::Span<const HloInstruction* const> io_hlos,
+ absl::Span<const HloInstruction* const> non_io_hlos) {
// I/O HLOs are bound to the arguments of the current IR function. I.e.,
//
// void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) {
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h
index eee40b0e91..c0edae530c 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h
+++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_map>
+#include "absl/types/span.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace gpu {
@@ -45,8 +45,8 @@ class HloToIrBindings {
alias_analysis_(module, *buffer_assignment_, &b_->getContext()) {}
void EmitBasePointersForHlos(
- tensorflow::gtl::ArraySlice<const HloInstruction*> io_hlos,
- tensorflow::gtl::ArraySlice<const HloInstruction*> non_io_hlos);
+ absl::Span<const HloInstruction* const> io_hlos,
+ absl::Span<const HloInstruction* const> non_io_hlos);
// Rebinds the given HLO to the LLVM IR value that represent its address.
void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value,
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
index 0bcaaee2b7..4d5d8e99f8 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
@@ -41,7 +42,7 @@ bool IsFusible(const HloInstruction& hlo) {
hlo.opcode() == HloOpcode::kDynamicUpdateSlice ||
hlo.opcode() == HloOpcode::kFusion ||
hlo.opcode() == HloOpcode::kGather ||
- hlo.opcode() == HloOpcode::kPad ||
+ hlo.opcode() == HloOpcode::kIota || hlo.opcode() == HloOpcode::kPad ||
hlo.opcode() == HloOpcode::kReduce ||
hlo.opcode() == HloOpcode::kReduceWindow ||
hlo.opcode() == HloOpcode::kReshape ||
@@ -221,6 +222,13 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
return false;
}
+ // Do not fuse into reduce input fusions if the resulting kernel would suffer
+ // from poor data locality (due to unfriendly input layouts).
+ if (IsInputFusibleReduction(*consumer) &&
+ !LayoutsAreReduceInputFusionFriendly(*producer, *consumer)) {
+ return false;
+ }
+
// We can't fuse library calls, so if a user of such an op could become a
// bitcast, leave it unfused. See `xla::InstructionFusion::ShouldFuse` for
// further rationale.
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index f53dfaee3d..bca775c475 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -171,6 +171,78 @@ TEST_F(InstructionFusionTest, BroadcastIntoReduce) {
op::Reduce(op::Broadcast(op::Constant()), op::Constant()));
}
+TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduce) {
+ auto module = ParseHloString(R"(
+ HloModule test_module
+
+ add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ }
+
+ ENTRY entry {
+ p0 = f32[16,16,16,16]{3,2,1,0} parameter(0)
+ copy = f32[16,16,16,16]{0,1,2,3} copy(p0)
+ constant.1 = f32[] constant(0)
+ ROOT reduce = f32[16] reduce(copy, constant.1), dimensions={0,1,2}, to_apply=add
+ })")
+ .ValueOrDie();
+
+ EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+}
+
+TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduceFusion) {
+ auto module = ParseHloString(R"(
+ HloModule test_module
+
+ add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ }
+
+ fused_reduce {
+ p0.1 = f32[16,16,16,16]{0,1,2,3} parameter(0)
+ mul = f32[16,16,16,16]{0,1,2,3} multiply(p0.1, p0.1)
+ c0.1 = f32[] constant(0)
+ ROOT root = f32[] reduce(mul, c0.1), dimensions={0,1,2,3}, to_apply=add
+ }
+
+ ENTRY entry {
+ p0 = f32[16,16,16,16]{3,2,1,0} parameter(0)
+ copy = f32[16,16,16,16]{0,1,2,3} copy(p0)
+ fusion = f32[] fusion(copy), kind=kInput, calls=fused_reduce
+ ROOT root = (f32[]) tuple(fusion)
+ })")
+ .ValueOrDie();
+
+ EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+}
+
+TEST_F(InstructionFusionTest, FuseLayoutChangingOpWithElementwise) {
+ auto module = ParseHloString(R"(
+ HloModule test_module
+ ENTRY entry {
+ p0 = f32[16,16,16,16]{3,2,1,0} parameter(0)
+ copy = f32[16,16,16,16]{0,1,2,3} copy(p0)
+ ROOT add = f32[16,16,16,16]{0,1,2,3} add(copy, copy)
+ })")
+ .ValueOrDie();
+
+ EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Fusion());
+ EXPECT_THAT(root->fused_expression_root(), op::Add(op::Copy(), op::Copy()));
+}
+
TEST_F(InstructionFusionTest, BitcastIntoAdd) {
auto module = ParseHloString(R"(
HloModule test_module
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index f544bcc919..20d523abe0 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -144,10 +144,12 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
IsCustomCallToDnnConvolution(hlo);
}
-static HloInstruction* CreateCudnnConv(
- const char* call_target, const Shape& shape, HloInstruction* lhs,
- HloInstruction* rhs, const Window& window,
- const ConvolutionDimensionNumbers& dnums) {
+static HloInstruction* CreateCudnnConv(const char* call_target,
+ const Shape& shape, HloInstruction* lhs,
+ HloInstruction* rhs,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count) {
HloComputation* computation = lhs->parent();
// This call returns a tuple of (conv_result, scratch_memory), where
@@ -165,28 +167,34 @@ static HloInstruction* CreateCudnnConv(
HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target));
custom_call->set_window(window);
custom_call->set_convolution_dimension_numbers(dnums);
+ custom_call->set_feature_group_count(feature_group_count);
return custom_call;
}
-HloInstruction* CreateCudnnConvForward(
- const Shape& shape, HloInstruction* input, HloInstruction* kernel,
- const Window& window, const ConvolutionDimensionNumbers& dnums) {
+HloInstruction* CreateCudnnConvForward(const Shape& shape,
+ HloInstruction* input,
+ HloInstruction* kernel,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count) {
return CreateCudnnConv(kCudnnConvForwardCallTarget, shape, input, kernel,
- window, dnums);
+ window, dnums, feature_group_count);
}
HloInstruction* CreateCudnnConvBackwardInput(
const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter,
- const Window& window, const ConvolutionDimensionNumbers& dnums) {
+ const Window& window, const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count) {
return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, shape, output,
- reverse_filter, window, dnums);
+ reverse_filter, window, dnums, feature_group_count);
}
HloInstruction* CreateCudnnConvBackwardFilter(
const Shape& shape, HloInstruction* input, HloInstruction* output,
- const Window& window, const ConvolutionDimensionNumbers& dnums) {
+ const Window& window, const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count) {
return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, shape, input,
- output, window, dnums);
+ output, window, dnums, feature_group_count);
}
bool IsReductionToVector(const HloInstruction& reduce) {
@@ -216,7 +224,7 @@ bool IsReductionToVector(const HloInstruction& reduce) {
// "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see
// http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
llvm::Value* EmitPrintf(absl::string_view fmt,
- tensorflow::gtl::ArraySlice<llvm::Value*> arguments,
+ absl::Span<llvm::Value* const> arguments,
llvm::IRBuilder<>* builder) {
std::vector<llvm::Type*> argument_types;
for (auto argument : arguments) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index a35e250101..59c65fc268 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -109,15 +109,20 @@ bool IsCustomCallToDnnConvolution(const HloInstruction& hlo);
//
// The created cudnn call will use the default cudnn algorithm and no scratch
// space.
-HloInstruction* CreateCudnnConvForward(
- const Shape& shape, HloInstruction* input, HloInstruction* kernel,
- const Window& window, const ConvolutionDimensionNumbers& dnums);
+HloInstruction* CreateCudnnConvForward(const Shape& shape,
+ HloInstruction* input,
+ HloInstruction* kernel,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count);
HloInstruction* CreateCudnnConvBackwardInput(
const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter,
- const Window& window, const ConvolutionDimensionNumbers& dnums);
+ const Window& window, const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count);
HloInstruction* CreateCudnnConvBackwardFilter(
const Shape& shape, HloInstruction* input, HloInstruction* output,
- const Window& window, const ConvolutionDimensionNumbers& dnums);
+ const Window& window, const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count);
// Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm
// or cuDNN convolution.
@@ -127,7 +132,7 @@ bool IsReductionToVector(const HloInstruction& reduce);
// Emits call to "vprintf" with given format and arguments.
llvm::Value* EmitPrintf(absl::string_view fmt,
- tensorflow::gtl::ArraySlice<llvm::Value*> arguments,
+ absl::Span<llvm::Value* const> arguments,
llvm::IRBuilder<>* builder);
// Emits code to shuffle data between threads of a warp. This has the same
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index bdf6aadde6..ffca5d6549 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -141,7 +141,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
Status IrEmitter::EmitCallToNestedComputation(
const HloComputation& nested_computation,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands, llvm::Value* output) {
+ absl::Span<llvm::Value* const> operands, llvm::Value* output) {
TF_RET_CHECK(nested_computation.num_parameters() > 0);
llvm::Function*& emitted_function =
computation_to_ir_function_[&nested_computation];
@@ -633,7 +633,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
}
auto arg = reduce->operand(0);
auto init_value = reduce->operand(1);
- tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ absl::Span<const int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
return EmitTargetElementLoop(
*reduce,
@@ -748,7 +748,7 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
const HloComputation& computation,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements) {
+ absl::Span<llvm::Value* const> parameter_elements) {
llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(
computation.root_instruction()->shape().element_type(), module_),
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 3673b9f58d..579268f071 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -23,6 +23,7 @@ limitations under the License.
#include <vector>
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
@@ -42,7 +43,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -143,9 +143,9 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// Emits a call in IR to the given nested computation with the given operands
// and output. If no IR function has been previously emitted for the
// computation, also emits such a function.
- Status EmitCallToNestedComputation(
- const HloComputation& nested_computation,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands, llvm::Value* output);
+ Status EmitCallToNestedComputation(const HloComputation& nested_computation,
+ absl::Span<llvm::Value* const> operands,
+ llvm::Value* output);
// Emits an atomic operation that implements `nested_computation` in the
// sequentially consistent memory model. `output_address` and `source_address`
@@ -199,7 +199,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
StatusOr<llvm::Value*> ComputeNestedElement(
const HloComputation& computation,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements);
+ absl::Span<llvm::Value* const> parameter_elements);
// Emits an atomic operation that implements `nested_computation` in the
// sequentially consistent memory model. `output_address` and `source_address`
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index c0c8ae181a..389a98facb 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
@@ -80,7 +81,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -94,7 +94,6 @@ using absl::optional;
using absl::StrCat;
using llvm_ir::IrArray;
using llvm_ir::IrName;
-using tensorflow::gtl::ArraySlice;
// If a dimensions is smaller than this, untiled transposition may be more
// efficient.
@@ -176,7 +175,7 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) {
llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
const HloInstruction& inst,
- tensorflow::gtl::ArraySlice<const BufferAllocation*> args) {
+ absl::Span<const BufferAllocation* const> args) {
// Compute the kernel name. The opcode string may contain "-" which cannot be
// in a PTX function name, so sanitize the name before uniquifying it.
string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName(
@@ -490,8 +489,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
/*filter_shape=*/rhs_shape,
/*output_shape=*/conv_result_shape, //
custom_call->window(), custom_call->convolution_dimension_numbers(),
- backend_config.algorithm(), backend_config.tensor_ops_enabled(),
- custom_call);
+ custom_call->feature_group_count(), backend_config.algorithm(),
+ backend_config.tensor_ops_enabled(), custom_call);
} else if (target == kCudnnConvBackwardInputCallTarget) {
thunk = absl::make_unique<ConvolutionThunk>(
CudnnConvKind::kBackwardInput,
@@ -504,8 +503,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
/*filter_shape=*/rhs_shape,
/*output_shape=*/lhs_shape, //
custom_call->window(), custom_call->convolution_dimension_numbers(),
- backend_config.algorithm(), backend_config.tensor_ops_enabled(),
- custom_call);
+ custom_call->feature_group_count(), backend_config.algorithm(),
+ backend_config.tensor_ops_enabled(), custom_call);
} else if (target == kCudnnConvBackwardFilterCallTarget) {
thunk = absl::make_unique<ConvolutionThunk>(
CudnnConvKind::kBackwardFilter,
@@ -518,8 +517,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
/*filter_shape=*/conv_result_shape,
/*output_shape=*/rhs_shape, //
custom_call->window(), custom_call->convolution_dimension_numbers(),
- backend_config.algorithm(), backend_config.tensor_ops_enabled(),
- custom_call);
+ custom_call->feature_group_count(), backend_config.algorithm(),
+ backend_config.tensor_ops_enabled(), custom_call);
} else {
LOG(FATAL) << "Unexpected custom call target: "
<< custom_call->custom_call_target();
@@ -556,10 +555,10 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
}
VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString();
std::vector<std::unique_ptr<Thunk>> thunks;
- ArraySlice<HloInstruction*> output_instructions =
+ absl::Span<HloInstruction* const> output_instructions =
root->opcode() == HloOpcode::kTuple
? root->operands()
- : ArraySlice<HloInstruction*>(&root, 1);
+ : absl::Span<HloInstruction* const>(&root, 1);
// For multi-output fusion emit an initializer for each tuple element.
// Otherwise it's sufficient to just initialize the single output.
@@ -718,8 +717,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
Status IrEmitterUnnested::EmitExtraOutputsForReduce(
const HloInstruction* reduce, const IrArray::Index& index,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
for (int i = 0; i != extra_output_gens.size(); ++i) {
const HloInstruction* output = reduce->parent()->FusionInstruction();
@@ -736,12 +734,11 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce(
Status IrEmitterUnnested::EmitReductionToScalar(
HloInstruction* reduce, const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
// Number of elements processed by a single thread.
constexpr int64 kTileSize = 16;
@@ -951,12 +948,11 @@ Status IrEmitterUnnested::EmitReductionToScalar(
Status IrEmitterUnnested::EmitColumnReduction(
int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
// Divide the input matrix into tiles of size KxL. For example, when the
// input matrix is 4x4, K=2, and L=1 the tiled matrix looks like
@@ -1240,12 +1236,11 @@ static std::pair<int64, int64> ComputeTilingSchemeForReduction(
Status IrEmitterUnnested::EmitRowReduction(
int64 depth, int64 height, int64 width, HloInstruction* reduce,
const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
// A naive algorithm is:
// 1. Divide the x dimension of the input tensor into tiles of size 1x1xX.
@@ -1593,13 +1588,12 @@ Status IrEmitterUnnested::EmitRowReduction(
// elementwise.
Status IrEmitterUnnested::EmitReductionToVector(
HloInstruction* reduce, const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<const int64> dimensions_to_reduce,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
// This emission requires "reduce" to have an input layout. It is either set
// by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for
@@ -1694,7 +1688,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
}
auto input = reduce->operand(0);
auto init_value = reduce->operand(1);
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce(reduce->dimensions());
+ absl::Span<const int64> dimensions_to_reduce(reduce->dimensions());
HloComputation* reducer = reduce->to_apply();
// HandleReduce specializes reduction from a multi-dimensional array to a 1D
// array. The specialized version requires an initializer thunk that
@@ -2570,7 +2564,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
// Are all the bytes of this scalar equal to 0? If so, we can create a
// MemzeroThunk.
- ArraySlice<uint8> literal_bytes(
+ absl::Span<const uint8> literal_bytes(
reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
return {absl::make_unique<MemzeroThunk>(GetAllocationSlice(*hlo, index),
@@ -2880,7 +2874,7 @@ int IrEmitterUnnested::ConstructIrArrayForInputs(
int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
const HloInstruction& hlo, const std::vector<IrArray>& output_arrays,
- tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ absl::Span<const int64> reduced_output_dims,
std::vector<Shape>* output_reduced_shapes,
std::vector<IrArray>* output_in_reduced_shape_arrays) {
int64 num_outputs = 1;
@@ -2907,7 +2901,7 @@ int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape(
const HloInstruction& hlo, const std::vector<IrArray>& param_arrays,
const std::vector<llvm::Value*>& param_buffers,
- tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ absl::Span<const int64> reduced_output_dims,
std::vector<Shape>* param_reduced_shapes,
std::vector<IrArray>* param_in_reduced_shape_arrays) {
int64 num_params = hlo.operands().size();
@@ -3048,8 +3042,8 @@ void EmitTiledElementalCodeWithBoundsCheck(
// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient
// to launch fewer blocks so each transposes many tiles.
LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
- HloInstruction* hlo, tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
- tensorflow::gtl::ArraySlice<int64> tiled_param_ids) {
+ HloInstruction* hlo, absl::Span<const int64> reduced_output_dims,
+ absl::Span<const int64> tiled_param_ids) {
// Parameters for the tiling algorithm.
constexpr int64 kTileSize = 32;
constexpr int64 kNumRows = 4;
@@ -3295,7 +3289,7 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
if (!reduced_dims_021.has_value()) {
reduced_dims_021 = curr_reduced_dims_021;
}
- if (!ContainersEqual(*reduced_dims_021, curr_reduced_dims_021)) {
+ if (!absl::c_equal(*reduced_dims_021, curr_reduced_dims_021)) {
// There is more than one possible transpose. Instead of picking one
// transpose, we simply give up here.
return false;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 5254419907..084462330e 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -105,13 +105,12 @@ class IrEmitterUnnested : public IrEmitter {
// This kernel takes as arguments pointers to the given buffer allocations.
llvm::Function* BuildKernelPrototype(
const HloInstruction& inst,
- tensorflow::gtl::ArraySlice<const BufferAllocation*> args);
+ absl::Span<const BufferAllocation* const> args);
// Helper for writing extra outputs from inside a reduce kernel.
Status EmitExtraOutputsForReduce(
const HloInstruction* reduce, const llvm_ir::IrArray::Index& index,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens);
// EmitColumnReduction and EmitRowReduction emit code for column and row
@@ -127,12 +126,11 @@ class IrEmitterUnnested : public IrEmitter {
Status EmitColumnReduction(
int64 height, int64 width, HloInstruction* reduce,
const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens);
// Emits code that reduces a 3D tensor of shape [depth x height x width] to a
@@ -143,23 +141,21 @@ class IrEmitterUnnested : public IrEmitter {
Status EmitRowReduction(
int64 depth, int64 height, int64 width, HloInstruction* reduce,
const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens);
// Emits code that reduces a tensor of arbitrary rank to a scalar.
Status EmitReductionToScalar(
HloInstruction* reduce, const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens);
// Figures out whether `reduce` is a row or column reduction, and which
@@ -180,13 +176,12 @@ class IrEmitterUnnested : public IrEmitter {
// Prerequisite: `IsReductionToVector(*reduce)`
Status EmitReductionToVector(
HloInstruction* reduce, const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<const int64> dimensions_to_reduce,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens);
// Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
@@ -195,10 +190,9 @@ class IrEmitterUnnested : public IrEmitter {
// Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and
// returns the launch dimensions for the kernel. This is a helper to support
// the implementation of CheckAndEmitHloWithTile021.
- LaunchDimensions EmitHlo021Tile(
- HloInstruction* hlo,
- tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
- tensorflow::gtl::ArraySlice<int64> tiled_param_ids);
+ LaunchDimensions EmitHlo021Tile(HloInstruction* hlo,
+ absl::Span<const int64> reduced_output_dims,
+ absl::Span<const int64> tiled_param_ids);
// Generates the IrArray for each output of hlo and returns the number of
// outputs.
int ConstructIrArrayForOutputs(const HloInstruction& hlo,
@@ -214,7 +208,7 @@ class IrEmitterUnnested : public IrEmitter {
int ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
const HloInstruction& hlo,
const std::vector<llvm_ir::IrArray>& output_arrays,
- tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ absl::Span<const int64> reduced_output_dims,
std::vector<Shape>* output_reduced_shapes,
std::vector<llvm_ir::IrArray>* output_in_reduced_shape_arrays);
// For each input of the `hlo` instruction, checks its value in
@@ -226,7 +220,7 @@ class IrEmitterUnnested : public IrEmitter {
const HloInstruction& hlo,
const std::vector<llvm_ir::IrArray>& param_arrays,
const std::vector<llvm::Value*>& param_buffers,
- tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ absl::Span<const int64> reduced_output_dims,
std::vector<Shape>* param_reduced_shapes,
std::vector<llvm_ir::IrArray>* param_in_reduced_shape_arrays);
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
index 3259eaa2a2..e09b8fbd3b 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
@@ -27,10 +27,10 @@ limitations under the License.
namespace xla {
namespace gpu {
-KernelThunk::KernelThunk(
- tensorflow::gtl::ArraySlice<const BufferAllocation*> args,
- const string& kernel_name, const HloInstruction* hlo_instruction,
- int unroll_factor)
+KernelThunk::KernelThunk(absl::Span<const BufferAllocation* const> args,
+ const string& kernel_name,
+ const HloInstruction* hlo_instruction,
+ int unroll_factor)
: Thunk(Kind::kKernel, hlo_instruction),
args_(args.begin(), args.end()),
kernel_name_(kernel_name),
@@ -41,11 +41,7 @@ Status KernelThunk::Initialize(const GpuExecutable& executable,
tensorflow::mutex_lock lock(mutex_);
if (!loader_spec_) {
loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size()));
- absl::string_view ptx = executable.ptx();
- // Convert absl::string_view to se::port::StringPiece because
- // StreamExecutor uses the latter.
- loader_spec_->AddCudaPtxInMemory(
- se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_);
+ loader_spec_->AddCudaPtxInMemory(executable.ptx(), kernel_name_);
if (!executable.cubin().empty()) {
loader_spec_->AddCudaCubinInMemory(
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
index d751de50ad..f63db5c369 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -47,7 +47,7 @@ class KernelThunk : public Thunk {
// Constructs a thunk for the given kernel.
//
// `hlo_instruction` is as in Thunk. Other arguments are as the class members.
- KernelThunk(tensorflow::gtl::ArraySlice<const BufferAllocation*> args,
+ KernelThunk(absl::Span<const BufferAllocation* const> args,
const string& kernel_name, const HloInstruction* hlo_instruction,
int unroll_factor);
KernelThunk(const KernelThunk&) = delete;
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index 7a43f0be54..c21f76f6eb 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -86,67 +87,13 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
get_element_shape(element_instr_1), get_element_shape(element_instr_2));
}
-namespace {
-bool IsInputFusibleReduction(HloInstruction* instr) {
- if (instr->IsMultiOutputFusion()) {
- for (const HloInstruction* operand :
- instr->fused_expression_root()->operands()) {
- if (operand->opcode() == HloOpcode::kReduce) {
- CHECK(instr->fusion_kind() == HloInstruction::FusionKind::kInput)
- << " Reduce multi-output fusion " << instr->ToString()
- << " must be an input fusion.";
- return true;
- }
- }
- return false;
- } else if (instr->opcode() == HloOpcode::kFusion) {
- // The loop emitter can handle to-vector reduce fusions. Such reduce
- // fusions have the fusion kind kLoop rather than kInput. We do not fuse
- // to-vector reduce fusions, because the resulting fusions may no longer be
- // supported by loop emitter.
- return IsReductionToVector(*instr->fused_expression_root());
- } else {
- return IsReductionToVector(*instr);
- }
-}
-
-// The code emitted for reduction suffers from poor data locality if the layouts
-// of input parameters differ. In such situtations it is beneficial not to fuse.
-// We consider input params with maximum rank only. Params with smaller ranks
-// will be broadcasted and have not been observed to cause data locality issues.
-// TODO(b/111977086): Improve reduce emitters to remove this limitation.
-bool ReduceFriendlyInputLayouts(HloInstruction* instr) {
- std::vector<HloInstruction*> params;
- if (instr->opcode() == HloOpcode::kFusion) {
- params = instr->fused_parameters();
- } else {
- for (HloInstruction* operand : instr->operands()) {
- params.push_back(operand);
- }
- }
- int64 max_rank = 0;
- const Layout* max_rank_layout;
- for (HloInstruction* param : params) {
- if (ShapeUtil::Rank(param->shape()) > max_rank) {
- max_rank = ShapeUtil::Rank(param->shape());
- max_rank_layout = &param->shape().layout();
- }
- }
- return absl::c_all_of(params, [&](HloInstruction* param) {
- return (ShapeUtil::Rank(param->shape()) < max_rank) ||
- (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout));
- });
-}
-
-} // namespace
-
bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) {
// We can fuse reduces and loop fusions. Elementwise instructions can be fused
// with any other instruction.
// TODO(b/112957171): This should use the same isFusible logic as
// instruction_fusion.
return instr->IsFusible() &&
- (IsInputFusibleReduction(instr) ||
+ (IsInputFusibleReduction(*instr) ||
(instr->opcode() == HloOpcode::kFusion &&
instr->fusion_kind() == HloInstruction::FusionKind::kLoop) ||
instr->IsElementwise());
@@ -219,7 +166,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
VLOG(3) << consumer->name() << " has no users.";
continue;
}
- if (!IsInputFusibleReduction(consumer)) {
+ if (!IsInputFusibleReduction(*consumer)) {
VLOG(3) << consumer->name() << " is not an input-fusible reduction.";
continue;
}
@@ -244,7 +191,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
VLOG(3) << producer->name() << " has an incompatible shape.";
continue;
}
- if (!ReduceFriendlyInputLayouts(producer)) {
+ if (!LayoutsAreReduceInputFusionFriendly(*producer, *consumer)) {
VLOG(3) << producer->name() << " has inputs with mixed layouts.";
continue;
}
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 695feadb11..f6325b3368 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -36,7 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
-#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
@@ -45,9 +44,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
-#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
@@ -208,9 +207,11 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
HloPassPipeline pipeline("conv_canonicalization");
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
- // TODO(b/31709653): Directly use the grouped convolution support of Cudnn.
- pipeline.AddPass<ConvolutionFeatureGroupConverter>();
pipeline.AddPass<CudnnConvolutionRewriter>();
+ // CudnnConvolutionRewriter may add instructions of the form
+ // reverse(constant), which it expects will be simplified by constant
+ // folding.
+ pipeline.AddPass<HloConstantFolding>();
pipeline.AddPass<PadInsertion>();
if (IsVoltaOrLater(*stream_exec)) {
pipeline.AddPass<PadForTensorCores>();
@@ -565,8 +566,8 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
// must also be used to determine the thunk launch schedule.
std::unique_ptr<StreamAssignment> stream_assignment = AssignStreams(*module);
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<HloSchedule> hlo_schedule,
- HloSchedule::Build(*module, *stream_assignment, pointer_size_));
+ std::unique_ptr<GpuHloSchedule> hlo_schedule,
+ GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_));
// Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation.
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
index 08ef6ef56c..8e97774750 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
@@ -21,12 +21,12 @@ limitations under the License.
#include <vector>
#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
index 79f7d31816..fa84d77223 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
@@ -23,7 +23,6 @@ limitations under the License.
namespace xla {
namespace gpu {
-using tensorflow::gtl::ArraySlice;
// We want the input/output feature counts of an f16 conv to be factors of 8,
// because without this cudnn can't use tensor cores on the conv.
@@ -42,7 +41,7 @@ static constexpr double kMaxBytesTouchedIncrease = 1.2;
// Pads the given dimensions in the given shape up to a multiple of
// kDesiredNumFeaturesFactor.
-static Shape PadShape(Shape s, ArraySlice<int64> dims) {
+static Shape PadShape(Shape s, absl::Span<const int64> dims) {
for (int64 dim : dims) {
int64 dim_to_pad_size = s.dimensions(dim);
int64 new_dim_to_pad_size =
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
index 104af48c82..5c92b0dcb8 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
@@ -29,12 +29,7 @@ namespace {
namespace op = xla::testing::opcode_matchers;
using ::testing::_;
-class PadForTensorCoresTest : public HloVerifiedTestBase {
- public:
- PadForTensorCoresTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-};
+class PadForTensorCoresTest : public HloVerifiedTestBase {};
TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) {
ParseAndVerifyModule(R"(
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 98cc21ccac..9d85d746d8 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -166,9 +166,9 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) {
Shape old_conv_shape = conv->shape().tuple_shapes(0);
VLOG(1) << "Canonicalizing forward conv";
- auto new_conv = CreateCudnnConvForward(old_conv_shape, new_input, new_kernel,
- new_conv_window,
- conv->convolution_dimension_numbers());
+ auto new_conv = CreateCudnnConvForward(
+ old_conv_shape, new_input, new_kernel, new_conv_window,
+ conv->convolution_dimension_numbers(), conv->feature_group_count());
VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n "
<< new_conv->ToString();
TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
@@ -247,7 +247,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter(
backward_conv_shape, padded_input, output, new_backward_conv_window,
- backward_conv_dnums);
+ backward_conv_dnums, backward_conv->feature_group_count());
VLOG(1) << "Canonicalizing backward filter conv";
VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
@@ -312,7 +312,7 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput(
new_backward_conv_shape, output, filter, new_backward_conv_window,
- backward_conv_dnums);
+ backward_conv_dnums, backward_conv->feature_group_count());
// The CustomCall created above returns a tuple (conv_result, scratch_memory).
// Extract out the two elements.
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
index ca57cacb98..8154d75d23 100644
--- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
@@ -40,7 +40,7 @@ ParallelLoopEmitter::ParallelLoopEmitter(
ParallelLoopEmitter::ParallelLoopEmitter(
const llvm_ir::ElementGenerator& target_element_generator,
- tensorflow::gtl::ArraySlice<llvm_ir::IrArray> target_arrays,
+ absl::Span<const llvm_ir::IrArray> target_arrays,
const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b,
int unroll_factor)
: LoopEmitter(target_element_generator, target_arrays, b),
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
index cc7da2e73b..f32ea1ce4c 100644
--- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
@@ -47,11 +47,10 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
//
// This is used in multi-output fusion. target_element_generator should
// produce a struct with N elements, one for each of target_arrays.
- ParallelLoopEmitter(
- const llvm_ir::ElementGenerator& target_element_generator,
- tensorflow::gtl::ArraySlice<llvm_ir::IrArray> target_arrays,
- const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b,
- int unroll_factor = 1);
+ ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator,
+ absl::Span<const llvm_ir::IrArray> target_arrays,
+ const LaunchDimensions& launch_dimensions,
+ llvm::IRBuilder<>* b, int unroll_factor = 1);
ParallelLoopEmitter(const ParallelLoopEmitter&) = delete;
ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete;
diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
index 05b305ea4c..08ff52211a 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace gpu {
@@ -53,8 +55,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
input_layout.push_back(dnums.input_feature_dimension());
break;
default:
- return tensorflow::errors::Internal("Invalid input layout: ",
- DataLayoutString(input));
+ return InternalError("Invalid input layout %s for conv with dnums %s",
+ DataLayoutString(input),
+ ConvolutionDimensionNumbersToString(dnums));
}
std::vector<int64> filter_layout;
@@ -74,8 +77,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
filter_layout.push_back(dnums.kernel_input_feature_dimension());
break;
default:
- return tensorflow::errors::Internal("Invalid filter layout: ",
- FilterLayoutString(filter));
+ return InternalError("Invalid filter layout %s for conv with dnums %s",
+ FilterLayoutString(filter),
+ ConvolutionDimensionNumbersToString(dnums));
}
std::vector<int64> output_layout;
@@ -95,8 +99,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
output_layout.push_back(dnums.output_feature_dimension());
break;
default:
- return tensorflow::errors::Internal("Invalid output layout: ",
- DataLayoutString(output));
+ return InternalError("Invalid output layout %s for conv with dnums %s",
+ DataLayoutString(output),
+ ConvolutionDimensionNumbersToString(dnums));
}
return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout),
@@ -128,8 +133,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
} else if (LayoutUtil::Equal(input, nhwc_input)) {
input_layout = DataLayout::kBatchYXDepth;
} else {
- return tensorflow::errors::Internal("Invalid input layout: ",
- input.ShortDebugString());
+ return InternalError("Invalid input layout %s for conv with dnums %s",
+ LayoutUtil::HumanString(input),
+ ConvolutionDimensionNumbersToString(dnums));
}
FilterLayout filter_layout;
@@ -138,8 +144,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
} else if (LayoutUtil::Equal(filter, nhwc_filter)) {
filter_layout = FilterLayout::kOutputYXInput;
} else {
- return tensorflow::errors::Internal("Invalid filter layout: ",
- filter.ShortDebugString());
+ return InternalError("Invalid filter layout %s for conv with dnums %s",
+ LayoutUtil::HumanString(filter),
+ ConvolutionDimensionNumbersToString(dnums));
}
DataLayout output_layout;
@@ -148,8 +155,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
} else if (LayoutUtil::Equal(output, nhwc_output)) {
output_layout = DataLayout::kBatchYXDepth;
} else {
- return tensorflow::errors::Internal("Invalid output layout: ",
- output.ShortDebugString());
+ return InternalError("Invalid output layout %s for conv with dnums %s",
+ LayoutUtil::HumanString(output),
+ ConvolutionDimensionNumbersToString(dnums));
}
return std::make_tuple(input_layout, filter_layout, output_layout);
diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
index 2d5735d6c4..dcdbf2cf3c 100644
--- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
@@ -18,12 +18,12 @@ limitations under the License.
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
@@ -34,8 +34,7 @@ namespace gpu {
// issue (b/31336476).
class TupleThunk : public Thunk {
public:
- TupleThunk(tensorflow::gtl::ArraySlice<BufferAllocation::Slice>
- tuple_element_buffers,
+ TupleThunk(absl::Span<const BufferAllocation::Slice> tuple_element_buffers,
const BufferAllocation::Slice& dest_buffer,
const HloInstruction* hlo_instruction)
: Thunk(Kind::kTuple, hlo_instruction),
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h
index 1fea544730..e345804537 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_buffer.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -29,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
index da94ab5346..54abe3345d 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
@@ -39,15 +39,17 @@ namespace {
using ::testing::UnorderedElementsAre;
-class HloAliasAnalysisTest : public HloTestBase {
+class HloAliasAnalysisTest : public HloVerifiedTestBase {
protected:
- HloAliasAnalysisTest() : module_(CreateNewModule()) {}
+ HloAliasAnalysisTest() : HloVerifiedTestBase() {
+ module_ = CreateNewModule();
+ }
// Run alias analysis on the member module. For convenience returns a
// reference to the generated analysis stored in analysis_.
HloAliasAnalysis& RunAnalysis() {
hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis");
- analysis_ = HloAliasAnalysis::Run(module_.get(),
+ analysis_ = HloAliasAnalysis::Run(module_,
/*fusion_can_share_buffer=*/nullptr)
.ConsumeValueOrDie();
return *analysis_;
@@ -91,7 +93,7 @@ class HloAliasAnalysisTest : public HloTestBase {
// never occurs, but HLO graphs with interference can be explicitly
// constructed.
bool AnyValuesInSameBufferInterfere() {
- DependencyHloOrdering ordering(module_.get());
+ DependencyHloOrdering ordering(module_);
for (const HloBuffer& buffer : analysis_->buffers()) {
for (const HloValue* value_a : buffer.values()) {
for (const HloValue* value_b : buffer.values()) {
@@ -108,7 +110,7 @@ class HloAliasAnalysisTest : public HloTestBase {
return false;
}
- std::unique_ptr<HloModule> module_;
+ HloModule* module_;
std::unique_ptr<HloAliasAnalysis> analysis_;
const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
@@ -461,7 +463,7 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) {
module_->AddEntryComputation(builder.Build());
FlattenCallGraph flattener;
- TF_ASSERT_OK(flattener.Run(module_.get()).status());
+ TF_ASSERT_OK(flattener.Run(module_).status());
const HloAliasAnalysis& analysis = RunAnalysis();
@@ -835,7 +837,7 @@ TEST_F(HloAliasAnalysisTest, BitcastInterference) {
const HloAliasAnalysis& analysis = RunAnalysis();
- DependencyHloOrdering ordering(module_.get());
+ DependencyHloOrdering ordering(module_);
EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering));
}
@@ -877,7 +879,7 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) {
{
// Dependency ordering should interfere because the negate and while are
// unordered.
- DependencyHloOrdering ordering(module_.get());
+ DependencyHloOrdering ordering(module_);
EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering));
}
@@ -888,13 +890,13 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) {
sequence[condition] = {cond_param, cond_root};
{
sequence[entry] = {init, xla_while, negate, entry_root};
- SequentialHloOrdering ordering(module_.get(), sequence);
+ SequentialHloOrdering ordering(module_, sequence);
EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering));
}
{
sequence[entry] = {init, negate, xla_while, entry_root};
- SequentialHloOrdering ordering(module_.get(), sequence);
+ SequentialHloOrdering ordering(module_, sequence);
EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering));
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_buffer.h b/tensorflow/compiler/xla/service/hlo_buffer.h
index 4873463b2e..a88c87e46c 100644
--- a/tensorflow/compiler/xla/service/hlo_buffer.h
+++ b/tensorflow/compiler/xla/service/hlo_buffer.h
@@ -84,7 +84,7 @@ class HloBuffer {
return a->id() == b->id();
}
- HloBuffer(Id id, tensorflow::gtl::ArraySlice<const HloValue*> values)
+ HloBuffer(Id id, absl::Span<const HloValue* const> values)
: id_(id), values_(values.begin(), values.end()) {}
// Return the unique identifier for this HloBuffer.
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index c2d0673f49..fe7f2be888 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -558,7 +558,7 @@ HloComputation::CreateFromProto(
}
void HloComputation::FuseInstructionsInto(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
+ absl::Span<HloInstruction* const> instructions_to_fuse,
HloInstruction* fusion_instruction) {
CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
HloInstruction* root = instructions_to_fuse.front();
@@ -577,7 +577,7 @@ void HloComputation::FuseInstructionsInto(
}
HloInstruction* HloComputation::CreateFusionInstruction(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
+ absl::Span<HloInstruction* const> instructions_to_fuse,
HloInstruction::FusionKind fusion_kind) {
HloInstruction* root = instructions_to_fuse.front();
HloInstruction* fusion_instruction = AddInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 59016624f7..fe2d3bbbe5 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -25,6 +25,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
@@ -39,7 +40,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
@@ -237,7 +237,7 @@ class HloComputation {
// removed if they have no uses after fusion (this is necessarily true for at
// least the root).
HloInstruction* CreateFusionInstruction(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
+ absl::Span<HloInstruction* const> instructions_to_fuse,
HloInstruction::FusionKind fusion_kind);
// Create a deep copy of the given instruction and return the instruction
@@ -385,7 +385,7 @@ class HloComputation {
//
// Pre-condition: fusion_instruction's opcode is kFusion.
void FuseInstructionsInto(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
+ absl::Span<HloInstruction* const> instructions_to_fuse,
HloInstruction* fusion_instruction);
// Internal helper for recursive copying of an instruction. Creates and
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index 2ed645c3ae..8a45939c61 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -71,7 +71,8 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
// Broadcasts dramatically increase the size of constants, which is often
// detrimental to performance and memory capacity, so do not fold
// broadcasts.
- if (instruction->opcode() == HloOpcode::kBroadcast) {
+ if (instruction->opcode() == HloOpcode::kBroadcast ||
+ instruction->opcode() == HloOpcode::kIota) {
continue;
}
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index 7cd1481a8a..07cd1efc12 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -105,8 +105,8 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
TEST_F(HloConstantFoldingTest, Concatenate) {
const struct TestConfig {
int concat_dimension;
- tensorflow::gtl::ArraySlice<int64> dimensions;
- tensorflow::gtl::ArraySlice<int64> concat_sizes;
+ absl::Span<const int64> dimensions;
+ absl::Span<const int64> concat_sizes;
} test_configs[] = {
{1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}},
{3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}},
@@ -196,7 +196,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
bool matched = true;
root->literal().EachCell<NativeT>(
- [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT value) {
+ [&](absl::Span<const int64> indices, NativeT value) {
std::vector<int64> rindexes = Permute(permutation, indices);
matched = matched && (value == literal_clone->Get<NativeT>(rindexes));
});
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 0e12a1ee03..939b5114c3 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -274,15 +274,21 @@ Status HloCostAnalysis::HandleMap(const HloInstruction* map) {
}
Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) {
- auto arg = reduce->operand(0);
HloComputation* function = reduce->to_apply();
// Compute the cost of the user function.
TF_ASSIGN_OR_RETURN(const Properties sub_properties,
ProcessSubcomputation(function));
// Compute the cost of all elements for this Reduce operation.
- int64 reduction_count = ShapeUtil::ElementsIn(arg->shape()) -
- ShapeUtil::ElementsIn(reduce->shape());
+ // This counts the number of times the reduction function is applied, so it
+ // does not need to be multiplied by the number of input tensors - that's
+ // already "priced in" by the sub-computation doing more work.
+ auto arg = reduce->operand(0);
+ auto output_shape = ShapeUtil::IsArray(reduce->shape())
+ ? reduce->shape()
+ : reduce->shape().tuple_shapes(0);
+ int64 reduction_count =
+ ShapeUtil::ElementsIn(arg->shape()) - ShapeUtil::ElementsIn(output_shape);
for (const auto& property : sub_properties) {
if (property.first != kBytesAccessedKey) {
current_properties_[property.first] = property.second * reduction_count;
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index c6a2007904..9bb3f12ee2 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -23,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index 131846794d..19ffb465c0 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -24,7 +24,6 @@ limitations under the License.
namespace xla {
using absl::StrCat;
-using tensorflow::gtl::ArraySlice;
StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
HloInstruction* rhs) {
@@ -50,9 +49,9 @@ StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
}
StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
- ArraySlice<int64> start_indices,
- ArraySlice<int64> limit_indices,
- ArraySlice<int64> strides) {
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides) {
HloComputation* computation = operand->parent();
TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape(
operand->shape(), start_indices,
@@ -74,7 +73,7 @@ StatusOr<HloInstruction*> MakeConvolveHlo(
}
StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
- ArraySlice<int64> dimensions) {
+ absl::Span<const int64> dimensions) {
HloComputation* computation = operand->parent();
TF_ASSIGN_OR_RETURN(
Shape transpose_shape,
@@ -91,15 +90,15 @@ StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
}
StatusOr<HloInstruction*> MakeReshapeHlo(
- ArraySlice<int64> result_shape_dim_bounds, HloInstruction* operand) {
+ absl::Span<const int64> result_shape_dim_bounds, HloInstruction* operand) {
Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
result_shape_dim_bounds);
return MakeReshapeHlo(new_shape, operand);
}
-StatusOr<HloInstruction*> MakeDynamicSliceHlo(HloInstruction* operand,
- HloInstruction* start_indices,
- ArraySlice<int64> slice_sizes) {
+StatusOr<HloInstruction*> MakeDynamicSliceHlo(
+ HloInstruction* operand, HloInstruction* start_indices,
+ absl::Span<const int64> slice_sizes) {
HloComputation* computation = operand->parent();
CHECK_EQ(computation, start_indices->parent());
TF_ASSIGN_OR_RETURN(
@@ -125,8 +124,8 @@ StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
}
StatusOr<HloInstruction*> MakeBroadcastHlo(
- HloInstruction* operand, ArraySlice<int64> broadcast_dimensions,
- ArraySlice<int64> result_shape_bounds) {
+ HloInstruction* operand, absl::Span<const int64> broadcast_dimensions,
+ absl::Span<const int64> result_shape_bounds) {
HloComputation* computation = operand->parent();
Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
result_shape_bounds);
@@ -146,8 +145,8 @@ StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
HloInstruction::CreateGetTupleElement(gte_shape, operand, index));
}
-StatusOr<HloInstruction*> MakeConcatHlo(ArraySlice<HloInstruction*> operands,
- int64 dimension) {
+StatusOr<HloInstruction*> MakeConcatHlo(
+ absl::Span<HloInstruction* const> operands, int64 dimension) {
CHECK_GT(operands.size(), 0);
HloComputation* computation = operands[0]->parent();
@@ -176,9 +175,8 @@ StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers));
}
-StatusOr<HloInstruction*> MakeMapHlo(
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* map_computation) {
+StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
+ HloComputation* map_computation) {
CHECK(!operands.empty()) << "Map Hlo requires at least one operand.";
HloComputation* computation = operands.front()->parent();
std::vector<const Shape*> operand_shapes;
@@ -235,7 +233,7 @@ StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
}
StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
- HloInstruction* operand, ArraySlice<int64> expanded_dims) {
+ HloInstruction* operand, absl::Span<const int64> expanded_dims) {
CHECK_GT(operand->shape().dimensions_size(), 0);
CHECK_EQ(operand->shape().dimensions(0), Product(expanded_dims));
@@ -251,8 +249,8 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
return MakeReshapeHlo(new_shape, operand);
}
-StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
- ArraySlice<int64> dims_to_elide) {
+StatusOr<HloInstruction*> ElideDegenerateDims(
+ HloInstruction* operand, absl::Span<const int64> dims_to_elide) {
CHECK(absl::c_is_sorted(dims_to_elide));
const Shape& input_shape = operand->shape();
@@ -277,7 +275,7 @@ StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
}
StatusOr<HloInstruction*> InsertDegenerateDims(
- HloInstruction* operand, ArraySlice<int64> dims_to_insert) {
+ HloInstruction* operand, absl::Span<const int64> dims_to_insert) {
CHECK(absl::c_is_sorted(dims_to_insert));
const Shape& operand_shape = operand->shape();
@@ -327,7 +325,7 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
StatusOr<HloInstruction*> BroadcastZeros(
HloComputation* computation, PrimitiveType element_type,
- ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
HloInstruction* zero =
computation->AddInstruction(HloInstruction::CreateConstant(
absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
@@ -336,7 +334,7 @@ StatusOr<HloInstruction*> BroadcastZeros(
}
StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
- ArraySlice<const Shape*> domain, const Shape& range,
+ absl::Span<const Shape* const> domain, const Shape& range,
absl::string_view name) {
HloComputation::Builder b{string(name)};
int64 param_idx = 0;
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h
index 1bc6d09b45..a1c4b374d1 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h
@@ -40,10 +40,10 @@ StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
// Creates a slice HLO instruction and adds it to the computation containing
// `operand`.
-StatusOr<HloInstruction*> MakeSliceHlo(
- HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
+StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides);
// Creates a convolution HLO instruction and adds it to the computation
// containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
@@ -53,8 +53,8 @@ StatusOr<HloInstruction*> MakeConvolveHlo(
// Creates a transpose HLO instruction and adds it to the computation containing
// `operand`.
-StatusOr<HloInstruction*> MakeTransposeHlo(
- HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions);
+StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
+ absl::Span<const int64> dimensions);
// Creates a reshape HLO instruction and adds it to the computation containing
// `operand`.
@@ -62,15 +62,14 @@ StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
HloInstruction* operand);
StatusOr<HloInstruction*> MakeReshapeHlo(
- tensorflow::gtl::ArraySlice<int64> result_shape_dim_bounds,
- HloInstruction* operand);
+ absl::Span<const int64> result_shape_dim_bounds, HloInstruction* operand);
// Creates a dynamic-slice HLO instruction and adds it to the computation
// containing `operand` and `start_indices` (`operand` and `start_indices` must
// be in the same computation).
StatusOr<HloInstruction*> MakeDynamicSliceHlo(
HloInstruction* operand, HloInstruction* start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
// Creates a dynamic-update-slice HLO instruction and adds it to the computation
// containing `operand`, `update` and `start_indices` (`operand`, `update` and
@@ -82,9 +81,8 @@ StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
// Creates a broadcast HLO instruction and adds it to the computation containing
// `operand`.
StatusOr<HloInstruction*> MakeBroadcastHlo(
- HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions,
- tensorflow::gtl::ArraySlice<int64> result_shape_bounds);
+ HloInstruction* operand, absl::Span<const int64> broadcast_dimensions,
+ absl::Span<const int64> result_shape_bounds);
// Creates a GetTupleElement HLO instruction and adds it to the computation
// containing `operand`.
@@ -95,7 +93,7 @@ StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
// containing `operands` (`operands` must be non-empty and every element must be
// contained in the same computation).
StatusOr<HloInstruction*> MakeConcatHlo(
- tensorflow::gtl::ArraySlice<HloInstruction*> operands, int64 dimension);
+ absl::Span<HloInstruction* const> operands, int64 dimension);
// Creates a Dot HLO instruction and adds it to the computation containing `lhs`
// and `rhs` (both must be in the same computation).
@@ -104,9 +102,8 @@ StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
// Creates a Map HLO instruction and adds it to the computation containing the
// operands. All operands must be in the same computation.
-StatusOr<HloInstruction*> MakeMapHlo(
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* map_computation);
+StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
+ HloComputation* map_computation);
// -----------------------------------------------------------------------------
// Some other miscellaneous helpers to generate common HLO patterns. All of
@@ -138,7 +135,7 @@ StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
// For instance if `operand` has shape f32[200,9,7] and expanded_dims is
// {2,5,20} the result is `operand` reshaped to [2,5,20,9,7].
StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
- HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> expanded_dims);
+ HloInstruction* operand, absl::Span<const int64> expanded_dims);
// Elides (via reshape) a set of degenerate dimensions (dimensions containing
// exactly one element), `dims_to_elide` from `operand`. Every dimension in
@@ -148,7 +145,7 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
// For example if `operand` is of shape f32[19,1,20,1,7,1,9] and dims_to_elide
// is {1,5} then the result is `operand` reshaped to [19,20,1,7,9].
StatusOr<HloInstruction*> ElideDegenerateDims(
- HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dims_to_elide);
+ HloInstruction* operand, absl::Span<const int64> dims_to_elide);
// Inserts (via reshape) a set of degenerate dimensions (dimensions containing
// exactly one element), `dims_to_insert` into `operand`. The dimensions in
@@ -158,7 +155,7 @@ StatusOr<HloInstruction*> ElideDegenerateDims(
// For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is
// {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34].
StatusOr<HloInstruction*> InsertDegenerateDims(
- HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dims_to_insert);
+ HloInstruction* operand, absl::Span<const int64> dims_to_insert);
// Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the
// front and `zeros_to_append` zeros in the back.
@@ -171,12 +168,12 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
// broadcast instruction is emitted into `computation`.
StatusOr<HloInstruction*> BroadcastZeros(
HloComputation* computation, PrimitiveType element_type,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
// Creates a HLO computation that takes arguments of type `domain` and produces
// a value of type `range`.
StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
- tensorflow::gtl::ArraySlice<const Shape*> domain, const Shape& range,
+ absl::Span<const Shape* const> domain, const Shape& range,
absl::string_view name);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
index a8de285d16..eb6affadc8 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
@@ -19,18 +19,17 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace {
-using tensorflow::gtl::ArraySlice;
-class HloCreationUtilsTest : public HloTestBase {
+class HloCreationUtilsTest : public HloVerifiedTestBase {
protected:
- std::unique_ptr<HloModule> CreateModuleWithProgramShape(
- PrimitiveType primitive_type, ArraySlice<int64> input_shape_dims,
- ArraySlice<int64> output_shape_dims, HloInstruction** param,
+ HloModule* CreateModuleWithProgramShape(
+ PrimitiveType primitive_type, absl::Span<const int64> input_shape_dims,
+ absl::Span<const int64> output_shape_dims, HloInstruction** param,
HloComputation** entry_computation) {
Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims);
Shape output_shape =
@@ -48,10 +47,10 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
- S32,
- /*input_shape_dims=*/{2}, /*output_shape_dims=*/{2}, &param,
- &entry_computation);
+ HloModule* module = CreateModuleWithProgramShape(S32,
+ /*input_shape_dims=*/{2},
+ /*output_shape_dims=*/{2},
+ &param, &entry_computation);
TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_1_dims_collapsed,
CollapseFirstNDims(param, 1));
@@ -68,7 +67,7 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
+ HloModule* module = CreateModuleWithProgramShape(
S32,
/*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, &param,
&entry_computation);
@@ -93,10 +92,10 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
- S32,
- /*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 2}, &param,
- &entry_computation);
+ HloModule* module = CreateModuleWithProgramShape(S32,
+ /*input_shape_dims=*/{2},
+ /*output_shape_dims=*/{1, 2},
+ &param, &entry_computation);
TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_1_degenerate_dim_prepended,
PrependDegenerateDims(param, 1));
@@ -114,7 +113,7 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
+ HloModule* module = CreateModuleWithProgramShape(
S32,
/*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 1, 2}, &param,
&entry_computation);
@@ -135,10 +134,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
- S32,
- /*input_shape_dims=*/{}, /*output_shape_dims=*/{1, 1}, &param,
- &entry_computation);
+ HloModule* module = CreateModuleWithProgramShape(S32,
+ /*input_shape_dims=*/{},
+ /*output_shape_dims=*/{1, 1},
+ &param, &entry_computation);
TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended,
PrependDegenerateDims(param, 2));
@@ -155,7 +154,7 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
+ HloModule* module = CreateModuleWithProgramShape(
S32,
/*input_shape_dims=*/{6}, /*output_shape_dims=*/{3, 1, 2}, &param,
&entry_computation);
@@ -177,10 +176,10 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
- S32,
- /*input_shape_dims=*/{2}, /*output_shape_dims=*/{6}, &param,
- &entry_computation);
+ HloModule* module = CreateModuleWithProgramShape(S32,
+ /*input_shape_dims=*/{2},
+ /*output_shape_dims=*/{6},
+ &param, &entry_computation);
TF_ASSERT_OK_AND_ASSIGN(
HloInstruction * zero_padded_param,
@@ -198,10 +197,10 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
- S32,
- /*input_shape_dims=*/{}, /*output_shape_dims=*/{2, 2}, &param,
- &entry_computation);
+ HloModule* module = CreateModuleWithProgramShape(S32,
+ /*input_shape_dims=*/{},
+ /*output_shape_dims=*/{2, 2},
+ &param, &entry_computation);
TF_ASSERT_OK_AND_ASSIGN(
HloInstruction * zeros,
@@ -219,10 +218,10 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
- F32,
- /*input_shape_dims=*/{}, /*output_shape_dims=*/{2, 2}, &param,
- &entry_computation);
+ HloModule* module = CreateModuleWithProgramShape(F32,
+ /*input_shape_dims=*/{},
+ /*output_shape_dims=*/{2, 2},
+ &param, &entry_computation);
TF_ASSERT_OK_AND_ASSIGN(
HloInstruction * zeros,
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 3376d170e6..6a63681996 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -46,8 +46,7 @@ namespace {
//
// In this case, we should be able to reuse p0 and output, although p0 has
// multiple uses.
-bool MultiDynamicSliceUseShareSameIndices(
- tensorflow::gtl::ArraySlice<HloUse> uses) {
+bool MultiDynamicSliceUseShareSameIndices(absl::Span<const HloUse> uses) {
if (uses.empty()) {
return false;
}
@@ -221,7 +220,7 @@ string HloDataflowAnalysis::ToString() const {
bool HloDataflowAnalysis::Phi(
HloInstruction* instruction,
- tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
+ absl::Span<const InstructionValueSet* const> inputs) {
CHECK(ssa_form_);
VLOG(4) << "Phi(" << instruction->name() << ")";
VLOG(5) << "instruction value set = "
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index a1678d4943..e62c1c2ac8 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -25,6 +25,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -34,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@@ -202,7 +202,7 @@ class HloDataflowAnalysis {
// the given instruction. If skip_top_level is true, then the top level of the
// value set of 'instruction' is not modified.
bool Phi(HloInstruction* instruction,
- tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
+ absl::Span<const InstructionValueSet* const> inputs);
// Updates the positions of the HloValues in the output of the given
// instruction. This should be called after the instruction value set of
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index c8e0a9e289..974ab94467 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -29,11 +29,6 @@ namespace xla {
namespace {
class HloDomainTest : public HloVerifiedTestBase {
- public:
- HloDomainTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-
protected:
bool FindUserViaDomainPath(HloInstruction* instruction,
HloInstruction* operand) const {
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
index b9244b8e9e..72006e17e7 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
@@ -151,7 +151,11 @@ StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
}
TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString();
- if (!HasOperandType(hlo, eliminate_type_)) {
+ bool nullary = hlo->operands().empty();
+ bool wrong_element_type = hlo->shape().element_type() == eliminate_type_;
+ bool should_eliminate_type = (nullary && wrong_element_type) ||
+ HasOperandType(hlo, eliminate_type_);
+ if (!should_eliminate_type) {
// If this CHECK fires, then this was an instruction that does not take
// the elimination type as an operand but it does return it. This pass
// does not have a feature to change the output type in that case, so
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 71f91fde93..441dcad000 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -53,7 +53,6 @@ namespace xla {
namespace {
-using tensorflow::gtl::ArraySlice;
template <typename OperandT>
StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
@@ -97,10 +96,11 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
}
auto result = absl::make_unique<Literal>(shape);
- TF_RETURN_IF_ERROR(result->Populate<bool>([&](ArraySlice<int64> multi_index) {
- return compare_op(lhs_literal.Get<OperandT>(multi_index),
- rhs_literal.Get<OperandT>(multi_index));
- }));
+ TF_RETURN_IF_ERROR(
+ result->Populate<bool>([&](absl::Span<const int64> multi_index) {
+ return compare_op(lhs_literal.Get<OperandT>(multi_index),
+ rhs_literal.Get<OperandT>(multi_index));
+ }));
return std::move(result);
}
@@ -127,10 +127,11 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
}
auto result = absl::make_unique<Literal>(shape);
- TF_RETURN_IF_ERROR(result->Populate<bool>([&](ArraySlice<int64> multi_index) {
- return compare_op(lhs_literal.Get<complex64>(multi_index),
- rhs_literal.Get<complex64>(multi_index));
- }));
+ TF_RETURN_IF_ERROR(
+ result->Populate<bool>([&](absl::Span<const int64> multi_index) {
+ return compare_op(lhs_literal.Get<complex64>(multi_index),
+ rhs_literal.Get<complex64>(multi_index));
+ }));
return std::move(result);
}
@@ -194,7 +195,7 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations)
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
- const HloModule& module, ArraySlice<LiteralPtr> arg_literals) {
+ const HloModule& module, absl::Span<const LiteralPtr> arg_literals) {
XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString());
evaluated_.clear();
@@ -211,7 +212,8 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
- const HloComputation& computation, ArraySlice<LiteralPtr> arg_literals) {
+ const HloComputation& computation,
+ absl::Span<const LiteralPtr> arg_literals) {
CHECK(computation.parent() != nullptr);
XLA_VLOG_LINES(
2, "HloEvaluator::Evaluate computation:\n" + computation.ToString());
@@ -228,7 +230,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
- HloInstruction* instruction, ArraySlice<LiteralPtr> arg_literals) {
+ HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals) {
TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
evaluated_.clear();
@@ -390,7 +392,7 @@ Status HloEvaluator::HandleTranspose(HloInstruction* transpose) {
}
Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
- ArraySlice<HloInstruction*> operands(concatenate->operands());
+ absl::Span<HloInstruction* const> operands(concatenate->operands());
// The result concatenate dimension is going to be the sum of all
// concatenate dimensions of the operands taking part of the operation.
const Shape& reference_shape = operands[0]->shape();
@@ -588,7 +590,7 @@ ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices(
// Return an ShapeUtil::IndexIterationSpace that iterates over the output slice
// dimensions while keeping the rest of the output dimensions clamped to 0.
ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices(
- int64 output_rank, ArraySlice<int64> slice_sizes,
+ int64 output_rank, absl::Span<const int64> slice_sizes,
const GatherDimensionNumbers& dim_numbers) {
std::vector<int64> index_base(output_rank, 0);
std::vector<int64> index_count(output_rank, 1);
@@ -660,12 +662,13 @@ class OutputBatchIndexToInputIndex {
// index_vector_index_ and index_vector on every invocation, we reuse the
// same storage for all invocations.
//
- // This returns an arrayslice into memory owned by the class.
- StatusOr<ArraySlice<int64>> operator()(ArraySlice<int64> output_index) {
+ // This returns a Span into memory owned by the class.
+ StatusOr<absl::Span<const int64>> operator()(
+ absl::Span<const int64> output_index) {
PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index);
TF_RETURN_IF_ERROR(FetchIndexVector());
PropagateIndexVectorToInputIndex();
- return ArraySlice<int64>(input_index_);
+ return absl::Span<const int64>(input_index_);
}
private:
@@ -674,7 +677,7 @@ class OutputBatchIndexToInputIndex {
// update the dim_numbers.index_vector_dim() dimension -- that's the dimension
// we iterate over in FetchIndexVector.
void PropagateOutputIndexGatherDimsToIndexVectorIndex(
- ArraySlice<int64> output_index) {
+ absl::Span<const int64> output_index) {
int64 index_vector_index_i = 0;
for (int64 i = 0, e = output_index.size(); i < e; i++) {
if (!output_dim_is_batch_dims_[i]) {
@@ -729,7 +732,7 @@ class OutputBatchIndexToInputIndex {
// The index vector fetched from start_indices_.
std::vector<int64> index_vector_;
- // The result computed by this functor. operator() returns an ArraySlice into
+ // The result computed by this functor. operator() returns a Span into
// this vector.
std::vector<int64> input_index_;
@@ -778,10 +781,11 @@ class OutputOffsetIndexToInputIndex {
// gather input index on every invocation we reuse the same storage for the
// result (input_index_), mutating it in place.
//
- // This returns an arrayslice into memory owned by the class.
- StatusOr<ArraySlice<int64>> operator()(ArraySlice<int64> output_index) {
+ // This returns a Span into memory owned by the class.
+ StatusOr<absl::Span<const int64>> operator()(
+ absl::Span<const int64> output_index) {
PropagateOutputIndexWindowDimsToInputIndex(output_index);
- return ArraySlice<int64>(input_index_);
+ return absl::Span<const int64>(input_index_);
}
// Returns for a given 'input_dim' the corresponding output dimension index,
@@ -794,7 +798,7 @@ class OutputOffsetIndexToInputIndex {
// Propagates window dimensions from the output index to input_index_ by
// mutating input_index_ in place.
void PropagateOutputIndexWindowDimsToInputIndex(
- ArraySlice<int64> output_index) {
+ absl::Span<const int64> output_index) {
for (int64 i = 0, e = input_index_.size(); i < e; i++) {
if (input_dim_value_to_output_index_[i] != -1) {
input_index_[i] = output_index[input_dim_value_to_output_index_[i]];
@@ -810,7 +814,7 @@ class OutputOffsetIndexToInputIndex {
// PropagateOutputIndexWindowDimsToInputIndex.
std::vector<int64> input_dim_value_to_output_index_;
- // The result computed by this functor. operator() returns an ArraySlice into
+ // The result computed by this functor. operator() returns a Span into
// this vector.
std::vector<int64> input_index_;
};
@@ -872,11 +876,11 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
const Shape& operand_shape = operand.shape();
auto gather_inner_loop_body =
- [&](ArraySlice<int64> output_window_index,
- ArraySlice<int64> input_gather_index,
- ArraySlice<int64> output_gather_index) -> StatusOr<bool> {
+ [&](absl::Span<const int64> output_window_index,
+ absl::Span<const int64> input_gather_index,
+ absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
TF_ASSIGN_OR_RETURN(
- ArraySlice<int64> input_window_index,
+ absl::Span<const int64> input_window_index,
output_offset_index_to_input_index(output_window_index));
for (int i = 0, e = output_index.size(); i < e; i++) {
output_index[i] = output_gather_index[i] + output_window_index[i];
@@ -909,8 +913,8 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
};
auto gather_outer_loop_body =
- [&](ArraySlice<int64> output_gather_index) -> StatusOr<bool> {
- TF_ASSIGN_OR_RETURN(ArraySlice<int64> input_gather_index,
+ [&](absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
+ TF_ASSIGN_OR_RETURN(absl::Span<const int64> input_gather_index,
output_batch_index_to_input_index(output_gather_index));
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
shape, offset_indices_iteration_space,
@@ -1170,12 +1174,11 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
result_values.push_back(key_value.second);
}
auto result_keys_literal = absl::make_unique<Literal>(keys_literal.shape());
- result_keys_literal->PopulateR1(
- tensorflow::gtl::ArraySlice<KeyType>(result_keys));
+ result_keys_literal->PopulateR1(absl::Span<const KeyType>(result_keys));
auto result_values_literal =
absl::make_unique<Literal>(values_literal.shape());
result_values_literal->PopulateR1(
- tensorflow::gtl::ArraySlice<ValueType>(result_values));
+ absl::Span<const ValueType>(result_values));
return std::make_pair(std::move(result_keys_literal),
std::move(result_values_literal));
};
@@ -1262,7 +1265,7 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) {
const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape());
if (sort_dim != rank - 1) {
return Unimplemented(
- "Trying to support along dimension %d, which is not the last "
+ "Trying to sort along dimension %d, which is not the last "
"dimension",
sort_dim);
}
@@ -1281,6 +1284,22 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) {
}
}
+Status HloEvaluator::HandleReduce(HloInstruction* reduce) {
+ if (!ShapeUtil::IsTuple(reduce->shape())) {
+ return DefaultAction(reduce);
+ } else {
+ auto first_element_type = reduce->shape().tuple_shapes(0).element_type();
+ for (const auto& tuple_shape : reduce->shape().tuple_shapes()) {
+ if (tuple_shape.element_type() != first_element_type) {
+ return Unimplemented(
+ "Reduce with several outputs that have mixed element types is "
+ "unsupported");
+ }
+ }
+ return reduce->Visit(typed_visitors_.at(first_element_type).get());
+ }
+}
+
Status HloEvaluator::Preprocess(HloInstruction* hlo) {
VLOG(2) << "About to visit HLO: " << hlo->ToString();
return ShapeUtil::ValidateShape(hlo->shape());
@@ -1295,26 +1314,27 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) {
// Explicit instantiation of templatized Evaluate* methods.
//
template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<const Literal*>(const HloModule& module,
- ArraySlice<const Literal*> arg_literals);
+HloEvaluator::Evaluate<const Literal*>(
+ const HloModule& module, absl::Span<const Literal* const> arg_literals);
template StatusOr<std::unique_ptr<Literal>>
HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
- const HloModule& module, ArraySlice<std::unique_ptr<Literal>> arg_literals);
+ const HloModule& module,
+ absl::Span<const std::unique_ptr<Literal>> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<const Literal*>(const HloComputation& computation,
- ArraySlice<const Literal*> arg_literals);
+template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate<
+ const Literal*>(const HloComputation& computation,
+ absl::Span<const Literal* const> arg_literals);
template StatusOr<std::unique_ptr<Literal>>
HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
const HloComputation& computation,
- ArraySlice<std::unique_ptr<Literal>> arg_literals);
+ absl::Span<const std::unique_ptr<Literal>> arg_literals);
template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<const Literal*>(HloInstruction* instruction,
- ArraySlice<const Literal*> arg_literals);
+HloEvaluator::Evaluate<const Literal*>(
+ HloInstruction* instruction, absl::Span<const Literal* const> arg_literals);
template StatusOr<std::unique_ptr<Literal>>
HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
HloInstruction* instruction,
- ArraySlice<std::unique_ptr<Literal>> arg_literals);
+ absl::Span<const std::unique_ptr<Literal>> arg_literals);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 0ea7089552..c2d49e56ac 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include "absl/memory/memory.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
@@ -51,8 +51,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// type.
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> Evaluate(
- const HloModule& module,
- tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals);
+ const HloModule& module, absl::Span<const LiteralPtr> arg_literals);
// Evaluates an HLO computation and an array of pointers to literals.
// Returns the evaluated result as a literal if successful.
@@ -75,7 +74,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> Evaluate(
const HloComputation& computation,
- tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals);
+ absl::Span<const LiteralPtr> arg_literals);
// Evaluates a single HLO instruction and an array of pointers to literals.
// Return the evaluated result as literal if successful.
@@ -87,8 +86,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// type.
template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> Evaluate(
- HloInstruction* instruction,
- tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals);
+ HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals);
// Evaluates a single HLO instruction with constant operands.
// Returns the evaluated result as literal if successful.
@@ -185,6 +183,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleSort(HloInstruction* sort) override;
+ Status HandleReduce(HloInstruction* reduce) override;
+
// Returns the already-evaluated literal result for the instruction.
// A Constant instruction is considered evaluated and its literal will be
// returned directly without looking up the cache.
@@ -227,8 +227,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
}
auto result = absl::make_unique<Literal>(shape);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ TF_RETURN_IF_ERROR(
+ result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return unary_op(operand_literal.Get<NativeT>(multi_index));
}));
return std::move(result);
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index c3af15c6a8..7e490d7f32 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -60,7 +60,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
}
std::unique_ptr<Literal> Evaluate(
- tensorflow::gtl::ArraySlice<const Literal*> arg_literals = {}) {
+ absl::Span<const Literal* const> arg_literals = {}) {
if (use_bfloat16_) {
// In BF16 mode, we convert all F32 type to BF16 and evaluate the module.
auto type_converter = HloElementTypeConverter(F32, BF16);
@@ -344,7 +344,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
result->EachCell<NativeT>(
- [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT value) {
+ [&](absl::Span<const int64> indices, NativeT value) {
std::vector<int64> rindexes = Permute(permutation, indices);
EXPECT_NEAR(value, literal_clone->Get<NativeT>(rindexes), 0.031250);
});
@@ -935,7 +935,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
// clang-format off
// Result dimensions: [feature=1, height=1, batch=1, width=2]
Array4D<float> expected_array({{{{2514, 2685}}}});
- Array4D<float> expected_array_bf16({{{{2512, 2672}}}});
+ Array4D<float> expected_array_bf16({{{{2512, 2688}}}});
// clang-format on
auto expected = LiteralUtil::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
@@ -1012,7 +1012,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
// clang-format off
// Result dimensions: [feature=1, height=1, batch=1, width=2]
Array4D<float> expected_array({{{{2514, 2685}}}});
- Array4D<float> expected_array_bf16({{{{2512, 2672}}}});
+ Array4D<float> expected_array_bf16({{{{2512, 2688}}}});
// clang-format on
auto expected = LiteralUtil::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
@@ -1219,12 +1219,7 @@ TEST_P(HloEvaluatorTest,
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
-class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {
- public:
- HloEvaluatorPreciseReduceTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-};
+class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
// Tests that Reduce doesn't lose precision when adding many numbers (because
// it accumulates its result in a double).
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index f682e69ee9..cb27e13e99 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -97,7 +97,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
double GetAsDouble(const Literal& literal,
- tensorflow::gtl::ArraySlice<int64> input_index) {
+ absl::Span<const int64> input_index) {
return static_cast<double>(literal.Get<NativeT>(input_index));
}
@@ -109,7 +109,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
typename NativeT,
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
double GetAsDouble(const Literal& literal,
- tensorflow::gtl::ArraySlice<int64> input_index) {
+ absl::Span<const int64> input_index) {
LOG(FATAL) << "Trying to get complex literal as double: "
<< literal.ToString();
}
@@ -980,8 +980,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
auto result = absl::make_unique<Literal>(result_shape);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> out_index) {
+ TF_RETURN_IF_ERROR(
+ result->Populate<ReturnT>([&](absl::Span<const int64> out_index) {
std::vector<int64> from_index(out_index.begin(), out_index.end());
for (const int64 dim : reverse_dimensions) {
from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
@@ -1048,8 +1048,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window,
&lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data,
- rhs_literal_data](
- tensorflow::gtl::ArraySlice<int64> out_index) {
+ rhs_literal_data](absl::Span<const int64> out_index) {
// Dimension number applicable for input (lhs).
const int64 input_batch_dim = dnums.input_batch_dimension();
const int64 input_z_dim = dnums.input_feature_dimension();
@@ -1130,7 +1129,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
static_cast<ElementwiseT>(rhs_literal_data[rhs_linear_index]);
}
cnt : {}
- } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));
+ } while (IndexUtil::BumpIndices(window_shape,
+ absl::MakeSpan(rhs_spatial_index)));
return static_cast<ReturnT>(result_val);
};
@@ -1198,20 +1198,20 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Then we have the LHS and RHS non-contracting dimensions, if any:
for (int64 i = 0; i < lhs_rank; i++) {
if (i != lhs_contracting_dimension &&
- !ArrayContains(AsInt64Slice(dnums.lhs_batch_dimensions()), i)) {
+ !absl::c_linear_search(dnums.lhs_batch_dimensions(), i)) {
result_index_locations.push_back({&lhs_index[i], nullptr});
}
}
for (int64 i = 0; i < rhs_rank; i++) {
if (i != rhs_contracting_dimension &&
- !ArrayContains(AsInt64Slice(dnums.rhs_batch_dimensions()), i)) {
+ !absl::c_linear_search(dnums.rhs_batch_dimensions(), i)) {
result_index_locations.push_back({&rhs_index[i], nullptr});
}
}
auto result = absl::make_unique<Literal>(dot->shape());
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> result_index) {
+ TF_RETURN_IF_ERROR(
+ result->Populate<ReturnT>([&](absl::Span<const int64> result_index) {
ElementwiseT result_val = static_cast<ElementwiseT>(0);
for (int64 i = 0; i < result_index.size(); i++) {
@@ -1260,9 +1260,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({});
auto result = absl::make_unique<Literal>(pad->shape());
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&scalar](tensorflow::gtl::ArraySlice<int64> multi_index) {
- return scalar;
- }));
+ [&scalar](absl::Span<const int64> multi_index) { return scalar; }));
const Literal& evaluated_operand =
parent_->GetEvaluatedLiteralFor(pad->operand(0));
@@ -1275,7 +1273,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// corresponding index of the resulting padded literal.
const PaddingConfig& pad_config = pad->padding_config();
- auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) {
+ auto func = [&](absl::Span<const int64> input_index) {
for (auto i = 0; i < input_index.size(); ++i) {
// Interior padding occurs logically before edge padding, so in the case
// of negative edge padding elements are removed from the
@@ -1426,8 +1424,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto result = absl::make_unique<Literal>(map->shape());
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ TF_RETURN_IF_ERROR(
+ result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
std::vector<std::unique_ptr<Literal>> arg_literals;
arg_literals.reserve(operands.size());
@@ -1538,8 +1536,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return SafeLess<ReturnT>(a, b);
});
auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
- result_literal->PopulateR1(
- tensorflow::gtl::ArraySlice<ReturnT>(result_data));
+ result_literal->PopulateR1(absl::Span<const ReturnT>(result_data));
VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
return result_literal;
};
@@ -1577,20 +1574,20 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleSort<ReturnT>(sort);
}
- Status HandleReduce(HloInstruction* reduce) override {
- // TODO(b/112040122): Support variadic reduce.
- if (!ShapeUtil::IsArray(reduce->shape())) {
- return Unimplemented("Variadic reduce is not supported in the Evaluator");
- }
- auto arg = reduce->operand(0);
- auto init_value = reduce->operand(1);
- tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ Status HandleReduce(HloInstruction* hlo) override {
+ HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo);
+ int64 num_args = reduce->inputs().size();
+ bool has_tuple_output = ShapeUtil::IsTuple(reduce->shape());
+ absl::Span<const int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
- TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) ==
- ShapeUtil::Rank(arg->shape()) - dimensions.size());
+
+ absl::InlinedVector<const Shape*, 1> operand_shapes;
+ for (const HloInstruction* operand : reduce->operands()) {
+ operand_shapes.push_back(&operand->shape());
+ }
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
ShapeInference::InferReduceShape(
- {&arg->shape(), &init_value->shape()},
+ operand_shapes,
/*dimensions_to_reduce=*/dimensions,
/*to_apply=*/function->ComputeProgramShape()));
TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape))
@@ -1598,14 +1595,23 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
<< " but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
- const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg);
- VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString();
- const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value);
- VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString();
- TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
- auto init_scalar = init_literal.Get<ReturnT>({});
+ absl::InlinedVector<const Literal*, 1> arg_literals(num_args);
+ absl::InlinedVector<const Literal*, 1> init_literals(num_args);
+ for (int64 i = 0; i < num_args; ++i) {
+ arg_literals[i] = &parent_->GetEvaluatedLiteralFor(reduce->inputs()[i]);
+ VLOG(3) << "HandleReduce arg_literal: " << arg_literals[i]->ToString();
+ init_literals[i] =
+ &parent_->GetEvaluatedLiteralFor(reduce->init_values()[i]);
+ VLOG(3) << "HandleReduce init_literal: " << init_literals[i]->ToString();
+ TF_RET_CHECK(ShapeUtil::IsScalar(init_literals[i]->shape()));
+ }
- const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions());
+ // All args and results have the same dimensions, so pick an arbitrary one.
+ const Shape& arg_shape = arg_literals[0]->shape();
+ const Shape& result_shape = ShapeUtil::IsTuple(reduce->shape())
+ ? reduce->shape().tuple_shapes(0)
+ : reduce->shape();
+ const auto arg_dimensions = AsInt64Slice(arg_shape.dimensions());
std::vector<int64> arg_dim_steps(arg_dimensions.size());
std::vector<int64> arg_dim_counts(arg_dimensions.size());
for (const int64 dim : dimensions) {
@@ -1623,63 +1629,109 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- auto result = absl::make_unique<Literal>(reduce->shape());
+ absl::InlinedVector<std::unique_ptr<Literal>, 1> results(num_args);
+ for (int64 i = 0; i < num_args; ++i) {
+ results[i] = absl::make_unique<Literal>(result_shape);
+ }
+
Status eval_status;
- // For each resulting dimension, calculate and assign computed value.
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
- ReturnT result_val = init_scalar;
- if (!eval_status.ok()) {
- return result_val;
- }
+ // For each resulting dimension, calculate and assign computed values.
+ // This is really wasteful when num_args > 1, since we re-run the
+ // reduction num_args time. The alternative is to teach Populate() about
+ // tuples, which we should probably do.
+ absl::InlinedVector<ReturnT, 1> init_scalars(num_args);
+ for (int i = 0; i < num_args; ++i) {
+ init_scalars[i] = init_literals[i]->Get<ReturnT>({});
+ }
- std::vector<int64> base(arg_dimensions.size());
- for (int64 i = 0; i < multi_index.size(); ++i) {
- base[result_to_arg_index[i]] = multi_index[i];
- }
+ for (int64 input = 0; input < num_args; ++input) {
+ TF_RETURN_IF_ERROR(results[input]->Populate<ReturnT>(
+ [&](absl::Span<const int64> multi_index) {
+ if (!eval_status.ok()) {
+ return init_scalars[input];
+ }
+ absl::InlinedVector<ReturnT, 1> result_values(init_scalars.begin(),
+ init_scalars.end());
+ std::vector<int64> base(arg_dimensions.size());
+ for (int64 i = 0; i < multi_index.size(); ++i) {
+ base[result_to_arg_index[i]] = multi_index[i];
+ }
+
+ // When the reduction is addition of floats, accumulate in a double
+ // for better precision. Also, avoid creating Literals for the
+ // intermediate results; it's much faster.
+ if (ShapeUtil::ElementIsFloating(init_literals[0]->shape()) &&
+ IsScalarAdd(function)) {
+ CHECK_EQ(num_args, 1);
+ double computed_result = 0;
+ auto func = [&](absl::Span<const int64> input_index) {
+ computed_result +=
+ GetAsDouble<ReturnT>(*arg_literals[0], input_index);
+ return true;
+ };
+ ShapeUtil::ForEachIndex(arg_literals[0]->shape(), base,
+ arg_dim_counts, arg_dim_steps, func);
+ return static_cast<ReturnT>(computed_result);
+ }
+ auto func =
+ [&](absl::Span<const int64> input_index) -> StatusOr<bool> {
+ absl::InlinedVector<ReturnT, 1> arg_values(num_args);
+ for (int64 i = 0; i < num_args; ++i) {
+ arg_values[i] = arg_literals[i]->Get<ReturnT>(input_index);
+ }
- // When the reduction is addition of floats, accumulate in a double
- // for better precision. Also, avoid creating Literals for the
- // intermediate results; it's much faster.
- if (ShapeUtil::ElementIsFloating(init_literal.shape()) &&
- IsScalarAdd(function)) {
- double computed_result = 0;
- auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) {
- computed_result += GetAsDouble<ReturnT>(arg_literal, input_index);
+ // Evaluate computation with specified literal operands.
+ absl::InlinedVector<std::unique_ptr<Literal>, 1>
+ embedded_operands;
+ for (ReturnT value : result_values) {
+ embedded_operands.push_back(
+ LiteralUtil::CreateR0<ReturnT>(value));
+ }
+ for (ReturnT value : arg_values) {
+ embedded_operands.push_back(
+ LiteralUtil::CreateR0<ReturnT>(value));
+ }
+ absl::InlinedVector<Literal*, 1> embedded_operands_ptrs(
+ embedded_operands.size());
+ std::transform(embedded_operands.begin(), embedded_operands.end(),
+ embedded_operands_ptrs.begin(),
+ [](const std::unique_ptr<Literal>& ptr) {
+ return ptr.get();
+ });
+
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result,
+ embedded_evaluator.Evaluate<const Literal*>(
+ *function, embedded_operands_ptrs));
+ // Clear visit states so that we can use the evaluator again on
+ // the same computation.
+ embedded_evaluator.ResetVisitStates();
+ // Assign computed result to result_val.
+ if (!has_tuple_output) {
+ result_values[0] = computed_result->Get<ReturnT>({});
+ } else {
+ for (int64 i = 0; i < num_args; ++i) {
+ result_values[i] = computed_result->Get<ReturnT>(
+ /*multi_index=*/{}, /*shape_index=*/{i});
+ }
+ }
return true;
};
- ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts,
- arg_dim_steps, func);
- return static_cast<ReturnT>(computed_result);
- }
- auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index)
- -> StatusOr<bool> {
- auto curr_val = arg_literal.Get<ReturnT>(input_index);
-
- // Evaluate computation with specified literal operands.
- auto curr_val_literal = LiteralUtil::CreateR0<ReturnT>(curr_val);
- auto result_val_literal =
- LiteralUtil::CreateR0<ReturnT>(result_val);
-
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result,
- embedded_evaluator.Evaluate<const Literal*>(
- *function, {result_val_literal.get(),
- curr_val_literal.get()}));
- // Clear visit states so that we can use the evaluator again on
- // the same computation.
- embedded_evaluator.ResetVisitStates();
- // Assign computed result to result_val.
- result_val = computed_result->Get<ReturnT>({});
- return true;
- };
- // Computes one element of the result, reducing all dimensions that
- // contribute to that element.
- eval_status = ShapeUtil::ForEachIndexWithStatus(
- arg_literal.shape(), base, arg_dim_counts, arg_dim_steps, func);
- return result_val;
- }));
-
- parent_->evaluated_[reduce] = std::move(result);
+ // Computes one element of the result, reducing all dimensions that
+ // contribute to that element.
+ eval_status = ShapeUtil::ForEachIndexWithStatus(
+ arg_shape, base, arg_dim_counts, arg_dim_steps, func);
+ return result_values[input];
+ }));
+ }
+ if (!has_tuple_output) {
+ parent_->evaluated_[reduce] = std::move(results[0]);
+ } else {
+ auto tuple_result = absl::make_unique<Literal>(reduce->shape());
+ for (int64 i = 0; i < num_args; ++i) {
+ TF_CHECK_OK(tuple_result->MoveFrom(std::move(*results[i]), {i}));
+ }
+ parent_->evaluated_[reduce] = std::move(tuple_result);
+ }
return eval_status;
}
@@ -1711,9 +1763,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Initialize result array with the init value.
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> output_index) {
- return init_scalar;
- }));
+ [&](absl::Span<const int64> output_index) { return init_scalar; }));
std::vector<int64> window_dimension_sizes;
for (const auto& window_dimension : window.dimensions()) {
@@ -1799,7 +1849,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
embedded_evaluator.ResetVisitStates();
}
});
- } while (IndexUtil::BumpIndices(source->shape(), &source_index));
+ } while (
+ IndexUtil::BumpIndices(source->shape(), absl::MakeSpan(source_index)));
parent_->evaluated_[select_and_scatter] = std::move(result);
return Status::OK();
@@ -1845,8 +1896,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
auto result = absl::make_unique<Literal>(reduce_window->shape());
// For each resulting dimension, calculate and assign computed value.
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> output_index) {
+ TF_RETURN_IF_ERROR(
+ result->Populate<ReturnT>([&](absl::Span<const int64> output_index) {
ReturnT result_val = init_scalar;
std::fill(window_index.begin(), window_index.end(), 0);
@@ -1991,13 +2042,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// index_vector_index_ and index_vector on every invocation, we reuse the
// same storage for all invocations.
//
- // This returns an arrayslice into memory owned by the class.
- StatusOr<tensorflow::gtl::ArraySlice<int64>> operator()(
- tensorflow::gtl::ArraySlice<int64> update_index) {
+ // This returns a Span into memory owned by the class.
+ StatusOr<absl::Span<const int64>> operator()(
+ absl::Span<const int64> update_index) {
PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index);
TF_RETURN_IF_ERROR(FetchIndexVector());
PropagateIndexVectorToInputIndex();
- return tensorflow::gtl::ArraySlice<int64>(input_index_);
+ return absl::Span<const int64>(input_index_);
}
private:
@@ -2006,7 +2057,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// update the dim_numbers.index_vector_dim() dimension -- that's the
// dimension we iterate over in FetchIndexVector.
void PropagateUpdateIndexScatterDimsToIndexVectorIndex(
- tensorflow::gtl::ArraySlice<int64> update_index) {
+ absl::Span<const int64> update_index) {
int64 index_vector_index_i = 0;
for (int64 i = 0, e = update_index.size(); i < e; i++) {
if (!update_dim_is_scatter_dims_[i]) {
@@ -2061,7 +2112,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// The index vector fetched from scatter_indices_.
std::vector<int64> index_vector_;
- // The result computed by this functor. operator() returns an ArraySlice
+ // The result computed by this functor. operator() returns a Span
// into this vector.
std::vector<int64> input_index_;
@@ -2114,11 +2165,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// scatter input index on every invocation we reuse the same storage for the
// result (input_index_), mutating it in place.
//
- // This returns an arrayslice into memory owned by the class.
- StatusOr<tensorflow::gtl::ArraySlice<int64>> operator()(
- tensorflow::gtl::ArraySlice<int64> update_index) {
+ // This returns a Span into memory owned by the class.
+ StatusOr<absl::Span<const int64>> operator()(
+ absl::Span<const int64> update_index) {
PropagateUpdateIndexWindowDimsToInputIndex(update_index);
- return tensorflow::gtl::ArraySlice<int64>(input_index_);
+ return absl::Span<const int64>(input_index_);
}
// Returns for a given 'input_dim' the corresponding update dimension index,
@@ -2131,7 +2182,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Propagates window dimensions from the update index to input_index_ by
// mutating input_index_ in place.
void PropagateUpdateIndexWindowDimsToInputIndex(
- tensorflow::gtl::ArraySlice<int64> update_index) {
+ absl::Span<const int64> update_index) {
for (int64 i = 0, e = input_index_.size(); i < e; i++) {
if (input_dim_value_to_update_index_[i] != -1) {
input_index_[i] = update_index[input_dim_value_to_update_index_[i]];
@@ -2147,7 +2198,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// PropagateUpdateIndexWindowDimsToInputIndex.
std::vector<int64> input_dim_value_to_update_index_;
- // The result computed by this functor. operator() returns an ArraySlice
+ // The result computed by this functor. operator() returns a Span
// into this vector.
std::vector<int64> input_index_;
};
@@ -2190,12 +2241,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::unique_ptr<Literal> result = operand.CloneToUnique();
HloEvaluator embedded_evaluator;
auto scatter_inner_loop_body =
- [&](tensorflow::gtl::ArraySlice<int64> update_window_index,
- tensorflow::gtl::ArraySlice<int64> input_scatter_index,
- tensorflow::gtl::ArraySlice<int64> update_scatter_index)
- -> StatusOr<bool> {
+ [&](absl::Span<const int64> update_window_index,
+ absl::Span<const int64> input_scatter_index,
+ absl::Span<const int64> update_scatter_index) -> StatusOr<bool> {
TF_ASSIGN_OR_RETURN(
- tensorflow::gtl::ArraySlice<int64> input_window_index,
+ absl::Span<const int64> input_window_index,
update_window_index_to_input_index(update_window_index));
for (int i = 0, e = update_index.size(); i < e; i++) {
update_index[i] = update_scatter_index[i] + update_window_index[i];
@@ -2244,14 +2294,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
};
auto scatter_outer_loop_body =
- [&](tensorflow::gtl::ArraySlice<int64> update_scatter_index)
- -> StatusOr<bool> {
+ [&](absl::Span<const int64> update_scatter_index) -> StatusOr<bool> {
TF_ASSIGN_OR_RETURN(
- tensorflow::gtl::ArraySlice<int64> input_scatter_index,
+ absl::Span<const int64> input_scatter_index,
update_scatter_index_to_input_index(update_scatter_index));
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
updates_shape, window_indices_iteration_space,
- [&](tensorflow::gtl::ArraySlice<int64> update_window_index) {
+ [&](absl::Span<const int64> update_window_index) {
return scatter_inner_loop_body(
update_window_index, input_scatter_index, update_scatter_index);
}));
@@ -2279,7 +2328,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const int64 rank = ShapeUtil::Rank(operand->shape());
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
- auto func = [&](tensorflow::gtl::ArraySlice<int64> out_index) {
+ auto func = [&](absl::Span<const int64> out_index) {
DimensionVector operand_index(rank);
for (int64 i = 0; i < rank; ++i) {
operand_index[i] =
@@ -2550,7 +2599,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// bound, call `f` with the base index.
static void IterateThroughWindow(
const Shape& window_shape, const Window& window, const Shape& base_shape,
- const tensorflow::gtl::ArraySlice<int64>& window_count_index,
+ const absl::Span<const int64>& window_count_index,
const std::function<void(const std::vector<int64>&)>& f) {
const int64 rank = ShapeUtil::Rank(base_shape);
DimensionVector window_index(rank);
@@ -2569,7 +2618,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (!out_of_bound) {
f(base_index);
}
- } while (IndexUtil::BumpIndices(window_shape, &window_index));
+ } while (
+ IndexUtil::BumpIndices(window_shape, absl::MakeSpan(window_index)));
}
template <typename IndexT>
@@ -2589,8 +2639,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> operand_indices(start.size());
auto result = absl::make_unique<Literal>(result_shape);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ TF_RETURN_IF_ERROR(
+ result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
for (int64 i = 0; i < operand_indices.size(); ++i) {
CHECK_GE(multi_index[i] + start[i], 0);
operand_indices[i] = multi_index[i] + start[i];
@@ -2621,7 +2671,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
std::vector<int64> result_index(rank, 0);
- auto func = [&](tensorflow::gtl::ArraySlice<int64> update_index) {
+ auto func = [&](absl::Span<const int64> update_index) {
std::transform(update_index.begin(), update_index.end(), start.begin(),
result_index.begin(), std::plus<int64>());
result->Set<ReturnT>(result_index,
@@ -2675,8 +2725,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto result = absl::make_unique<Literal>(shape);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ TF_RETURN_IF_ERROR(
+ result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return ConvertBinaryFunction(binary_op)(
lhs_literal.Get<ReturnT>(multi_index),
rhs_literal.Get<ReturnT>(multi_index));
@@ -2712,8 +2762,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto result = absl::make_unique<Literal>(shape);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ TF_RETURN_IF_ERROR(
+ result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return ternary_op(lhs_literal.Get<LhsType>(multi_index),
rhs_literal.Get<RhsType>(multi_index),
ehs_literal.Get<EhsType>(multi_index));
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index ed4e159910..6d13f85cbb 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -113,7 +113,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
std::vector<int64> fft_length(proto.fft_length().begin(),
proto.fft_length().end());
instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(),
- tensorflow::gtl::ArraySlice<int64>(fft_length));
+ absl::Span<const int64>(fft_length));
break;
}
case HloOpcode::kSend:
@@ -158,16 +158,26 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
CreateConcatenate(proto.shape(), all_operands(), proto.dimensions(0));
break;
case HloOpcode::kReduce:
- TF_RET_CHECK(proto.operand_ids_size() == 2)
- << "Reduce instruction should have 2 operands but sees "
+ TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
+ << "Reduce instruction should have an even number of operands but "
+ "sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Reduce instruction should have 1 called computation but sees "
<< proto.called_computation_ids_size();
- instruction = CreateReduce(proto.shape(), operands(0), operands(1),
- std::vector<int64>(proto.dimensions().begin(),
- proto.dimensions().end()),
- computations(0));
+ {
+ const auto reduce_operands = all_operands();
+ auto inputs = absl::MakeSpan(reduce_operands)
+ .subspan(0, reduce_operands.size() / 2);
+ auto init_values =
+ absl::MakeSpan(reduce_operands)
+ .subspan(reduce_operands.size() / 2, reduce_operands.size());
+ instruction =
+ CreateReduce(proto.shape(), inputs, init_values,
+ std::vector<int64>(proto.dimensions().begin(),
+ proto.dimensions().end()),
+ computations(0));
+ }
break;
case HloOpcode::kSort: {
TF_RET_CHECK(proto.operand_ids_size() == 1 ||
@@ -375,6 +385,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
->set_convolution_dimension_numbers(
proto.convolution_dimension_numbers());
}
+ static_cast<HloCustomCallInstruction*>(instruction.get())
+ ->set_feature_group_count(
+ std::max(static_cast<int64>(proto.feature_group_count()), 1LL));
break;
case HloOpcode::kPad:
TF_RET_CHECK(proto.operand_ids_size() == 2)
@@ -509,13 +522,13 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
const Shape& shape, RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<HloInstruction*> parameters) {
+ absl::Span<HloInstruction* const> parameters) {
return absl::make_unique<HloRngInstruction>(shape, distribution, parameters);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
const Shape& shape, HloOpcode opcode,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ absl::Span<HloInstruction* const> operands) {
if (opcode == HloOpcode::kCopy) {
// It is impossible to copy an opaque shape, we don't know how big it is.
CHECK(!ShapeUtil::IsOpaque(shape));
@@ -617,13 +630,13 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateVariadic(
const Shape& shape, HloOpcode opcode,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ absl::Span<HloInstruction* const> operands) {
CHECK_EQ(HloOpcode::kTuple, opcode);
return CreateNary(shape, opcode, operands);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* map_computation) {
return absl::make_unique<HloMapInstruction>(shape, operands, map_computation);
}
@@ -638,7 +651,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
const Shape& shape, HloInstruction* operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length) {
+ absl::Span<const int64> fft_length) {
return absl::make_unique<HloFftInstruction>(shape, operand, fft_type,
fft_length);
}
@@ -682,7 +695,7 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateCrossReplicaSum(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier,
const absl::optional<int64>& all_reduce_id) {
@@ -692,7 +705,7 @@ HloInstruction::CreateCrossReplicaSum(
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups) {
return absl::make_unique<HloAllToAllInstruction>(shape, operands,
replica_groups);
@@ -754,12 +767,12 @@ HloInstruction::CreateCollectivePermute(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+ absl::Span<const int64> dimensions) {
return absl::make_unique<HloReverseInstruction>(shape, operand, dimensions);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
- tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ absl::Span<HloInstruction* const> operands) {
CHECK(!operands.empty());
auto instruction = absl::WrapUnique(
new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
@@ -805,16 +818,15 @@ HloInstruction::CreateCollectivePermute(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides) {
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
return absl::make_unique<HloSliceInstruction>(shape, operand, start_indices,
limit_indices, strides);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ absl::Span<const int64> slice_sizes) {
return absl::make_unique<HloDynamicSliceInstruction>(
shape, operand, start_indices, slice_sizes);
}
@@ -833,7 +845,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
int64 dimension) {
return absl::make_unique<HloConcatenateInstruction>(shape, operands,
dimension);
@@ -858,7 +870,7 @@ HloInstruction::CreateBitcastConvert(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ absl::Span<const int64> dimensions_to_reduce,
HloComputation* reduce_computation) {
auto instruction = absl::WrapUnique(new HloReduceInstruction(
shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
@@ -866,9 +878,9 @@ HloInstruction::CreateBitcastConvert(const Shape& shape,
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::gtl::ArraySlice<HloInstruction*> init_values,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ absl::Span<HloInstruction* const> init_values,
+ absl::Span<const int64> dimensions_to_reduce,
HloComputation* reduce_computation) {
std::vector<HloInstruction*> all_args;
all_args.reserve(operands.size() * 2);
@@ -926,7 +938,7 @@ HloInstruction::CreateSelectAndScatter(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return absl::make_unique<HloBroadcastInstruction>(shape, operand,
broadcast_dimensions);
}
@@ -1004,7 +1016,7 @@ HloInstruction::CreateBroadcastSequence(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+ absl::Span<const int64> dimensions) {
return absl::make_unique<HloTransposeInstruction>(shape, operand, dimensions);
}
@@ -1022,7 +1034,7 @@ HloInstruction::CreateBroadcastSequence(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
const Shape& shape, FusionKind fusion_kind,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ absl::Span<HloInstruction* const> operands,
HloComputation* fusion_computation) {
return absl::make_unique<HloFusionInstruction>(shape, fusion_kind, operands,
fusion_computation);
@@ -1080,7 +1092,7 @@ bool HloInstruction::HasSideEffect() const {
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* computation) {
std::unique_ptr<HloInstruction> instruction =
absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
@@ -1092,14 +1104,14 @@ bool HloInstruction::HasSideEffect() const {
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target) {
return absl::make_unique<HloCustomCallInstruction>(shape, operands,
custom_call_target);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
- tensorflow::gtl::ArraySlice<HloInstruction*> elements) {
+ absl::Span<HloInstruction* const> elements) {
std::vector<Shape> element_shapes;
for (auto element : elements) {
element_shapes.push_back(element->shape());
@@ -1111,7 +1123,7 @@ bool HloInstruction::HasSideEffect() const {
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ absl::Span<const int64> slice_sizes) {
return absl::make_unique<HloGatherInstruction>(
shape, operand, start_indices, gather_dim_numbers, slice_sizes);
}
@@ -1139,8 +1151,7 @@ bool HloInstruction::HasSideEffect() const {
}
std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
VLOG(3) << "CloneWithNewOperands:\n " << ToString();
VLOG(3) << " new operands:";
@@ -1491,7 +1502,7 @@ void HloInstruction::AppendOperand(HloInstruction* operand) {
}
void HloInstruction::RemoveOperandsAtAscendingIndices(
- tensorflow::gtl::ArraySlice<int> ascending_indices) {
+ absl::Span<const int> ascending_indices) {
if (ascending_indices.empty()) {
return;
}
@@ -1987,7 +1998,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap(
const HloPrintOptions& options,
CanonicalNameMap* canonical_name_map) const {
string operands;
- tensorflow::gtl::ArraySlice<HloInstruction*> slice(operands_);
+ absl::Span<HloInstruction* const> slice(operands_);
const int64 kMaxOperandsToShowIfCompact = 4;
if (options.compact_operands() &&
slice.size() > kMaxOperandsToShowIfCompact) {
@@ -2749,10 +2760,13 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
case HloOpcode::kTranspose:
return UseKind::kUsePermutingElements;
case HloOpcode::kPad:
- case HloOpcode::kReduce:
// Pad reuses the padding value but not the padded array elements.
- // Reduce reuses the init value but not the operand array elements.
return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements;
+ case HloOpcode::kReduce:
+ // Reduce reuses the init values but not the operand array elements.
+ return i >= Cast<HloReduceInstruction>(this)->input_count()
+ ? UseKind::kReuse
+ : UseKind::kUsePermutingElements;
case HloOpcode::kFusion:
// Uses the memoizing, recursive computation defined above.
return FusionReusesParamElements::Compute(i, *fused_expression_root());
@@ -3258,7 +3272,15 @@ void HloInstruction::set_convolution_dimension_numbers(
}
int64 HloInstruction::feature_group_count() const {
- return Cast<HloConvolutionInstruction>(this)->feature_group_count();
+ if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
+ return convolution->feature_group_count();
+ }
+ return Cast<HloCustomCallInstruction>(this)->feature_group_count();
+}
+
+void HloInstruction::set_feature_group_count(int64 feature_group_count) {
+ Cast<HloCustomCallInstruction>(this)->set_feature_group_count(
+ feature_group_count);
}
HloComputation* HloInstruction::select() const {
@@ -3297,7 +3319,7 @@ const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
}
-tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_slice_sizes() const {
+absl::Span<const int64> HloInstruction::gather_slice_sizes() const {
return Cast<HloGatherInstruction>(this)->gather_slice_sizes();
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 4a424cebc0..cca134e8b4 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -36,6 +36,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -49,7 +50,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
#include "tensorflow/core/platform/logging.h"
@@ -365,7 +365,7 @@ class HloInstruction {
// random numbers from a given distribution.
static std::unique_ptr<HloInstruction> CreateRng(
const Shape& shape, RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<HloInstruction*> parameters);
+ absl::Span<HloInstruction* const> parameters);
// Creates a unary instruction (one operand).
// Precondition: opcode must be a legitimate unary operation.
@@ -392,13 +392,13 @@ class HloInstruction {
// Precondition: opcode must be a legitimate variadic operation.
static std::unique_ptr<HloInstruction> CreateVariadic(
const Shape& shape, HloOpcode opcode,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands);
+ absl::Span<HloInstruction* const> operands);
// Creates a map instruction, where the computation (given by the handle) is
// applied element-wise to every element in operands (across the operands,
// at a given index)
static std::unique_ptr<HloInstruction> CreateMap(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* map_computation);
// Creates a convolution op, where rhs is the convolutional filter
@@ -412,7 +412,7 @@ class HloInstruction {
// Creates an FFT op, of the type indicated by fft_type.
static std::unique_ptr<HloInstruction> CreateFft(
const Shape& shape, HloInstruction* operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
+ absl::Span<const int64> fft_length);
// Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
// dimensions specified in 'dimension_numbers'.
@@ -449,7 +449,7 @@ class HloInstruction {
//
// TODO(b/79737069): Rename this to AllReduce.
static std::unique_ptr<HloInstruction> CreateCrossReplicaSum(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups,
absl::string_view barrier, const absl::optional<int64>& all_reduce_id);
@@ -468,7 +468,7 @@ class HloInstruction {
// be concatenated in the order of 1, 2, 3; another Alltoall will be applied
// within replica 4, 5, 0, and the concatenation order is 4, 5, 0.
static std::unique_ptr<HloInstruction> CreateAllToAll(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups);
// Creates a communitation instructions that permutes data cross replicas.
@@ -536,17 +536,15 @@ class HloInstruction {
// start/limit indices.
static std::unique_ptr<HloInstruction> CreateSlice(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices, absl::Span<const int64> strides);
// Creates a slice instruction, where the first operand is sliced by
// start indices specified in the second operand, and by size specified in
// 'slice_sizes'.
static std::unique_ptr<HloInstruction> CreateDynamicSlice(
const Shape& shape, HloInstruction* operand,
- HloInstruction* start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ HloInstruction* start_indices, absl::Span<const int64> slice_sizes);
// Creates a dynamic update slice instruction, which updates a slice
// of 'operand' with 'update' and 'start_indices'.
@@ -557,7 +555,7 @@ class HloInstruction {
// Creates a concatenate instruction, where the operands are concatenated on
// the provided dimension.
static std::unique_ptr<HloInstruction> CreateConcatenate(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
int64 dimension);
// Creates a reduce instruction, where the computation (given by the handle)
@@ -569,7 +567,7 @@ class HloInstruction {
// f(f(init, value0), value1), ...)
static std::unique_ptr<HloInstruction> CreateReduce(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ absl::Span<const int64> dimensions_to_reduce,
HloComputation* reduce_computation);
// A more general, multiple-argument version of the above.
@@ -584,9 +582,9 @@ class HloInstruction {
// ...
// TODO(b/112040122): Add support to this in HLO passes and in backends.
static std::unique_ptr<HloInstruction> CreateReduce(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::gtl::ArraySlice<HloInstruction*> init_values,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ absl::Span<HloInstruction* const> init_values,
+ absl::Span<const int64> dimensions_to_reduce,
HloComputation* reduce_computation);
// Creates a reduce-window instruction, where the computation (given
@@ -623,7 +621,7 @@ class HloInstruction {
// Creates a broadcast instruction.
static std::unique_ptr<HloInstruction> CreateBroadcast(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
// Creates a sequence of instructions that performs an explicit broadcast of
// the operand to the target shape.
@@ -653,7 +651,7 @@ class HloInstruction {
// Creates a transpose instruction which permutes the operand dimensions.
static std::unique_ptr<HloInstruction> CreateTranspose(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const int64> dimensions);
// Creates a sort op, with a keys operand, and an optional values operand.
static std::unique_ptr<HloInstruction> CreateSort(
@@ -679,7 +677,7 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
static std::unique_ptr<HloInstruction> CreateScatter(
const Shape& shape, HloInstruction* operand,
@@ -703,37 +701,37 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateFusion(
const Shape& shape, FusionKind fusion_kind,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ absl::Span<HloInstruction* const> operands,
HloComputation* fusion_computation);
// Creates a call instruction that applies the given computation on the given
// operands. "shape" is the resultant shape.
static std::unique_ptr<HloInstruction> CreateCall(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* computation);
// Creates a custom call instruction that applies the given custom call target
// to the given operands. "shape" is the resultant shape.
static std::unique_ptr<HloInstruction> CreateCustomCall(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target);
// Creates a tuple instruction with the given elements. This is a convenience
// wrapper around CreateVariadic.
static std::unique_ptr<HloInstruction> CreateTuple(
- tensorflow::gtl::ArraySlice<HloInstruction*> elements);
+ absl::Span<HloInstruction* const> elements);
// Creates a reverse instruction, which reverses the order of the elements
// in the specified dimensions.
static std::unique_ptr<HloInstruction> CreateReverse(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const int64> dimensions);
// Creates a Afterall instruction used for joining or creating new values of
// token type which thread through side-effecting operations. Operands must
// all be tokens, and there must be at least one operand.
static std::unique_ptr<HloInstruction> CreateAfterAll(
- tensorflow::gtl::ArraySlice<HloInstruction*> operands);
+ absl::Span<HloInstruction* const> operands);
// Creates an AfterAll instruction which creates a token type out of thin air
// (no operands). This is a separate method from CreateAfterAll to facility
@@ -867,8 +865,8 @@ class HloInstruction {
return false;
}
- if (!ContainersEqual(precision_config_.operand_precision(),
- other.precision_config_.operand_precision())) {
+ if (!absl::c_equal(precision_config_.operand_precision(),
+ other.precision_config_.operand_precision())) {
return false;
}
@@ -1124,8 +1122,7 @@ class HloInstruction {
// Clones the HLO instruction as above but with new shape and operands.
std::unique_ptr<HloInstruction> CloneWithNewOperands(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context = nullptr) const;
// Returns the computations this instruction directly calls (if any).
@@ -1478,6 +1475,8 @@ class HloInstruction {
// dimension and output feature dimension.
int64 feature_group_count() const;
+ void set_feature_group_count(int64 feature_group_count);
+
// Delegates to HloSelectAndScatterInstruction::select.
HloComputation* select() const;
@@ -1505,7 +1504,7 @@ class HloInstruction {
// Delegates to HloGatherInstruction::gather_dimension_numbers.
const GatherDimensionNumbers& gather_dimension_numbers() const;
// Delegates to HloGatherInstruction::gather_slice_sizes.
- tensorflow::gtl::ArraySlice<int64> gather_slice_sizes() const;
+ absl::Span<const int64> gather_slice_sizes() const;
// Delegates to HloScatterInstruction::scatter_dimension_numbers().
const ScatterDimensionNumbers& scatter_dimension_numbers() const;
@@ -1531,7 +1530,7 @@ class HloInstruction {
// Removes a list of operands with the given indices in ascending order.
void RemoveOperandsAtAscendingIndices(
- tensorflow::gtl::ArraySlice<int> ascending_indices);
+ absl::Span<const int> ascending_indices);
void AppendComputation(HloComputation* computation) {
called_computations_.push_back(computation);
@@ -1561,8 +1560,7 @@ class HloInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
virtual std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
// TODO(b/80131774): This should be pure virtual.
LOG(FATAL) << "Unimplemented method.";
@@ -1608,7 +1606,7 @@ class HloInstruction {
// Creates an n-ary elementwise operation.
static std::unique_ptr<HloInstruction> CreateNary(
const Shape& shape, HloOpcode opcode,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands);
+ absl::Span<HloInstruction* const> operands);
// Adds a user for this instruction.
void AddUser(HloInstruction* user);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 8b0b90dfb3..76b0e940a6 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -29,7 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
@@ -39,10 +39,8 @@ namespace {
using ::testing::ElementsAre;
using ::testing::UnorderedElementsAre;
-class HloInstructionTest : public HloTestBase {
+class HloInstructionTest : public HloVerifiedTestBase {
protected:
- HloInstructionTest() {}
-
Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
};
@@ -1086,16 +1084,14 @@ TEST_F(HloInstructionTest, PartiallyElementwise) {
TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
// Fused expression:
- //
- // x y
- // \ / \
- // min broadcast
+ // y
+ // /
+ // x broadcast
+ // \ / |
+ // min |
// \ /
// sub
//
- // The fusion instruction is elementwise on `x` because the only path from x
- // to sub contains only elementwise operations. It is not elementwise on `y`
- // because the path y->broadcast->sub is not all elementwise.
const Shape r0f32 = ShapeUtil::MakeShape(F32, {});
const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
@@ -1104,10 +1100,10 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
HloInstruction* y =
builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "y"));
- HloInstruction* min = builder.AddInstruction(
- HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, x, y));
HloInstruction* broadcast =
- builder.AddInstruction(HloInstruction::CreateBroadcast(r1f32, y, {0}));
+ builder.AddInstruction(HloInstruction::CreateBroadcast(r1f32, y, {}));
+ HloInstruction* min = builder.AddInstruction(
+ HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, x, broadcast));
HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
r1f32, HloOpcode::kSubtract, min, broadcast));
@@ -1118,10 +1114,10 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
EXPECT_FALSE(fusion->IsElementwise());
for (int64 operand_idx = 0; operand_idx < fusion->operand_count();
++operand_idx) {
- if (fusion->operand(operand_idx) == x) {
- EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx));
- } else {
+ if (fusion->operand(operand_idx) == y) {
EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx));
+ } else {
+ EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx));
}
}
}
@@ -1248,7 +1244,7 @@ TEST_F(HloInstructionTest, NestedFusionEquality) {
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto add_operand = builder.AddInstruction(
- HloInstruction::CreateBroadcast(data_shape, one, {1}));
+ HloInstruction::CreateBroadcast(data_shape, one, {}));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape, HloOpcode::kAdd, dot, add_operand));
auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -1743,5 +1739,23 @@ TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) {
<< clone->convolution_dimension_numbers().DebugString();
}
+TEST_F(HloInstructionTest, PreserveOperandPrecisionOnCloneConv) {
+ constexpr char kHloString[] = R"(
+ HloModule test_module
+ ENTRY test {
+ arg0 = f32[1,2,1] parameter(0)
+ arg1 = f32[1,1,1] parameter(1)
+ ROOT conv = f32[1,2,1] convolution(arg0, arg1), window={size=1},
+ dim_labels=b0f_0io->b0f, operand_precision={high,default}
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kHloString));
+ auto* conv = module->entry_computation()->root_instruction();
+
+ auto clone = conv->Clone();
+ EXPECT_THAT(clone->precision_config().operand_precision(),
+ ::testing::ElementsAre(PrecisionConfigProto::HIGH,
+ PrecisionConfigProto::DEFAULT));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index ffc74cfedd..e46afa764f 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -91,8 +91,7 @@ HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction(
std::unique_ptr<HloInstruction>
HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 3);
return absl::make_unique<HloBatchNormTrainingInstruction>(
@@ -113,8 +112,7 @@ HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction(
std::unique_ptr<HloInstruction>
HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 5);
return absl::make_unique<HloBatchNormInferenceInstruction>(
@@ -135,8 +133,7 @@ HloBatchNormGradInstruction::HloBatchNormGradInstruction(
std::unique_ptr<HloInstruction>
HloBatchNormGradInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 5);
return absl::make_unique<HloBatchNormGradInstruction>(
@@ -144,9 +141,9 @@ HloBatchNormGradInstruction::CloneWithNewOperandsImpl(
new_operands[4], epsilon(), feature_index());
}
-HloFftInstruction::HloFftInstruction(
- const Shape& shape, HloInstruction* operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length)
+HloFftInstruction::HloFftInstruction(const Shape& shape,
+ HloInstruction* operand, FftType fft_type,
+ absl::Span<const int64> fft_length)
: HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) {
fft_length_.assign(fft_length.begin(), fft_length.end());
AppendOperand(operand);
@@ -177,8 +174,7 @@ bool HloFftInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloFftInstruction>(shape, new_operands[0], fft_type_,
@@ -232,8 +228,7 @@ HloSendInstruction::HloSendInstruction(HloInstruction* operand,
}
std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloSendInstruction>(
@@ -250,8 +245,7 @@ HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
std::unique_ptr<HloInstruction>
HloSendDoneInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloSendDoneInstruction>(
@@ -271,8 +265,7 @@ HloRecvInstruction::HloRecvInstruction(const Shape& shape,
}
std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloRecvInstruction>(
@@ -293,8 +286,7 @@ HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
std::unique_ptr<HloInstruction>
HloRecvDoneInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloRecvDoneInstruction>(
@@ -303,7 +295,7 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl(
HloCollectiveInstruction::HloCollectiveInstruction(
HloOpcode opcode, const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups)
: HloInstruction(opcode, shape), replica_groups_(replica_groups) {
for (auto operand : operands) {
@@ -337,15 +329,14 @@ bool HloCollectiveInstruction::IdenticalSlowPath(
/*eq_computations*/) const {
const auto& casted_other =
static_cast<const HloCollectiveInstruction&>(other);
- return ContainersEqual(replica_groups(), casted_other.replica_groups(),
- [](const ReplicaGroup& a, const ReplicaGroup& b) {
- return ContainersEqual(a.replica_ids(),
- b.replica_ids());
- });
+ return absl::c_equal(replica_groups(), casted_other.replica_groups(),
+ [](const ReplicaGroup& a, const ReplicaGroup& b) {
+ return absl::c_equal(a.replica_ids(), b.replica_ids());
+ });
}
HloAllReduceInstruction::HloAllReduceInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier,
const absl::optional<int64>& all_reduce_id)
@@ -393,8 +384,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloAllReduceInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* /*context*/) const {
return absl::make_unique<HloAllReduceInstruction>(
shape, new_operands, to_apply(), replica_groups(),
@@ -402,15 +392,14 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl(
}
HloAllToAllInstruction::HloAllToAllInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups)
: HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
replica_groups) {}
std::unique_ptr<HloInstruction>
HloAllToAllInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* /*context*/) const {
return absl::make_unique<HloAllToAllInstruction>(shape, new_operands,
replica_groups());
@@ -452,25 +441,23 @@ bool HloCollectivePermuteInstruction::IdenticalSlowPath(
/*eq_computations*/) const {
const auto& casted_other =
static_cast<const HloCollectivePermuteInstruction&>(other);
- return ContainersEqual(
- source_target_pairs(), casted_other.source_target_pairs(),
- [](const std::pair<int64, int64>& a, const std::pair<int64, int64>& b) {
- return a == b;
- });
+ return absl::c_equal(source_target_pairs(),
+ casted_other.source_target_pairs(),
+ [](const std::pair<int64, int64>& a,
+ const std::pair<int64, int64>& b) { return a == b; });
}
std::unique_ptr<HloInstruction>
HloCollectivePermuteInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* /*context*/) const {
return absl::make_unique<HloCollectivePermuteInstruction>(
shape, new_operands[0], source_target_pairs());
}
-HloReverseInstruction::HloReverseInstruction(
- const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions)
+HloReverseInstruction::HloReverseInstruction(const Shape& shape,
+ HloInstruction* operand,
+ absl::Span<const int64> dimensions)
: HloInstruction(HloOpcode::kReverse, shape),
dimensions_(dimensions.begin(), dimensions.end()) {
AppendOperand(operand);
@@ -498,8 +485,7 @@ bool HloReverseInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloReverseInstruction>(shape, new_operands[0],
@@ -507,7 +493,7 @@ std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl(
}
HloConcatenateInstruction::HloConcatenateInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
int64 dimension)
: HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) {
for (auto operand : operands) {
@@ -539,16 +525,15 @@ bool HloConcatenateInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloConcatenateInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
return absl::make_unique<HloConcatenateInstruction>(shape, new_operands,
dimensions(0));
}
HloReduceInstruction::HloReduceInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> args,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ const Shape& shape, absl::Span<HloInstruction* const> args,
+ absl::Span<const int64> dimensions_to_reduce,
HloComputation* reduce_computation)
: HloInstruction(HloOpcode::kReduce, shape),
dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) {
@@ -583,10 +568,9 @@ bool HloReduceInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
- CHECK_EQ(new_operands.size(), 2);
+ CHECK_EQ(new_operands.size() % 2, 0);
return absl::make_unique<HloReduceInstruction>(shape, new_operands,
dimensions(), to_apply());
}
@@ -623,8 +607,7 @@ bool HloSortInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
HloInstruction* keys = new_operands[0];
HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr;
@@ -634,7 +617,7 @@ std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
HloTransposeInstruction::HloTransposeInstruction(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions)
+ absl::Span<const int64> dimensions)
: HloInstruction(HloOpcode::kTranspose, shape),
dimensions_(dimensions.begin(), dimensions.end()) {
CHECK_EQ(shape.dimensions().size(), dimensions.size());
@@ -678,8 +661,7 @@ bool HloTransposeInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloTransposeInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloTransposeInstruction>(shape, new_operands[0],
@@ -688,7 +670,7 @@ HloTransposeInstruction::CloneWithNewOperandsImpl(
HloBroadcastInstruction::HloBroadcastInstruction(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimension)
+ absl::Span<const int64> broadcast_dimension)
: HloInstruction(HloOpcode::kBroadcast, shape),
dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) {
AppendOperand(operand);
@@ -717,17 +699,16 @@ bool HloBroadcastInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloBroadcastInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloBroadcastInstruction>(shape, new_operands[0],
dimensions());
}
-HloMapInstruction::HloMapInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* map_computation)
+HloMapInstruction::HloMapInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ HloComputation* map_computation)
: HloInstruction(HloOpcode::kMap, shape) {
for (auto operand : operands) {
AppendOperand(operand);
@@ -776,17 +757,16 @@ bool HloMapInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
return absl::make_unique<HloMapInstruction>(shape, new_operands, to_apply());
}
-HloSliceInstruction::HloSliceInstruction(
- const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides)
+HloSliceInstruction::HloSliceInstruction(const Shape& shape,
+ HloInstruction* operand,
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides)
: HloInstruction(HloOpcode::kSlice, shape),
slice_starts_(start_indices.begin(), start_indices.end()),
slice_limits_(limit_indices.begin(), limit_indices.end()),
@@ -837,8 +817,7 @@ bool HloSliceInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloSliceInstruction>(
@@ -891,8 +870,7 @@ bool HloConstantInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloConstantInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
return absl::make_unique<HloConstantInstruction>(literal_->CloneToUnique());
}
@@ -949,8 +927,7 @@ bool HloTraceInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloTraceInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode());
}
@@ -968,7 +945,7 @@ HloFusionInstruction::HloFusionInstruction(const Shape& shape,
HloFusionInstruction::HloFusionInstruction(
const Shape& shape, FusionKind fusion_kind,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ absl::Span<HloInstruction* const> operands,
HloComputation* fusion_computation)
: HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
for (auto operand : operands) {
@@ -1375,8 +1352,7 @@ bool HloFusionInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
HloModule* module = context != nullptr ? context->module() : GetModule();
HloComputation* new_fused_computation = nullptr;
@@ -1414,7 +1390,7 @@ Status HloFusionInstruction::DeduplicateFusionOperands() {
HloRngInstruction::HloRngInstruction(
const Shape& shape, RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<HloInstruction*> parameters)
+ absl::Span<HloInstruction* const> parameters)
: HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) {
for (HloInstruction* param : parameters) {
AppendOperand(param);
@@ -1445,8 +1421,7 @@ bool HloRngInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloRngInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
return absl::make_unique<HloRngInstruction>(shape, distribution_,
new_operands);
@@ -1482,8 +1457,7 @@ bool HloParameterInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloParameterInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
return absl::make_unique<HloParameterInstruction>(parameter_number_, shape,
name());
@@ -1518,8 +1492,7 @@ bool HloGetTupleElementInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloGetTupleElementInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloGetTupleElementInstruction>(
@@ -1561,8 +1534,7 @@ bool HloReducePrecisionInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloReducePrecisionInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloReducePrecisionInstruction>(
@@ -1602,8 +1574,7 @@ bool HloInfeedInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloInfeedInstruction>(
@@ -1648,8 +1619,7 @@ bool HloOutfeedInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloOutfeedInstruction>(
@@ -1690,6 +1660,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const {
*proto.mutable_window() = window_;
*proto.mutable_convolution_dimension_numbers() =
convolution_dimension_numbers_;
+ proto.set_feature_group_count(feature_group_count_);
return proto;
}
@@ -1711,6 +1682,9 @@ bool HloConvolutionInstruction::IdenticalSlowPath(
eq_computations) const {
const auto& casted_other =
static_cast<const HloConvolutionInstruction&>(other);
+ if (feature_group_count_ != other.feature_group_count()) {
+ return false;
+ }
return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
protobuf_util::ProtobufEquals(
convolution_dimension_numbers(),
@@ -1719,8 +1693,7 @@ bool HloConvolutionInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloConvolutionInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloConvolutionInstruction>(
@@ -1764,8 +1737,7 @@ bool HloReduceWindowInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloReduceWindowInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloReduceWindowInstruction>(
@@ -1813,8 +1785,7 @@ bool HloSelectAndScatterInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 3);
return absl::make_unique<HloSelectAndScatterInstruction>(
@@ -1823,11 +1794,11 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
}
HloCustomCallInstruction::HloCustomCallInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target)
: HloInstruction(HloOpcode::kCustomCall, shape),
- custom_call_target_(custom_call_target.begin(),
- custom_call_target.end()) {
+ custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
+ feature_group_count_(1) {
for (auto operand : operands) {
AppendOperand(operand);
}
@@ -1843,6 +1814,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
*convolution_dimension_numbers_;
}
proto.set_custom_call_target(custom_call_target_);
+ proto.set_feature_group_count(feature_group_count_);
return proto;
}
@@ -1857,6 +1829,9 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
"dim_labels=",
ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_)));
}
+ if (feature_group_count_ != 1) {
+ extra.push_back(StrCat("feature_group_count=", feature_group_count_));
+ }
// By contract, we print the custom call target even if
// options.print_subcomputation_mode() == kOff, because the call target is not
// an HloComputation.
@@ -1884,13 +1859,15 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
casted_other.convolution_dimension_numbers()))) {
return false;
}
+ if (feature_group_count_ != casted_other.feature_group_count_) {
+ return false;
+ }
return custom_call_target_ == casted_other.custom_call_target_;
}
std::unique_ptr<HloInstruction>
HloCustomCallInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
auto cloned = absl::make_unique<HloCustomCallInstruction>(
shape, new_operands, custom_call_target());
@@ -1900,6 +1877,7 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
if (convolution_dimension_numbers_ != nullptr) {
cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
}
+ cloned->set_feature_group_count(feature_group_count_);
return std::move(cloned);
}
@@ -1933,8 +1911,7 @@ bool HloPadInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloPadInstruction>(shape, new_operands[0],
@@ -1943,7 +1920,7 @@ std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
HloDynamicSliceInstruction::HloDynamicSliceInstruction(
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes)
+ absl::Span<const int64> slice_sizes)
: HloInstruction(HloOpcode::kDynamicSlice, shape),
dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
AppendOperand(operand);
@@ -1973,8 +1950,7 @@ bool HloDynamicSliceInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloDynamicSliceInstruction>(
@@ -1984,7 +1960,7 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
HloGatherInstruction::HloGatherInstruction(
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes)
+ absl::Span<const int64> slice_sizes)
: HloInstruction(HloOpcode::kGather, shape) {
AppendOperand(operand);
AppendOperand(start_indices);
@@ -2013,10 +1989,9 @@ string HloGatherInstruction::GatherDimensionNumbersToString() const {
}
/* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
- tensorflow::gtl::ArraySlice<int64> offset_dims,
- tensorflow::gtl::ArraySlice<int64> collapsed_slice_dims,
- tensorflow::gtl::ArraySlice<int64> start_index_map,
- int64 index_vector_dim) {
+ absl::Span<const int64> offset_dims,
+ absl::Span<const int64> collapsed_slice_dims,
+ absl::Span<const int64> start_index_map, int64 index_vector_dim) {
GatherDimensionNumbers gather_dim_numbers;
for (int64 output_window_dim : offset_dims) {
gather_dim_numbers.add_offset_dims(output_window_dim);
@@ -2059,8 +2034,7 @@ bool HloGatherInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloGatherInstruction>(
@@ -2104,9 +2078,9 @@ string HloScatterInstruction::ScatterDimensionNumbersToString() const {
/* static */ ScatterDimensionNumbers
HloScatterInstruction::MakeScatterDimNumbers(
- tensorflow::gtl::ArraySlice<int64> update_window_dims,
- tensorflow::gtl::ArraySlice<int64> inserted_window_dims,
- tensorflow::gtl::ArraySlice<int64> scatter_dims_to_operand_dims,
+ absl::Span<const int64> update_window_dims,
+ absl::Span<const int64> inserted_window_dims,
+ absl::Span<const int64> scatter_dims_to_operand_dims,
int64 index_vector_dim) {
ScatterDimensionNumbers scatter_dim_numbers;
for (int64 update_window_dim : update_window_dims) {
@@ -2146,8 +2120,7 @@ bool HloScatterInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 3);
return absl::make_unique<HloScatterInstruction>(
@@ -2179,8 +2152,7 @@ bool HloIotaInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
return absl::make_unique<HloIotaInstruction>(shape, iota_dimension());
}
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index ee6e337b6a..3230383579 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -67,8 +67,7 @@ class HloBatchNormTrainingInstruction : public HloBatchNormInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -82,8 +81,7 @@ class HloBatchNormInferenceInstruction : public HloBatchNormInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -97,8 +95,7 @@ class HloBatchNormGradInstruction : public HloBatchNormInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -106,7 +103,7 @@ class HloFftInstruction : public HloInstruction {
public:
explicit HloFftInstruction(const Shape& shape, HloInstruction* operand,
FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
+ absl::Span<const int64> fft_length);
FftType fft_type() const { return fft_type_; }
const std::vector<int64>& fft_length() const { return fft_length_; }
@@ -124,8 +121,7 @@ class HloFftInstruction : public HloInstruction {
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Describes FFT type for an FFT instruction.
@@ -174,8 +170,7 @@ class HloSendInstruction : public HloSendRecvInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -187,8 +182,7 @@ class HloSendDoneInstruction : public HloSendRecvInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -200,8 +194,7 @@ class HloRecvInstruction : public HloSendRecvInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -213,8 +206,7 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -227,7 +219,7 @@ class HloCollectiveInstruction : public HloInstruction {
protected:
explicit HloCollectiveInstruction(
HloOpcode opcode, const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups);
HloInstructionProto ToProto() const override;
@@ -245,7 +237,7 @@ class HloCollectiveInstruction : public HloInstruction {
class HloAllReduceInstruction : public HloCollectiveInstruction {
public:
explicit HloAllReduceInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups,
absl::string_view barrier, const absl::optional<int64>& all_reduce_id);
@@ -274,8 +266,7 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The string representation of the barrier config used for CrossReplicaSum.
@@ -290,14 +281,13 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
class HloAllToAllInstruction : public HloCollectiveInstruction {
public:
explicit HloAllToAllInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups);
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -324,8 +314,7 @@ class HloCollectivePermuteInstruction : public HloInstruction {
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
const std::vector<std::pair<int64, int64>> source_target_pairs_;
@@ -334,7 +323,7 @@ class HloCollectivePermuteInstruction : public HloInstruction {
class HloReverseInstruction : public HloInstruction {
public:
explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const int64> dimensions);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
@@ -350,8 +339,7 @@ class HloReverseInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -359,9 +347,9 @@ class HloReverseInstruction : public HloInstruction {
class HloConcatenateInstruction : public HloInstruction {
public:
- explicit HloConcatenateInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- int64 dimension);
+ explicit HloConcatenateInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ int64 dimension);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
@@ -379,8 +367,7 @@ class HloConcatenateInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -388,26 +375,28 @@ class HloConcatenateInstruction : public HloInstruction {
class HloReduceInstruction : public HloInstruction {
public:
- explicit HloReduceInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> args,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
- HloComputation* reduce_computation);
+ explicit HloReduceInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> args,
+ absl::Span<const int64> dimensions_to_reduce,
+ HloComputation* reduce_computation);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
+ // Returns the number of input arrays (and, consequentially, the number of
+ // init values) this reduce has.
+ int64 input_count() const { return operand_count() / 2; }
+
// Returns the input tensors to be reduced.
- tensorflow::gtl::ArraySlice<HloInstruction*> inputs() const {
- return tensorflow::gtl::ArraySlice<HloInstruction*>(operands(), 0,
- operand_count() / 2);
+ absl::Span<HloInstruction* const> inputs() const {
+ return absl::MakeSpan(operands()).subspan(0, input_count());
}
// Returns the init values of the reduction.
- tensorflow::gtl::ArraySlice<HloInstruction*> init_values() const {
- return tensorflow::gtl::ArraySlice<HloInstruction*>(
- operands(), operand_count() / 2, operand_count());
+ absl::Span<HloInstruction* const> init_values() const {
+ return absl::MakeSpan(operands()).subspan(input_count(), operand_count());
}
private:
@@ -419,8 +408,7 @@ class HloReduceInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -448,8 +436,7 @@ class HloSortInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -457,9 +444,8 @@ class HloSortInstruction : public HloInstruction {
class HloTransposeInstruction : public HloInstruction {
public:
- explicit HloTransposeInstruction(
- const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand,
+ absl::Span<const int64> dimensions);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
@@ -477,8 +463,7 @@ class HloTransposeInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -486,9 +471,8 @@ class HloTransposeInstruction : public HloInstruction {
class HloBroadcastInstruction : public HloInstruction {
public:
- explicit HloBroadcastInstruction(
- const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimension);
+ explicit HloBroadcastInstruction(const Shape& shape, HloInstruction* operand,
+ absl::Span<const int64> broadcast_dimension);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
@@ -504,8 +488,7 @@ class HloBroadcastInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -513,9 +496,9 @@ class HloBroadcastInstruction : public HloInstruction {
class HloMapInstruction : public HloInstruction {
public:
- explicit HloMapInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* map_computation);
+ explicit HloMapInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ HloComputation* map_computation);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
@@ -533,8 +516,7 @@ class HloMapInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -543,9 +525,9 @@ class HloMapInstruction : public HloInstruction {
class HloSliceInstruction : public HloInstruction {
public:
explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides);
HloInstructionProto ToProto() const override;
@@ -584,8 +566,7 @@ class HloSliceInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Describes the [begin, end) index range for a slice.
@@ -627,8 +608,7 @@ class HloConstantInstruction : public HloInstruction {
CanonicalNameMap* canonical_name_map) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// TODO(b/36360764): Remove unique_ptr wrapping.
std::unique_ptr<Literal> literal_;
@@ -649,8 +629,7 @@ class HloTraceInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// TODO(b/36360764): Remove unique_ptr wrapping.
std::unique_ptr<Literal> literal_;
@@ -661,10 +640,9 @@ class HloFusionInstruction : public HloInstruction {
explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
HloInstruction* fused_root);
- explicit HloFusionInstruction(
- const Shape& shape, FusionKind fusion_kind,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* fusion_computation);
+ explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
+ absl::Span<HloInstruction* const> operands,
+ HloComputation* fusion_computation);
string ToCategory() const override;
// Returns a serialized representation of this instruction.
@@ -777,8 +755,7 @@ class HloFusionInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The type of the fusion. Used by kFusion only.
@@ -787,9 +764,9 @@ class HloFusionInstruction : public HloInstruction {
class HloRngInstruction : public HloInstruction {
public:
- explicit HloRngInstruction(
- const Shape& shape, RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<HloInstruction*> parameters);
+ explicit HloRngInstruction(const Shape& shape,
+ RandomDistribution distribution,
+ absl::Span<HloInstruction* const> parameters);
// Returns the random distribution for this rng node.
RandomDistribution random_distribution() const { return distribution_; }
// Returns a serialized representation of this instruction.
@@ -806,8 +783,7 @@ class HloRngInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The distribution requested for random number generation.
@@ -832,8 +808,7 @@ class HloParameterInstruction : public HloInstruction {
CanonicalNameMap* canonical_name_map) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
int64 parameter_number_ = 0;
@@ -857,8 +832,7 @@ class HloGetTupleElementInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
int64 tuple_index_ = -1;
@@ -886,8 +860,7 @@ class HloReducePrecisionInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The bit sizes for a reduce-precision operation.
@@ -924,8 +897,7 @@ class HloInfeedInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The string representation of the infeed configuration.
@@ -957,8 +929,7 @@ class HloOutfeedInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Shape of outfeed request.
@@ -999,8 +970,7 @@ class HloConvolutionInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
Window window_;
// Describes the dimension numbers used for a convolution.
@@ -1031,8 +1001,7 @@ class HloReduceWindowInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
Window window_;
};
@@ -1080,17 +1049,16 @@ class HloSelectAndScatterInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
Window window_;
};
class HloCustomCallInstruction : public HloInstruction {
public:
- explicit HloCustomCallInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- absl::string_view custom_call_target);
+ explicit HloCustomCallInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ absl::string_view custom_call_target);
const Window& window() const override {
CHECK(window_ != nullptr);
return *window_;
@@ -1111,6 +1079,10 @@ class HloCustomCallInstruction : public HloInstruction {
absl::make_unique<ConvolutionDimensionNumbers>(dnums);
}
const string& custom_call_target() const { return custom_call_target_; }
+ void set_feature_group_count(int64 feature_group_count) {
+ feature_group_count_ = feature_group_count;
+ }
+ int64 feature_group_count() const { return feature_group_count_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -1123,8 +1095,7 @@ class HloCustomCallInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Name of a global symbol to call, only present for kCustomCall.
string custom_call_target_;
@@ -1132,6 +1103,8 @@ class HloCustomCallInstruction : public HloInstruction {
std::unique_ptr<Window> window_;
// Describes the dimension numbers used for a convolution.
std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
+ // The number of feature groups. This is used for grouped convolutions.
+ int64 feature_group_count_;
};
class HloPadInstruction : public HloInstruction {
@@ -1153,8 +1126,7 @@ class HloPadInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The padding configuration that describes the edge padding and interior
@@ -1164,10 +1136,10 @@ class HloPadInstruction : public HloInstruction {
class HloDynamicSliceInstruction : public HloInstruction {
public:
- explicit HloDynamicSliceInstruction(
- const Shape& shape, HloInstruction* operand,
- HloInstruction* start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ explicit HloDynamicSliceInstruction(const Shape& shape,
+ HloInstruction* operand,
+ HloInstruction* start_indices,
+ absl::Span<const int64> slice_sizes);
// Old methods kept for smooth subclassing transition END.
// Returns the size of the slice in the given dimension for a dynamic
// slice node.
@@ -1189,8 +1161,7 @@ class HloDynamicSliceInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Describes the [start, start + size) range size for a dynamic slice
@@ -1204,12 +1175,12 @@ class HloGatherInstruction : public HloInstruction {
const Shape& shape, HloInstruction* operand,
HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
const GatherDimensionNumbers& gather_dimension_numbers() const {
CHECK(gather_dimension_numbers_ != nullptr);
return *gather_dimension_numbers_;
}
- tensorflow::gtl::ArraySlice<int64> gather_slice_sizes() const {
+ absl::Span<const int64> gather_slice_sizes() const {
return gather_slice_sizes_;
}
// Returns the dump string of the gather dimension numbers.
@@ -1219,10 +1190,9 @@ class HloGatherInstruction : public HloInstruction {
// Creates an instance of GatherDimensionNumbers.
static GatherDimensionNumbers MakeGatherDimNumbers(
- tensorflow::gtl::ArraySlice<int64> offset_dims,
- tensorflow::gtl::ArraySlice<int64> collapsed_slice_dims,
- tensorflow::gtl::ArraySlice<int64> start_index_map,
- int64 index_vector_dim);
+ absl::Span<const int64> offset_dims,
+ absl::Span<const int64> collapsed_slice_dims,
+ absl::Span<const int64> start_index_map, int64 index_vector_dim);
private:
std::vector<string> ExtraAttributesToStringImpl(
@@ -1232,8 +1202,7 @@ class HloGatherInstruction : public HloInstruction {
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
@@ -1258,9 +1227,9 @@ class HloScatterInstruction : public HloInstruction {
// Creates an instance of ScatterDimensionNumbers.
static ScatterDimensionNumbers MakeScatterDimNumbers(
- tensorflow::gtl::ArraySlice<int64> update_window_dims,
- tensorflow::gtl::ArraySlice<int64> inserted_window_dims,
- tensorflow::gtl::ArraySlice<int64> scatter_dims_to_operand_dims,
+ absl::Span<const int64> update_window_dims,
+ absl::Span<const int64> inserted_window_dims,
+ absl::Span<const int64> scatter_dims_to_operand_dims,
int64 index_vector_dim);
private:
@@ -1272,8 +1241,7 @@ class HloScatterInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
@@ -1296,8 +1264,7 @@ class HloIotaInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
const int64 iota_dimension_;
diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc
index 8350285e67..d9be841dd7 100644
--- a/tensorflow/compiler/xla/service/hlo_lexer.cc
+++ b/tensorflow/compiler/xla/service/hlo_lexer.cc
@@ -406,11 +406,7 @@ TokKind HloLexer::LexString() {
absl::string_view raw =
StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1);
string error;
- // TODO(b/113077997): Change to absl::CUnescape once it works properly with
- // copy-on-write std::string implementations.
- if (!tensorflow::str_util::CUnescape( // non-absl ok
- tensorflow::StringPiece(raw.data(), raw.size()), // non-absl ok
- &str_val_, &error)) {
+ if (!absl::CUnescape(raw, &str_val_, &error)) {
LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error;
return TokKind::kError;
}
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index 9ace0d76e0..5502e565b6 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -188,6 +188,7 @@ HLO_MATCHER(Fusion);
HLO_MATCHER(Ge);
HLO_MATCHER(AfterAll);
HLO_MATCHER(Gt);
+HLO_MATCHER(Iota);
HLO_MATCHER(Infeed);
HLO_MATCHER(IsFinite);
HLO_MATCHER(Le);
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 78167335c8..3a1bc4e328 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -353,7 +353,7 @@ bool IsUsedOutsideSubcomputation(
} // anonymous namespace
HloInstruction* HloModule::OutlineExpressionFromComputation(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_outline,
+ absl::Span<HloInstruction* const> instructions_to_outline,
const string& outlined_computation_name, HloComputation* computation) {
auto builder = HloComputation::Builder(outlined_computation_name);
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index cf129b835d..3c3371426b 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -25,6 +25,7 @@ limitations under the License.
#include <vector>
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_clone_context.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
@@ -192,7 +192,7 @@ class HloModule {
// order (root of outlined instructions last). TODO(jingyue): takes a set of
// instructions and topologically sorts them.
HloInstruction* OutlineExpressionFromComputation(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_outline,
+ absl::Span<HloInstruction* const> instructions_to_outline,
const string& outlined_computation_name, HloComputation* computation);
// Returns a randomly generated uint64.
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
index d70328c8a3..d83ee71490 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
@@ -193,7 +193,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
}
std::vector<HloInstruction*> HloModuleGroupUtil::RootInstructions(
- tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+ absl::Span<HloComputation* const> computations) {
std::vector<HloInstruction*> roots;
for (HloComputation* computation : computations) {
for (HloInstruction* instruction : computation->instructions()) {
@@ -293,7 +293,7 @@ Status HloModuleGroupUtil::VisitTopologicalOrder(
}
Status HloModuleGroupUtil::VerifyComputations(
- tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+ absl::Span<HloComputation* const> computations) {
auto visit_function =
[&](HloInstruction* instruction,
const std::vector<HloInstruction*>& instruction_group) {
@@ -324,7 +324,7 @@ Status HloModuleGroupUtil::VerifyComputations(
StatusOr<std::unique_ptr<HloReachabilityMap>>
HloModuleGroupUtil::ComputeReachability(
- tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+ absl::Span<HloComputation* const> computations) {
std::vector<HloInstruction*> post_order;
auto visit_function =
[&](HloInstruction* instruction,
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h
index c25ca1aff5..309c23045d 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@@ -56,7 +56,7 @@ class HloModuleGroupUtil {
// Returns the root instructions of the computations.
std::vector<HloInstruction*> RootInstructions(
- tensorflow::gtl::ArraySlice<HloComputation*> computations);
+ absl::Span<HloComputation* const> computations);
// Visit state of each instruction during DFS traversal.
enum VisitState {
@@ -93,15 +93,14 @@ class HloModuleGroupUtil {
HloInstruction* root);
// Verifies that the computations are well-formed (e.g., no cycles).
- Status VerifyComputations(
- tensorflow::gtl::ArraySlice<HloComputation*> computations);
+ Status VerifyComputations(absl::Span<HloComputation* const> computations);
// Below Reachability utils resemble those in HloComputation, except that
// they can handle instructions across multiple computations.
//
// Creates the reachability map for the instructions in the computations.
StatusOr<std::unique_ptr<HloReachabilityMap>> ComputeReachability(
- tensorflow::gtl::ArraySlice<HloComputation*> computations);
+ absl::Span<HloComputation* const> computations);
// Updates the reachability of the given instruction, taking the global
// predeccessorss and successors into account.
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 209ad5e58c..4bc1bacd7d 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -23,8 +23,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
@@ -44,7 +44,7 @@ class HloModuleTest : public HloTestBase {
// Creates a computation which calls the given zero-parameter computations.
std::unique_ptr<HloComputation> CreateCallComputation(
- tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+ absl::Span<HloComputation* const> computations) {
auto builder = HloComputation::Builder("Call");
for (auto computation : computations) {
builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index eae4508b24..ea8e6a239a 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -65,6 +65,7 @@ class HloParser {
StatusOr<HloSharding> ParseShardingOnly();
StatusOr<Window> ParseWindowOnly();
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
+ StatusOr<PaddingConfig> ParsePaddingConfigOnly();
// Stand-alone parsing utility for a single instruction worth of text.
Status ParseSingleInstruction(HloComputation::Builder* builder,
@@ -306,7 +307,7 @@ bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) {
// Creates replica groups from the provided nested array. groups[i] represents
// the replica ids for group 'i'.
std::vector<ReplicaGroup> CreateReplicaGroups(
- tensorflow::gtl::ArraySlice<std::vector<int64>> groups) {
+ absl::Span<const std::vector<int64>> groups) {
std::vector<ReplicaGroup> replica_groups;
absl::c_transform(groups, std::back_inserter(replica_groups),
[](const std::vector<int64>& ids) {
@@ -997,11 +998,11 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
instruction = builder->AddInstruction(HloInstruction::CreateReduce(
shape, /*operands=*/
- tensorflow::gtl::ArraySlice<HloInstruction*>(operands, 0,
- operands.size() / 2),
+ absl::Span<HloInstruction* const>(operands).subspan(
+ 0, operands.size() / 2),
/*init_values=*/
- tensorflow::gtl::ArraySlice<HloInstruction*>(
- operands, operands.size() / 2, operands.size()),
+ absl::Span<HloInstruction* const>(operands).subspan(
+ operands.size() / 2, operands.size()),
*dimensions_to_reduce, *reduce_computation));
break;
}
@@ -3156,6 +3157,18 @@ HloParser::ParseConvolutionDimensionNumbersOnly() {
return dnums;
}
+StatusOr<PaddingConfig> HloParser::ParsePaddingConfigOnly() {
+ lexer_.Lex();
+ PaddingConfig padding_config;
+ if (!ParsePaddingConfig(&padding_config)) {
+ return InvalidArgument("Syntax error:\n%s", GetError());
+ }
+ if (lexer_.GetKind() != TokKind::kEof) {
+ return InvalidArgument("Syntax error:\nExtra content after PaddingConfig");
+ }
+ return padding_config;
+}
+
Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder,
string* root_name) {
TF_RET_CHECK(missing_instruction_hook_ == nullptr);
@@ -3238,4 +3251,10 @@ StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
return parser.ParseConvolutionDimensionNumbersOnly();
}
+StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
+ HloModuleConfig config;
+ HloParser parser(str, config);
+ return parser.ParsePaddingConfigOnly();
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h
index 0c64b50481..1882a184da 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.h
+++ b/tensorflow/compiler/xla/service/hlo_parser.h
@@ -59,6 +59,9 @@ StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
// sharding, i.e. just the rhs of the "sharding={...}" attribute string.
StatusOr<HloSharding> ParseSharding(absl::string_view str);
+// Parses the result of PaddingConfigToString(), e.g. "0_0x1_1".
+StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index ba07ec432e..759789437c 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -382,7 +382,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
%input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1)
- ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=1
+ ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=1, operand_precision={high,default}
}
)"
@@ -1725,6 +1725,25 @@ TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) {
EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums));
}
+TEST_F(HloParserTest, ParsePaddingConfigNoInteriorPadding) {
+ const string original = "0_1x2_3";
+ TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original));
+ EXPECT_EQ(original, PaddingConfigToString(dnums));
+}
+
+TEST_F(HloParserTest, ParsePaddingConfigInteriorPadding) {
+ const string original = "0_1_0x2_3_4";
+ TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original));
+ EXPECT_EQ(original, PaddingConfigToString(dnums));
+}
+
+TEST_F(HloParserTest, ParsePaddingConfigInteriorPaddingImplicitZeroDim) {
+ TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig("0_1x2_3_4"));
+ // The extra "_0" gets added to the canonical string because the other dim has
+ // interior padding.
+ EXPECT_EQ("0_1_0x2_3_4", PaddingConfigToString(dnums));
+}
+
TEST_F(HloParserTest, NontupleInfeed) {
const string original = R"(HloModule nontuple_infeed:
ENTRY nontuple_infeed {
diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc
index 01b088a957..961930f0a8 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability.cc
+++ b/tensorflow/compiler/xla/service/hlo_reachability.cc
@@ -18,7 +18,7 @@ limitations under the License.
namespace xla {
HloReachabilityMap::HloReachabilityMap(
- tensorflow::gtl::ArraySlice<const HloInstruction*> instructions)
+ absl::Span<const HloInstruction* const> instructions)
: size_(instructions.size()) {
bit_vectors_.reserve(size_);
for (const HloInstruction* hlo : instructions) {
@@ -29,7 +29,7 @@ HloReachabilityMap::HloReachabilityMap(
}
bool HloReachabilityMap::SetReachabilityToUnion(
- tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ absl::Span<const HloInstruction* const> inputs,
const HloInstruction* instruction) {
BitVector& bit_vector = GetBitVector(instruction);
tmp_bit_vector_ = bit_vector;
@@ -38,13 +38,13 @@ bool HloReachabilityMap::SetReachabilityToUnion(
}
void HloReachabilityMap::FastSetReachabilityToUnion(
- tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ absl::Span<const HloInstruction* const> inputs,
const HloInstruction* instruction) {
SetReachabilityToUnionHelper(inputs, instruction, &GetBitVector(instruction));
}
void HloReachabilityMap::SetReachabilityToUnionHelper(
- tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ absl::Span<const HloInstruction* const> inputs,
const HloInstruction* instruction, BitVector* bit_vector) {
// If instruction is part of inputs, don't reset the bit_vector.
if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) {
diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h
index 48215d32a8..b66a2aa4bd 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability.h
+++ b/tensorflow/compiler/xla/service/hlo_reachability.h
@@ -19,10 +19,10 @@ limitations under the License.
#include <list>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/types.h"
@@ -42,7 +42,7 @@ class HloReachabilityMap {
// Sets up a graph with no edges and where the nodes correspond to the given
// instructions.
explicit HloReachabilityMap(
- tensorflow::gtl::ArraySlice<const HloInstruction*> instructions);
+ absl::Span<const HloInstruction* const> instructions);
// Set the reachability set of 'instruction' to the union of the reachability
// sets of 'inputs'. Upon return, IsReachable(x, instruction) where
@@ -54,13 +54,12 @@ class HloReachabilityMap {
// vector in the internal graph of this HloReachabilityMap for the given
// instruction and does not transitively update any other part of the
// adjacency matrix.
- bool SetReachabilityToUnion(
- tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
- const HloInstruction* instruction);
+ bool SetReachabilityToUnion(absl::Span<const HloInstruction* const> inputs,
+ const HloInstruction* instruction);
// As above, but faster because it does not check if the reachability changed.
void FastSetReachabilityToUnion(
- tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ absl::Span<const HloInstruction* const> inputs,
const HloInstruction* instruction);
// Sets entry so that IsReachable(a, b) will return true
@@ -141,7 +140,7 @@ class HloReachabilityMap {
// Helper for SetReachabilityToUnion/FastSetReachabilityToUnion.
void SetReachabilityToUnionHelper(
- tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ absl::Span<const HloInstruction* const> inputs,
const HloInstruction* instruction, BitVector* bit_vector);
// Return the index of the given instruction. The value is used to index into
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 569d2e5d2d..c9629926ea 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -202,8 +202,8 @@ class InstructionList {
// On object construction this ordinal is precisely the instruction's index
// in the list. Later, instructions inserted via InsertBefore receive
// duplicate values. However, monotonicity is preserved.
- void InsertBeforeInstructions(
- Item* to_insert, tensorflow::gtl::ArraySlice<Item*> before_instructions) {
+ void InsertBeforeInstructions(Item* to_insert,
+ absl::Span<Item* const> before_instructions) {
VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name()
<< " before {"
<< absl::StrJoin(before_instructions, ", ",
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index 7bd8a4a544..66ac1f66fd 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -106,7 +106,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::TransferLiteralToDevice(
}
StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
- const tensorflow::gtl::ArraySlice<const Literal*> literals) {
+ const absl::Span<const Literal* const> literals) {
std::vector<ScopedShapedBuffer> buffers;
for (const Literal* literal : literals) {
CHECK(literal != nullptr);
@@ -118,7 +118,7 @@ StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
}
StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
- const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> literals) {
+ const absl::Span<const std::unique_ptr<Literal>> literals) {
std::vector<const Literal*> literal_pointers;
literal_pointers.reserve(literals.size());
for (const auto& literal : literals) {
@@ -137,8 +137,8 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<const Literal*> arguments,
- bool run_hlo_passes, ExecutionProfile* profile) {
+ const absl::Span<const Literal* const> arguments, bool run_hlo_passes,
+ ExecutionProfile* profile) {
TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
TransferLiteralsToDevice(arguments));
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
@@ -152,7 +152,7 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arguments,
+ const absl::Span<const std::unique_ptr<Literal>> arguments,
bool run_hlo_passes, ExecutionProfile* profile) {
// Construct a vector of plain pointers for the arguments.
std::vector<const Literal*> argument_pointers;
@@ -169,8 +169,8 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- bool run_hlo_passes, ExecutionProfile* profile) {
+ const absl::Span<const ShapedBuffer* const> arguments, bool run_hlo_passes,
+ ExecutionProfile* profile) {
// Get service run options.
se::Stream stream(backend().default_stream_executor());
stream.Init();
@@ -190,8 +190,8 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<ScopedShapedBuffer> arguments,
- bool run_hlo_passes, ExecutionProfile* profile) {
+ const absl::Span<const ScopedShapedBuffer> arguments, bool run_hlo_passes,
+ ExecutionProfile* profile) {
std::vector<const ShapedBuffer*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
@@ -226,8 +226,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
// no arguments.
std::vector<const ShapedBuffer*> argument_buffer_ptrs(
options.num_replicas * options.arguments.size() + 1);
- std::vector<tensorflow::gtl::ArraySlice<const ShapedBuffer*>>
- argument_buffer_slices;
+ std::vector<absl::Span<const ShapedBuffer* const>> argument_buffer_slices;
int64 index = 0;
for (int64 i = 0; i < options.num_replicas; ++i) {
int64 device = device_assignment(i, 0);
diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h
index cfc519063e..76d8b92bed 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.h
+++ b/tensorflow/compiler/xla/service/hlo_runner.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
@@ -104,9 +104,9 @@ class HloRunner {
// Transfers data between the host and device.
StatusOr<ScopedShapedBuffer> TransferLiteralToDevice(const Literal& literal);
StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
- const tensorflow::gtl::ArraySlice<const Literal*> literals);
+ const absl::Span<const Literal* const> literals);
StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
- const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> literals);
+ const absl::Span<const std::unique_ptr<Literal>> literals);
StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
const ShapedBuffer& buffer);
@@ -117,24 +117,24 @@ class HloRunner {
// optimization.
StatusOr<std::unique_ptr<Literal>> Execute(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<const Literal*> arguments,
+ const absl::Span<const Literal* const> arguments,
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
StatusOr<std::unique_ptr<Literal>> Execute(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arguments,
+ const absl::Span<const std::unique_ptr<Literal>> arguments,
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
// As Execute(), but accepts and returns device buffers instead of host
// buffers.
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const absl::Span<const ShapedBuffer* const> arguments,
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<ScopedShapedBuffer> arguments,
+ const absl::Span<const ScopedShapedBuffer> arguments,
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
// Executes a given HLO module into a set of replicas, and returns a map
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
index 930801288a..d49d09d459 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
@@ -269,7 +269,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
auto abs_abs1 = builder.AddInstruction(
HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const));
auto tuple = builder.AddInstruction(HloInstruction::CreateTuple(
- tensorflow::gtl::ArraySlice<HloInstruction*>({abs_abs1})));
+ absl::Span<HloInstruction* const>({abs_abs1})));
auto tuple_elm = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(r1f32, tuple, 0));
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 980dae07ce..de7e6b53d4 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -54,9 +54,8 @@ HloSharding HloSharding::Tuple(const ShapeTree<HloSharding>& sub_shardings) {
return HloSharding(flattened_list);
}
-HloSharding HloSharding::Tuple(
- const Shape& tuple_shape,
- tensorflow::gtl::ArraySlice<HloSharding> shardings) {
+HloSharding HloSharding::Tuple(const Shape& tuple_shape,
+ absl::Span<const HloSharding> shardings) {
CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape);
for (auto& sharding : shardings) {
CHECK(!sharding.IsTuple()) << sharding.ToString();
@@ -142,7 +141,7 @@ std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
CHECK(!maximal_);
CHECK(!IsTuple());
std::vector<int64> ret_index;
- tile_assignment_.Each([&](tensorflow::gtl::ArraySlice<int64> index, int64 d) {
+ tile_assignment_.Each([&](absl::Span<const int64> index, int64 d) {
if (d == device) {
ret_index = {index.begin(), index.end()};
}
@@ -151,8 +150,7 @@ std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
return ret_index;
}
-int64 HloSharding::DeviceForTileIndex(
- tensorflow::gtl::ArraySlice<int64> index) const {
+int64 HloSharding::DeviceForTileIndex(absl::Span<const int64> index) const {
CHECK(!replicated_);
CHECK(!IsTuple());
if (maximal_) {
@@ -319,7 +317,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
Status status = Status::OK();
std::set<int64> seen_cores;
tile_assignment_.Each(
- [&](tensorflow::gtl::ArraySlice<int64> indices, int32 core) {
+ [&](absl::Span<const int64> indices, int32 core) {
// Don't overwrite a bad status, so we report the first error.
if (status.ok()) {
if (core >= num_devices) {
@@ -429,12 +427,23 @@ Shape HloSharding::TileShape(const Shape& shape) const {
HloSharding HloSharding::GetSubSharding(const Shape& shape,
const ShapeIndex& index) const {
CHECK(IsTuple());
-
- Shape sub_shape = ShapeUtil::GetSubshape(shape, index);
- ShapeTree<HloSharding> sub_shape_tree(sub_shape, Replicate());
- sub_shape_tree.CopySubtreeFrom(GetAsShapeTree(shape), index, {});
- return ShapeUtil::IsTuple(sub_shape) ? Tuple(sub_shape_tree)
- : sub_shape_tree.element(ShapeIndex({}));
+ int64 sharding_index = 0;
+ const Shape* sub_shape = &shape;
+ for (int64 idx : index) {
+ for (int64 i = 0; i < idx; ++i) {
+ sharding_index +=
+ ShapeUtil::GetLeafCount(ShapeUtil::GetSubshape(*sub_shape, {i}));
+ }
+ sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx});
+ }
+ if (ShapeUtil::IsTuple(*sub_shape)) {
+ auto begin_it = tuple_elements_.begin() + sharding_index;
+ std::vector<HloSharding> sub_shardings(
+ begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape));
+ return HloSharding::Tuple(*sub_shape, sub_shardings);
+ } else {
+ return tuple_elements_[sharding_index];
+ }
}
absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const {
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index be51c3f55b..9775505f86 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -23,12 +23,12 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -66,7 +66,7 @@ class HloSharding {
// shardings must match the number of leaf nodes in tuple_shape. For
// empty tuples, the shardings array must have one element.
static HloSharding Tuple(const Shape& tuple_shape,
- tensorflow::gtl::ArraySlice<HloSharding> shardings);
+ absl::Span<const HloSharding> shardings);
// Creates a new sharding for a tuple type, with a single input sharding
// repeated on each leaf.
@@ -132,7 +132,7 @@ class HloSharding {
// Returns the device that should execute the given tile.
// It is an error to call this if is_replicated() is true.
// REQUIRES: !IsTuple()
- int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice<int64> index) const;
+ int64 DeviceForTileIndex(absl::Span<const int64> index) const;
// Given a device ID, returns the offset within the specified shape of the
// tile that should be executed on the given core. This returns the lower
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
index 6e9b96488c..34cba6136f 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
@@ -372,7 +372,7 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain,
}
StatusOr<std::shared_ptr<const HloSharding>> ExtractOriginalCommonSharding(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
+ absl::Span<HloInstruction* const> instructions) {
// If we are here, all the instructions being passed had the same sharding
// (or no sharding), by the means of the ShardingMatches() API.
// As such, no kDomain was inserted, and here we are asked to extract the
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
index 7a6b0d9abc..cba5db927a 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
@@ -16,11 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index 2341f8ada0..80634677e7 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -29,8 +29,8 @@ limitations under the License.
namespace xla {
namespace {
-Array<int64> MakeArray(tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> contents) {
+Array<int64> MakeArray(absl::Span<const int64> dimensions,
+ absl::Span<const int64> contents) {
Array<int64> a(dimensions);
std::copy(contents.begin(), contents.end(), a.begin());
return a;
diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc
index e0c1326177..773fc7d225 100644
--- a/tensorflow/compiler/xla/service/hlo_value.cc
+++ b/tensorflow/compiler/xla/service/hlo_value.cc
@@ -149,7 +149,7 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
} // namespace
void HloValue::SetPositionsAndComputeUses(
- tensorflow::gtl::ArraySlice<HloPosition> positions) {
+ absl::Span<const HloPosition> positions) {
CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once.";
// The positions must be unique and should not contain the defining position
@@ -222,8 +222,7 @@ string HloValueSet::ToString() const {
}));
}
-bool HloValueSet::AssignUnionOf(
- tensorflow::gtl::ArraySlice<const HloValueSet*> inputs) {
+bool HloValueSet::AssignUnionOf(absl::Span<const HloValueSet* const> inputs) {
HloValueSet union_set;
for (const HloValueSet* input : inputs) {
for (const HloValue* value : input->values()) {
@@ -254,7 +253,7 @@ std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) {
}
bool InstructionValueSet::AssignUnionOf(
- tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
+ absl::Span<const InstructionValueSet* const> inputs) {
CHECK_GT(inputs.size(), 0);
for (int i = 1; i < inputs.size(); ++i) {
DCHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape()));
diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h
index a1151f65e0..b6670d409b 100644
--- a/tensorflow/compiler/xla/service/hlo_value.h
+++ b/tensorflow/compiler/xla/service/hlo_value.h
@@ -20,13 +20,13 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -108,8 +108,7 @@ class HloValue : public BufferValue {
// Sets the positions in the module at which the HloValue appears. Updates
// uses. Should be called once and only once. The defining position should not
// be included in 'positions' as this is set at construction time.
- void SetPositionsAndComputeUses(
- tensorflow::gtl::ArraySlice<HloPosition> positions);
+ void SetPositionsAndComputeUses(absl::Span<const HloPosition> positions);
// Returns whether this value is a phi value.
bool is_phi() const { return is_phi_; }
@@ -186,14 +185,14 @@ class HloValueSet {
public:
HloValueSet() = default;
- explicit HloValueSet(tensorflow::gtl::ArraySlice<const HloValue*> values)
+ explicit HloValueSet(absl::Span<const HloValue* const> values)
: values_(values.begin(), values.end()) {
SortAndUniquifyValues();
}
// Sets this value set to the union of the given value sets. Returns whether
// this value set changed.
- bool AssignUnionOf(tensorflow::gtl::ArraySlice<const HloValueSet*> inputs);
+ bool AssignUnionOf(absl::Span<const HloValueSet* const> inputs);
// Return the vector of HloValues in the set. Values in the vector are unique
// and stably sorted by value id.
@@ -247,8 +246,7 @@ class InstructionValueSet : public ShapeTree<HloValueSet> {
// Sets this value set to the union of the given value sets. Returns whether
// this value set changed.
- bool AssignUnionOf(
- tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
+ bool AssignUnionOf(absl::Span<const InstructionValueSet* const> inputs);
string ToString() const;
};
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index f1b29c2559..95516dec74 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -288,14 +288,13 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
}
Status ShapeVerifier::HandleReduce(HloInstruction* reduce) {
- if (!ShapeUtil::IsArray(reduce->shape())) {
- return InvalidArgument("Variadic reduce is not supported.");
+ std::vector<const Shape*> operand_shapes;
+ for (const HloInstruction* operand : reduce->operands()) {
+ operand_shapes.push_back(&operand->shape());
}
- return CheckShape(
- reduce,
- ShapeInference::InferReduceShape(
- {&reduce->operand(0)->shape(), &reduce->operand(1)->shape()},
- reduce->dimensions(), reduce->to_apply()->ComputeProgramShape()));
+ return CheckShape(reduce, ShapeInference::InferReduceShape(
+ operand_shapes, reduce->dimensions(),
+ reduce->to_apply()->ComputeProgramShape()));
}
Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) {
@@ -700,8 +699,7 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) {
instruction->opcode(), instruction->operands()));
}
-string ComputationsToString(
- tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+string ComputationsToString(absl::Span<HloComputation* const> computations) {
return absl::StrJoin(computations, ",",
[](string* s, const HloComputation* computation) {
s->append(computation->name());
@@ -1069,9 +1067,9 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RET_CHECK(instruction->parent() == computation);
if (instruction->opcode() == HloOpcode::kFusion) {
TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction));
- TF_RET_CHECK(
- ContainersEqual(instruction->called_computations(),
- {instruction->fused_instructions_computation()}))
+ TF_RET_CHECK(instruction->called_computations() ==
+ absl::Span<HloComputation* const>(
+ {instruction->fused_instructions_computation()}))
<< "Fusion HLO calls computations other than the "
"fused_instructions_computation: "
<< instruction->ToString()
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index 70b741353d..0cac210c24 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -34,6 +34,8 @@ namespace {
using ::testing::HasSubstr;
+// This class cannot be converted to use HloVerifiedTestBase. It explicitly
+// uses HloTestBase to create and test malformed HLOs.
class HloVerifierTest : public HloTestBase {
public:
HloVerifierTest()
@@ -277,5 +279,84 @@ TEST_F(HloVerifierTest, RngElementTypeNotSupported) {
EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported"));
}
+TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) {
+ // This testcase can't be written using textual HLO, because it doesn't parse
+ // negative interior padding. That's probably a feature. :)
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {100}), "param"));
+ PaddingConfig padding_config;
+ padding_config.add_dimensions()->set_interior_padding(-1);
+ builder.AddInstruction(HloInstruction::CreatePad(
+ ShapeUtil::MakeShape(F32, {100}), param,
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(F32).CloneToUnique())),
+ padding_config));
+
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("Interior padding cannot be negative"));
+}
+
+TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
+ // This testcase can't be written using textual HLO, because it doesn't parse
+ // negative interior padding. That's probably a feature. :)
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {100}), "param"));
+ PaddingConfig padding_config;
+ padding_config.add_dimensions()->set_interior_padding(-1);
+ builder.AddInstruction(HloInstruction::CreatePad(
+ ShapeUtil::MakeShape(F32, {100}), param,
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(F32).CloneToUnique())),
+ padding_config));
+
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
+ HasSubstr("Interior padding cannot be negative"));
+}
+
+// Simple module containing a convolution as the root.
+static const char* const kConvHloString = R"(
+HloModule module
+ENTRY entry_computation {
+ param0 = f16[128,128,56,56] parameter(0)
+ param1 = f16[3,3,128,128] parameter(1)
+ zero_f16 = f16[] constant(0)
+ ROOT conv = f16[128,128,28,28] convolution(param0, param1),
+ window={size=3x3 stride=2x2}, dim_labels=bf01_01io->bf01
+})";
+
+TEST_F(HloVerifierTest, ConvNegativeWindowDilationNotAllowed) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kConvHloString));
+ auto* conv = module->entry_computation()->root_instruction();
+ Window w = conv->window();
+ w.mutable_dimensions(0)->set_window_dilation(-1);
+ conv->set_window(w);
+
+ EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
+ HasSubstr("non-positive window dilation factor"));
+}
+
+TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kConvHloString));
+ auto* conv = module->entry_computation()->root_instruction();
+ Window w = conv->window();
+ w.mutable_dimensions(0)->set_base_dilation(-1);
+ conv->set_window(w);
+
+ EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
+ HasSubstr("non-positive base area dilation factor"));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
index df88587492..f85d31d522 100644
--- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
+++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
@@ -26,11 +26,6 @@ namespace xla {
namespace {
class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase {
- public:
- ImplicitBroadcastRemoverTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-
protected:
ImplicitBroadcastRemover remover_;
};
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 43ef30d1eb..a4de02a890 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -35,7 +35,6 @@ using ConstantArray = Analysis::ConstantArray;
using ReshapedArray = Analysis::ReshapedArray;
using ScalarIndexedArray = Analysis::ScalarIndexedArray;
using absl::StrJoin;
-using tensorflow::gtl::ArraySlice;
} // namespace
string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
@@ -186,7 +185,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForConstant(
StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldGatherOfGather(
ScalarIndexedArray* source, Array* indices, int64 source_dim,
- tensorflow::gtl::ArraySlice<int64> output_dims, Shape shape) {
+ absl::Span<const int64> output_dims, Shape shape) {
// We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)).
// `source` is the inner Gather(A, X).
@@ -252,8 +251,7 @@ StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldGatherOfGather(
StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
const Shape& shape, const GatherDimensionNumbers& dim_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes, Array* source,
- Array* indices) {
+ absl::Span<const int64> slice_sizes, Array* source, Array* indices) {
if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) {
VLOG(3) << "ComputeArrayForGather: indices are not scalar";
return nullptr;
@@ -314,7 +312,7 @@ namespace {
// Returns an index into `values` such that the product of the range
// [values.begin()+index, values.end()) is equal to `product`. If there is no
// such index, return -1. All integers in `values` must be positive.
-int64 FindSuffixWithProduct(ArraySlice<int64> values, int64 product) {
+int64 FindSuffixWithProduct(absl::Span<const int64> values, int64 product) {
DCHECK(absl::c_all_of(values, [](int64 value) { return value > 0; }));
int64 current_product = 1;
@@ -343,7 +341,8 @@ struct ReshapePassthroughDimPair {
// The returned vector of pairs is sorted in both the result_dim and the
// operand_dim components.
std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
- ArraySlice<int64> operand_shape, ArraySlice<int64> result_shape) {
+ absl::Span<const int64> operand_shape,
+ absl::Span<const int64> result_shape) {
// A reshape can be seen as an index mapping from output index to input index:
//
// (i_0, ..., i_n) = f(o_0, ..., o_m)
@@ -420,7 +419,7 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
// Return true if `dim` is stated as an passthrough operand dim in
// `passthrough_dims`.
bool IsReshapePassthroughOperandDim(
- ArraySlice<ReshapePassthroughDimPair> passthrough_dims, int64 dim) {
+ absl::Span<const ReshapePassthroughDimPair> passthrough_dims, int64 dim) {
return absl::c_any_of(passthrough_dims,
[&](ReshapePassthroughDimPair passthrough_dim_pair) {
return passthrough_dim_pair.operand_dim == dim;
@@ -430,7 +429,8 @@ bool IsReshapePassthroughOperandDim(
// Maps `operand_dim` which must be an passthrough operand dimension to its
// corresponding passthrough result dimension based on `passthrough_dims`.
int64 MapPassthroughOperandDimToResultDim(
- ArraySlice<ReshapePassthroughDimPair> passthrough_dims, int64 operand_dim) {
+ absl::Span<const ReshapePassthroughDimPair> passthrough_dims,
+ int64 operand_dim) {
auto it = absl::c_find_if(
passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) {
return passthrough_dim_pair.operand_dim == operand_dim;
@@ -439,9 +439,9 @@ int64 MapPassthroughOperandDimToResultDim(
return it->result_dim;
}
-int64 FindSourcePositionForPassthroughResultDim(ArraySlice<int64> operand_shape,
- ArraySlice<int64> result_shape,
- int64 source_passthrough_dim) {
+int64 FindSourcePositionForPassthroughResultDim(
+ absl::Span<const int64> operand_shape, absl::Span<const int64> result_shape,
+ int64 source_passthrough_dim) {
VLOG(3) << "FindSourcePositionForPassthroughResultDim(["
<< StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",")
<< "], " << source_passthrough_dim << ")";
@@ -499,7 +499,7 @@ IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims(
for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
if (shape.dimensions(i) == 1) {
degenerate_dims_seen++;
- } else if (ArrayContains(operand->output_dims(), i)) {
+ } else if (absl::c_linear_search(operand->output_dims(), i)) {
new_output_dims.push_back(i - degenerate_dims_seen);
}
}
@@ -519,8 +519,7 @@ IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims(
}
StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::ReshapeToAddDegenerateDims(
- ScalarIndexedArray* operand,
- tensorflow::gtl::ArraySlice<int64> degenerate_dims) {
+ ScalarIndexedArray* operand, absl::Span<const int64> degenerate_dims) {
if (degenerate_dims.empty()) {
return operand;
}
@@ -873,7 +872,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
return nullptr;
}
- ArraySlice<int64> broadcast_dims = broadcast_instr->dimensions();
+ absl::Span<const int64> broadcast_dims = broadcast_instr->dimensions();
auto is_broadcasted_dim = [&](int64 output_dim) {
return absl::c_find(broadcast_dims, output_dim) == broadcast_dims.end();
};
@@ -896,7 +895,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
// The scalar-indexed node "removes" the source dim and "inserts" the output
// dims. We do the opposite here to undo the scalar-indexed operation.
- ArraySlice<int64> output_dims = scalar_indexed_const->output_dims();
+ absl::Span<const int64> output_dims = scalar_indexed_const->output_dims();
for (int64 i = output_dims.size() - 1; i >= 0; --i) {
CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted);
EraseAt(&simulated_index, output_dims[i]);
@@ -973,12 +972,12 @@ namespace {
// Returns the non-contracting non-batch dimension (as per `contracting_dims`
// and `batch_dims`) if there is exactly one, otherwise returns nullopt.
absl::optional<int64> GetOnlyNonContractingNonBatchDim(
- int64 rank, ArraySlice<int64> contracting_dims,
- ArraySlice<int64> batch_dims) {
+ int64 rank, absl::Span<const int64> contracting_dims,
+ absl::Span<const int64> batch_dims) {
absl::optional<int64> result;
for (int64 dim = 0; dim < rank; dim++) {
- if (!ArrayContains(contracting_dims, dim) &&
- !ArrayContains(batch_dims, dim)) {
+ if (!absl::c_linear_search(contracting_dims, dim) &&
+ !absl::c_linear_search(batch_dims, dim)) {
if (result.has_value()) {
return absl::nullopt;
}
@@ -998,7 +997,8 @@ absl::optional<int64> GetOnlyNonContractingNonBatchDim(
// of whatever operand `indexed_array` is to the dot (LHS or RHS).
bool CanFoldDotIntoIndexedArray(
absl::string_view tag, Analysis::ScalarIndexedConstantArray* indexed_array,
- ArraySlice<int64> contracting_dims, ArraySlice<int64> batch_dims) {
+ absl::Span<const int64> contracting_dims,
+ absl::Span<const int64> batch_dims) {
absl::optional<int64> non_contracting_non_batch_dim =
GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()),
contracting_dims, batch_dims);
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index 3fa7d749e1..dcfb725535 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -188,9 +188,7 @@ class IndexedArrayAnalysis {
// `output_dims` are the dimensions in the output array that are being used
// to compute an index into the `indices` array. See the class
// documentation and the overview for more details.
- tensorflow::gtl::ArraySlice<int64> output_dims() const {
- return output_dims_;
- }
+ absl::Span<const int64> output_dims() const { return output_dims_; }
private:
explicit ScalarIndexedArray(Array* source, Array* indices, int64 source_dim,
@@ -265,8 +263,7 @@ class IndexedArrayAnalysis {
StatusOr<Array*> ComputeArrayForGather(
const Shape& shape, const GatherDimensionNumbers& dim_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes, Array* source,
- Array* indices);
+ absl::Span<const int64> slice_sizes, Array* source, Array* indices);
StatusOr<Array*> ComputeArrayForDotWithIndexedLhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
@@ -303,7 +300,7 @@ class IndexedArrayAnalysis {
// G1 = [Arr[i] for i in I2]
StatusOr<ScalarIndexedArray*> FoldGatherOfGather(
ScalarIndexedArray* source, Array* indices, int64 source_dim,
- tensorflow::gtl::ArraySlice<int64> output_dims, Shape shape);
+ absl::Span<const int64> output_dims, Shape shape);
// Reshapes a scalar-indexed node to remove the degenerate dimensions in its
// output. The result is always a scalar-indexed node.
@@ -313,8 +310,7 @@ class IndexedArrayAnalysis {
// Reshapes a scalar-indexed node such that the result has the degenerate
// dimensions `degenerate_dims`. The result is always a scalar-indexed node.
StatusOr<ScalarIndexedArray*> ReshapeToAddDegenerateDims(
- ScalarIndexedArray* operand,
- tensorflow::gtl::ArraySlice<int64> degenerate_dims);
+ ScalarIndexedArray* operand, absl::Span<const int64> degenerate_dims);
StatusOr<ScalarIndexedArray*> FoldReshapeOfGather(
const Shape& shape, ScalarIndexedConstantArray* operand);
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
index c34c32f7d3..2d03aebc1a 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
@@ -22,11 +22,6 @@ limitations under the License.
namespace xla {
namespace {
class IndexedArrayAnalysisTest : public HloVerifiedTestBase {
- public:
- IndexedArrayAnalysisTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-
protected:
void AssertArrayForRootExpressionIs(const string& hlo_text,
const string& root_expression) {
diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/inliner.cc
index 5c193fceb9..5fd779ebf9 100644
--- a/tensorflow/compiler/xla/service/inliner.cc
+++ b/tensorflow/compiler/xla/service/inliner.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 83313c7ec1..8c907eae0c 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -172,7 +172,8 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
});
return std::count_if(hlo->operands().begin(), hlo->operands().end(),
[output_rank](HloInstruction* operand) {
- if (operand->opcode() == HloOpcode::kBroadcast) {
+ if (operand->opcode() == HloOpcode::kBroadcast ||
+ operand->opcode() == HloOpcode::kIota) {
return false;
}
if (operand->opcode() == HloOpcode::kConstant &&
@@ -218,7 +219,7 @@ bool InstructionFusion::CanFuseOnAllPaths(
InstructionFusion::HloInstructionSet
InstructionFusion::ComputeGloballyUnfusible(
- tensorflow::gtl::ArraySlice<HloInstruction*> post_order) {
+ absl::Span<HloInstruction* const> post_order) {
// Forbid fusion of producers that:
// a) Need to be duplicated, unless they can be fused into all consumers
// via all paths.
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index 9802d4cfc1..00b658959a 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -123,7 +123,7 @@ class InstructionFusion : public HloPassInterface {
// Computes the set of nodes that we do not want to fuse into any of their
// consumers based on a global analysis of the HLO graph.
HloInstructionSet ComputeGloballyUnfusible(
- tensorflow::gtl::ArraySlice<HloInstruction*> post_order);
+ absl::Span<HloInstruction* const> post_order);
// Used to determine if an HLO is expensive. Expensive operations will not be
// duplicated.
diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD
index 581f8d2e92..146c9052f1 100644
--- a/tensorflow/compiler/xla/service/interpreter/BUILD
+++ b/tensorflow/compiler/xla/service/interpreter/BUILD
@@ -89,6 +89,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
@@ -114,5 +115,6 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_headers_lib",
+ "@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 2259dc1083..5dea124768 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -47,7 +47,7 @@ InterpreterExecutable::~InterpreterExecutable() {}
StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) {
se::Stream* stream = run_options->stream();
se::StreamExecutor* executor = stream->parent();
@@ -111,7 +111,7 @@ StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ absl::Span<const ShapedBuffer* const> arguments) {
return tensorflow::errors::Unimplemented(
"ExecuteAsyncOnStream is not yet supported on Interpreter.");
}
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h
index 91d8148d26..3b1ebce0c7 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.h
+++ b/tensorflow/compiler/xla/service/interpreter/executable.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
@@ -29,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -48,13 +48,13 @@ class InterpreterExecutable : public Executable {
StatusOr<ScopedShapedBuffer> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) override
LOCKS_EXCLUDED(evaluator_lock_);
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) override;
+ absl::Span<const ShapedBuffer* const> arguments) override;
static int64 ShapeSizeBytes(const Shape& shape);
diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h
index db6b910b32..fbb9945784 100644
--- a/tensorflow/compiler/xla/service/interpreter/executor.h
+++ b/tensorflow/compiler/xla/service/interpreter/executor.h
@@ -22,9 +22,9 @@ limitations under the License.
#include <functional>
#include <memory>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/device_description.h"
@@ -47,7 +47,7 @@ limitations under the License.
namespace stream_executor {
namespace interpreter {
-using Args = tensorflow::gtl::ArraySlice<DeviceMemoryBase>;
+using Args = absl::Span<const DeviceMemoryBase>;
class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
public:
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 5e5c93e3a2..6e17711f57 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
@@ -51,7 +52,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 7505d7a5b3..021fe630ff 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
@@ -40,7 +41,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace op = xla::testing::opcode_matchers;
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index be12d7c90c..540bbb7c7a 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -71,6 +71,7 @@ cc_library(
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
"@llvm//:support",
"@llvm//:target",
@@ -92,6 +93,7 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
],
)
@@ -108,6 +110,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
],
)
@@ -163,6 +166,7 @@ cc_library(
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
],
)
@@ -200,6 +204,7 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@llvm//:core",
+ "@llvm//:support",
],
)
@@ -214,6 +219,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
],
)
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
index fe5ec1cc66..b6ae4932f5 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
@@ -61,7 +61,7 @@ ENTRY while3 {
; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:[0-9]+]]
;
; CHECK-LABEL: @condition(i8* %retval, i8* noalias %run_options, i8** noalias %params
-; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %temps, i64 0
+; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %buffer_table, i64 0
; CHECK: %[[cond_state_buf_untyped:.*]] = load i8*, i8** %[[cond_state_buf_ptr]]
; CHECK: %[[cond_state_buf_typed:.*]] = bitcast i8* %[[cond_state_buf_untyped]] to float*
; CHECK: load float, float* %[[cond_state_buf_typed]], !alias.scope ![[alias_scope_md_for_store]], !noalias ![[noalias_md_for_load:.*]]
diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
index ad350613dd..cc2e862f2e 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
@@ -99,9 +99,10 @@ static Status EmitDynamicUpdateSliceInPlaceImpl(
return LoopEmitter(loop_body_emitter, update_shape, b).EmitLoop(name);
}
-Status EmitDynamicUpdateSliceInPlace(
- tensorflow::gtl::ArraySlice<IrArray> operand_arrays,
- const IrArray& output_array, absl::string_view name, llvm::IRBuilder<>* b) {
+Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,
+ const IrArray& output_array,
+ absl::string_view name,
+ llvm::IRBuilder<>* b) {
VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name;
// No need to use operand_arrays[0], the input array of the
@@ -129,8 +130,7 @@ Status EmitDynamicUpdateSliceInPlace(
//
// Emits a sequential loop if launch_dimensions is null.
static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
- HloInstruction* fusion,
- tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays,
+ HloInstruction* fusion, absl::Span<const IrArray> fusion_operand_arrays,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) {
CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
@@ -173,8 +173,7 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
}
Status EmitFusedDynamicUpdateSliceInPlace(
- HloInstruction* fusion,
- tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays,
+ HloInstruction* fusion, absl::Span<const IrArray> fusion_operand_arrays,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
llvm::IRBuilder<>* b) {
return EmitFusedDynamicUpdateSliceInPlaceImpl(
@@ -183,8 +182,7 @@ Status EmitFusedDynamicUpdateSliceInPlace(
}
Status EmitParallelFusedDynamicUpdateSliceInPlace(
- HloInstruction* fusion,
- tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays,
+ HloInstruction* fusion, absl::Span<const IrArray> fusion_operand_arrays,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) {
return EmitFusedDynamicUpdateSliceInPlaceImpl(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h
index e1631a62ae..fb3e4eb97c 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h
@@ -63,25 +63,24 @@ inline bool CanEmitFusedDynamicUpdateSliceInPlace(
// Emits IR for running the given dynamic-update-slice op in-place -- that is,
// where the input and output buffers share the same slice, so we can simply
// modify the input/output buffer without touching any of the other elements.
-Status EmitDynamicUpdateSliceInPlace(
- tensorflow::gtl::ArraySlice<IrArray> operand_arrays,
- const IrArray& output_array, absl::string_view name, llvm::IRBuilder<>* b);
+Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,
+ const IrArray& output_array,
+ absl::string_view name,
+ llvm::IRBuilder<>* b);
// Given a loop-fusion node whose root is a dynamic-update-slice op whose
// array-to-be-updated and output share the same buffer slice, emits
// (sequential) code for a fusion node that does the dynamic-update-slice in
// place.
Status EmitFusedDynamicUpdateSliceInPlace(
- HloInstruction* fusion,
- tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays,
+ HloInstruction* fusion, absl::Span<const IrArray> fusion_operand_arrays,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
llvm::IRBuilder<>* b);
// Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with
// the given launch dimensions.
Status EmitParallelFusedDynamicUpdateSliceInPlace(
- HloInstruction* fusion,
- tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays,
+ HloInstruction* fusion, absl::Span<const IrArray> fusion_operand_arrays,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b);
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
index 6d637cad6d..b606c993a2 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
@@ -147,7 +147,7 @@ Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) {
}
Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) {
- tensorflow::gtl::ArraySlice<HloInstruction*> operands(tuple->operands());
+ absl::Span<HloInstruction* const> operands(tuple->operands());
std::vector<llvm::Type*> operand_elemental_ir_types;
for (HloInstruction* operand : operands) {
operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
index 30471480c4..44d21fa750 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <map>
#include <unordered_map>
+#include "absl/types/span.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@@ -29,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
@@ -54,7 +54,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
public:
using Generator = llvm_ir::ElementGenerator;
- FusedIrEmitter(tensorflow::gtl::ArraySlice<llvm_ir::IrArray> parameter_arrays,
+ FusedIrEmitter(absl::Span<const llvm_ir::IrArray> parameter_arrays,
ElementalIrEmitter* elemental_emitter)
: parameter_arrays_(parameter_arrays),
tiled_parameter_info_(nullptr),
@@ -94,7 +94,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
private:
// Arrays of parameters of fusion instruction
- tensorflow::gtl::ArraySlice<llvm_ir::IrArray> parameter_arrays_;
+ absl::Span<const llvm_ir::IrArray> parameter_arrays_;
const llvm_ir::TiledParameterInfo* tiled_parameter_info_;
ElementalIrEmitter* elemental_emitter_;
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
index 6971220022..67f7423121 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
@@ -73,7 +73,7 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
Delinearize(&multidim_, linear, shape, b);
}
-IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
+IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
llvm::Value* linear, const Shape& shape)
: multidim_(multidim.begin(), multidim.end()),
linear_(linear),
@@ -92,7 +92,7 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
<< " should have a layout.";
}
-IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
+IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
const Shape& shape, llvm::IRBuilder<>* b)
: multidim_(multidim.begin(), multidim.end()),
layout_(shape.layout()),
@@ -147,16 +147,15 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
// indices in the same common factor.
for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
llvm::Value* logical_linear_index =
- Index(tensorflow::gtl::ArraySlice<llvm::Value*>(
- multidim_, common_factors[k].second,
+ Index(absl::Span<llvm::Value* const>(multidim_).subspan(
+ common_factors[k].second,
common_factors[k + 1].second - common_factors[k].second),
index_type_)
- .Linearize(
- tensorflow::gtl::ArraySlice<int64>(
- AsInt64Slice(output_shape.dimensions()),
- common_factors[k].second,
- common_factors[k + 1].second - common_factors[k].second),
- builder);
+ .Linearize(AsInt64Slice(output_shape.dimensions())
+ .subspan(common_factors[k].second,
+ common_factors[k + 1].second -
+ common_factors[k].second),
+ builder);
// Delinearizes logical_linear_index for the source array in row-major
// collapsed order. The first rank-1 indices are the remainder of the
// linear index by each dimension size.
@@ -185,9 +184,8 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
}
IrArray::Index IrArray::Index::SourceIndexOfSlice(
- const Shape& shape, tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> strides,
- llvm::IRBuilder<>* builder) const {
+ const Shape& shape, absl::Span<const int64> starts,
+ absl::Span<const int64> strides, llvm::IRBuilder<>* builder) const {
Index source_index(index_type_, multidim_.size());
for (int i = 0; i < multidim_.size(); ++i) {
int64 stride = strides[i];
@@ -208,7 +206,7 @@ IrArray::Index IrArray::Index::SourceIndexOfSlice(
IrArray::Index IrArray::Index::SourceIndexOfTranspose(
const Shape& shape, const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> dimension_mapping,
+ absl::Span<const int64> dimension_mapping,
llvm::IRBuilder<>* builder) const {
std::vector<llvm::Value*> operand_multidim_index =
Permute(dimension_mapping, multidim());
@@ -257,7 +255,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast(
IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
const Shape& shape, const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> dimension_mapping,
+ absl::Span<const int64> dimension_mapping,
llvm::IRBuilder<>* builder) const {
int64 rank = ShapeUtil::Rank(operand_shape);
std::vector<llvm::Value*> source_index(rank);
@@ -322,9 +320,8 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
return Index(source_index, linear, operand_shape);
}
-llvm::Value* IrArray::Index::Linearize(
- tensorflow::gtl::ArraySlice<int64> dimensions,
- llvm::IRBuilder<>* builder) const {
+llvm::Value* IrArray::Index::Linearize(absl::Span<const int64> dimensions,
+ llvm::IRBuilder<>* builder) const {
// Each dimension is multiplied by the product of the sizes of all
// earlier dimensions and added to the accumulator logical_linear_index.
CHECK_EQ(size(), dimensions.size());
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
index e913c109b3..f4b05f29c3 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
@@ -21,12 +21,12 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -70,7 +70,7 @@ class IrArray {
// Constructs an index from multi-dimensional index "multidim". The linear
// index is set to nullptr.
- explicit Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
+ explicit Index(absl::Span<llvm::Value* const> multidim,
llvm::Type* index_ty = nullptr)
: multidim_(multidim.begin(), multidim.end()) {
if (size() == 0) {
@@ -99,14 +99,14 @@ class IrArray {
// that it indexes into.
//
// Precondition: "shape" has a layout.
- Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
- const Shape& shape, llvm::IRBuilder<>* b);
+ Index(absl::Span<llvm::Value* const> multidim, const Shape& shape,
+ llvm::IRBuilder<>* b);
// Constructs an index from both a multi-dimensional index and a linear
// index. "shape" has the same meaning as that in the constructor that takes
// only a linear index.
- Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
- llvm::Value* linear, const Shape& shape);
+ Index(absl::Span<llvm::Value* const> multidim, llvm::Value* linear,
+ const Shape& shape);
const std::vector<llvm::Value*>& multidim() const { return multidim_; }
llvm::Value* linear() const { return linear_; }
@@ -145,17 +145,15 @@ class IrArray {
// by starting indices `starts` and stride values `strides`.
//
// Precondition: "this" is an index into a slice whose shape is `shape`.
- Index SourceIndexOfSlice(const Shape& shape,
- tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> strides,
+ Index SourceIndexOfSlice(const Shape& shape, absl::Span<const int64> starts,
+ absl::Span<const int64> strides,
llvm::IRBuilder<>* builder) const;
// Given that "this" is the target index of a transpose from `operand_shape`
// to `shape` with the given dimension mapping, returns the source index.
- Index SourceIndexOfTranspose(
- const Shape& shape, const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> dimension_mapping,
- llvm::IRBuilder<>* builder) const;
+ Index SourceIndexOfTranspose(const Shape& shape, const Shape& operand_shape,
+ absl::Span<const int64> dimension_mapping,
+ llvm::IRBuilder<>* builder) const;
// Given that "this" is the target index of a bitcast from `operand_shape`
// to `shape`, returns the source index.
@@ -164,14 +162,13 @@ class IrArray {
// Given that "this" is the target index of a broadcast from `operand_shape`
// to `shape` with the given dimension mapping, returns the source index.
- Index SourceIndexOfBroadcast(
- const Shape& shape, const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> dimension_mapping,
- llvm::IRBuilder<>* builder) const;
+ Index SourceIndexOfBroadcast(const Shape& shape, const Shape& operand_shape,
+ absl::Span<const int64> dimension_mapping,
+ llvm::IRBuilder<>* builder) const;
// Linearizes the index into the given shape, i.e. reshapes it to rank-1 and
// returns the index into the sole dimension 0 of the new shape.
- llvm::Value* Linearize(tensorflow::gtl::ArraySlice<int64> dimensions,
+ llvm::Value* Linearize(absl::Span<const int64> dimensions,
llvm::IRBuilder<>* builder) const;
llvm::Type* GetType() const { return index_type_; }
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
index b152cf9275..43fec311f1 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
@@ -235,7 +235,7 @@ class KernelSupportLibrary {
}));
}
- using ArgumentVector = tensorflow::gtl::ArraySlice<llvm::Value*>;
+ using ArgumentVector = absl::Span<llvm::Value* const>;
// Generates the following control flow structure:
//
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
index cb4d1db997..e5fbdbd51b 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
@@ -28,7 +28,7 @@ namespace {
// Returns the indices of the first elements of all consecutive subarrays of the
// given array. For example:
// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
-std::vector<size_t> ConsecutiveSegments(tensorflow::gtl::ArraySlice<int64> xs) {
+std::vector<size_t> ConsecutiveSegments(absl::Span<const int64> xs) {
std::vector<size_t> is = {0};
for (size_t i = 1; i < xs.size(); ++i) {
if (1 != xs[i] - xs[i - 1]) {
@@ -40,8 +40,7 @@ std::vector<size_t> ConsecutiveSegments(tensorflow::gtl::ArraySlice<int64> xs) {
// Merges the sequences of dimensions of the given shape which start at the
// given indices `segs`.
-Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs,
- const Shape& shape) {
+Shape MergeDimensions(absl::Span<const size_t> segs, const Shape& shape) {
std::vector<int64> dimensions;
for (size_t i = 1; i <= segs.size(); ++i) {
dimensions.push_back(std::accumulate(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
index 8bd06c42c3..5ea05b3188 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
@@ -50,7 +50,7 @@ IrArray::Index GetUnreducedOutputIndex(
// for 021 transpose.
class TiledParameterInfo {
public:
- TiledParameterInfo(tensorflow::gtl::ArraySlice<llvm::Value*> param_buffers,
+ TiledParameterInfo(absl::Span<llvm::Value* const> param_buffers,
llvm::Value* y, llvm::Value* x)
: param_buffers_(param_buffers), y_(y), x_(x) {}
@@ -67,7 +67,7 @@ class TiledParameterInfo {
private:
// Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr
// if the parameter is not tiled.
- tensorflow::gtl::ArraySlice<llvm::Value*> param_buffers_;
+ absl::Span<llvm::Value* const> param_buffers_;
// The y coordinate within a tile.
llvm::Value* y_;
// The x coordinate within a tile.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
index 9f3329e7f0..219a9f221f 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
@@ -241,7 +241,7 @@ IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
}
IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions(
- const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions,
+ const Shape& shape, absl::Span<const int64> dimensions,
absl::string_view suffix) {
llvm_ir::IrArray::Index index(index_type_, shape.dimensions_size());
for (int64 dimension : dimensions) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
index 0a406bd90b..ac3bba3c9f 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
@@ -21,13 +21,13 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -242,7 +242,7 @@ class ForLoopNest {
// size equals the rank of shape and there is a null for each
// dimension that is not in "dimensions".
IrArray::Index AddLoopsForShapeOnDimensions(
- const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions,
+ const Shape& shape, absl::Span<const int64> dimensions,
absl::string_view suffix);
// Emits a series of nested loops for iterating over an operand array. Loops
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index f0db2a3761..1a53c026be 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -83,11 +83,10 @@ string DumpModuleToString(const llvm::Module& module) {
return AsString(buffer_string);
}
-llvm::Value* EmitCallToIntrinsic(
- llvm::Intrinsic::ID intrinsic_id,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<llvm::Type*> overloaded_types,
- llvm::IRBuilder<>* b) {
+llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id,
+ absl::Span<llvm::Value* const> operands,
+ absl::Span<llvm::Type* const> overloaded_types,
+ llvm::IRBuilder<>* b) {
llvm::Module* module = ModuleFromIRBuilder(b);
llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
module, intrinsic_id, AsArrayRef(overloaded_types));
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
index dde50e19d1..f59baff263 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/IRBuilder.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
namespace llvm {
@@ -59,7 +59,7 @@ llvm::ArrayRef<T> AsArrayRef(const std::vector<T>& vec) {
}
template <typename T>
-llvm::ArrayRef<T> AsArrayRef(const tensorflow::gtl::ArraySlice<T>& slice) {
+llvm::ArrayRef<T> AsArrayRef(const absl::Span<const T>& slice) {
return llvm::ArrayRef<T>(slice.data(), slice.size());
}
@@ -101,11 +101,10 @@ string SanitizeFunctionName(string function_name);
// intrinsics (for example, "minnum") must include a type in overloaded_types
// for each overloaded type. Typically, overloaded intrinsics have only a single
// overloaded type.
-llvm::Value* EmitCallToIntrinsic(
- llvm::Intrinsic::ID intrinsic_id,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<llvm::Type*> overloaded_types,
- llvm::IRBuilder<>* b);
+llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id,
+ absl::Span<llvm::Value* const> operands,
+ absl::Span<llvm::Type* const> overloaded_types,
+ llvm::IRBuilder<>* b);
// Emit float max. Emit maxnum intrinsic is fast math is disabled, or
// fcmp+select otherwise
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
index 1553b4fc91..0dc120e0b0 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
@@ -69,7 +69,7 @@ static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion(
}
LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
- tensorflow::gtl::ArraySlice<IrArray> target_arrays,
+ absl::Span<const IrArray> target_arrays,
llvm::IRBuilder<>* b)
: body_emitter_(MakeBodyEmitterForMultiOutputFusion(
target_element_generator,
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
index 57d9d8bbc6..a537c00066 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
@@ -53,8 +53,7 @@ class LoopEmitter {
// This is used for multi-output fusion. target_element_generator must
// produce an LLVM struct with N elements.
LoopEmitter(const ElementGenerator& target_element_generator,
- tensorflow::gtl::ArraySlice<IrArray> target_arrays,
- llvm::IRBuilder<>* b);
+ absl::Span<const IrArray> target_arrays, llvm::IRBuilder<>* b);
LoopEmitter(const LoopEmitter&) = delete;
LoopEmitter& operator=(const LoopEmitter&) = delete;
diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
index 00dd3f1638..944c79580c 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
@@ -18,6 +18,7 @@ limitations under the License.
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
+#include "llvm/ADT/APInt.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Instructions.h"
@@ -59,15 +60,39 @@ void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index,
SetToFirstInsertPoint(if_data.true_block, b);
auto key1 = keys_array.EmitReadArrayElement(keys_index, b);
auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, b);
+ auto compare_key1 = key1;
+ auto compare_key2 = key2;
auto key_type = keys_array.GetShape().element_type();
+ bool is_signed_comparison = true;
+ if (primitive_util::IsFloatingPointType(key_type)) {
+ // We would like a total order of floating point numbers so that the sort
+ // has a predictable behavior in the presence of NaNs. Rather than using
+ // floating point comparison, we use the following trick:
+ // If f is a float, and
+ // x = bit_cast<int32>(f);
+ // y = x < 0 ? 0x7FFFFFFF - x : x;
+ // then y is ordered as an int32 such that finite values have the obvious
+ // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning
+ // and end of the ordering.
+ auto k = b->getInt(llvm::APInt::getSignedMaxValue(
+ key1->getType()->getPrimitiveSizeInBits()));
+ auto comparison_type = k->getType();
+ auto zero = llvm::ConstantInt::get(comparison_type, 0);
+ auto maybe_flip = [&](llvm::Value* v) {
+ return b->CreateSelect(b->CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero),
+ b->CreateSub(k, v), v);
+ };
+ compare_key1 = b->CreateBitCast(key1, comparison_type);
+ compare_key2 = b->CreateBitCast(key2, comparison_type);
+ compare_key1 = maybe_flip(compare_key1);
+ compare_key2 = maybe_flip(compare_key2);
+ } else if (!primitive_util::IsSignedIntegralType(key_type)) {
+ is_signed_comparison = false;
+ }
auto comparison =
- primitive_util::IsFloatingPointType(key_type)
- // TODO(b/26783907): Figure out how to handle NaNs.
- ? b->CreateFCmp(llvm::FCmpInst::FCMP_ULT, key2, key1)
- : b->CreateICmp(primitive_util::IsSignedIntegralType(key_type)
- ? llvm::ICmpInst::ICMP_SLT
- : llvm::ICmpInst::ICMP_ULT,
- key2, key1);
+ b->CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT
+ : llvm::ICmpInst::ICMP_ULT,
+ compare_key2, compare_key1);
// If key2 < key1
auto if_smaller_data =
EmitIfThenElse(comparison, "is_smaller_than", b, /*emit_else=*/false);
diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
index 11ed6ee59f..7d49b8d6c2 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
@@ -64,8 +64,7 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred,
}
}
-void EmitTuple(const IrArray& tuple,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
+void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
llvm::IRBuilder<>* b, llvm::Module* module) {
for (size_t i = 0; i < operands.size(); ++i) {
auto* store = b->CreateStore(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
index cf6bf5d0b1..887fb61371 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
@@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_TUPLE_OPS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_TUPLE_OPS_H_
+#include "absl/types/span.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
// Utilities for emitting LLVM IR related to HLO tuples.
@@ -65,8 +65,7 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred,
// A tuple is an array of pointers, one for each operand. Each pointer points to
// the output buffer of its corresponding operand.
-void EmitTuple(const IrArray& tuple,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
+void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
llvm::IRBuilder<>* b, llvm::Module* module);
// A tuple is an array of pointers, one for each operand. Each pointer points to
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 768105d9e1..0d0fb7946a 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -141,7 +141,7 @@ ExecutionOptions CreateExecutionOptions(
StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
const XlaComputation& computation,
- const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& build_options) {
const HloModuleProto& proto = computation.proto();
TF_RET_CHECK(proto.has_program_shape());
diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h
index 8f707ea904..3b4f0b5083 100644
--- a/tensorflow/compiler/xla/service/local_service.h
+++ b/tensorflow/compiler/xla/service/local_service.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/service/backend.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
@@ -48,7 +48,7 @@ class LocalService : public Service {
// compiler is responsible for freeing any memory it allocates this way.
StatusOr<std::unique_ptr<Executable>> CompileExecutable(
const XlaComputation& computation,
- const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& build_options);
// Returns the device ordinal that corresponds to the given replica number.
diff --git a/tensorflow/compiler/xla/service/logical_buffer.h b/tensorflow/compiler/xla/service/logical_buffer.h
index f9ba5a5547..ceacab4ed7 100644
--- a/tensorflow/compiler/xla/service/logical_buffer.h
+++ b/tensorflow/compiler/xla/service/logical_buffer.h
@@ -18,13 +18,13 @@ limitations under the License.
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/int_type.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc
new file mode 100644
index 0000000000..8269842426
--- /dev/null
+++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc
@@ -0,0 +1,41 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
+#include "absl/types/variant.h"
+namespace xla {
+
+se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase() {
+ if (HasOwnership()) {
+ return absl::get<OwningDeviceMemory>(mem_).AsDeviceMemoryBase();
+ } else {
+ return absl::get<se::DeviceMemoryBase>(mem_);
+ }
+}
+
+bool MaybeOwningDeviceMemory::HasOwnership() const {
+ return absl::holds_alternative<OwningDeviceMemory>(mem_);
+}
+
+absl::optional<OwningDeviceMemory> MaybeOwningDeviceMemory::Release() {
+ if (!HasOwnership()) {
+ return {};
+ }
+ OwningDeviceMemory result = std::move(absl::get<OwningDeviceMemory>(mem_));
+ mem_ = result.AsDeviceMemoryBase();
+ return absl::make_optional<OwningDeviceMemory>(std::move(result));
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.h b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h
new file mode 100644
index 0000000000..82e7f1183c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h
@@ -0,0 +1,70 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MAYBE_OWNING_DEVICE_MEMORY_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_MAYBE_OWNING_DEVICE_MEMORY_H_
+
+#include "absl/types/optional.h"
+#include "absl/types/variant.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/owning_device_memory.h"
+
+namespace xla {
+
+// MaybeOwningDeviceMemory represents either an owned or unowned device memory.
+// Like std::variant<OwningDeviceMemory, DeviceMemory>. When the object goes
+// output of scope, it will free the underlying memory if it owns it.
+class MaybeOwningDeviceMemory {
+ public:
+ MaybeOwningDeviceMemory() = default;
+ explicit MaybeOwningDeviceMemory(OwningDeviceMemory owned)
+ : mem_(std::move(owned)) {}
+ explicit MaybeOwningDeviceMemory(se::DeviceMemoryBase unowned)
+ : mem_(unowned) {}
+ MaybeOwningDeviceMemory(MaybeOwningDeviceMemory&&) = default;
+ ~MaybeOwningDeviceMemory() = default;
+
+ MaybeOwningDeviceMemory& operator=(se::DeviceMemoryBase unowned) {
+ mem_ = unowned;
+ return *this;
+ }
+
+ MaybeOwningDeviceMemory& operator=(OwningDeviceMemory owned) {
+ mem_ = std::move(owned);
+ return *this;
+ }
+
+ MaybeOwningDeviceMemory& operator=(MaybeOwningDeviceMemory&&) = default;
+
+ // Fetches the underlying DeviceMemoryBase from a MaybeOwningDeviceMemory. The
+ // caller of this function is *not* responsible for freeing the memory.
+ se::DeviceMemoryBase AsDeviceMemoryBase();
+
+ // Release the OwningDeviceMemory without freeing it, and moves the ownership
+ // of the memory buffer from the object to the caller.
+ //
+ // A nullopt is returned if the HasOwnership() == false;
+ absl::optional<OwningDeviceMemory> Release();
+
+ // Returns true if the device_memory has ownership over underlying memory.
+ bool HasOwnership() const;
+
+ private:
+ absl::variant<OwningDeviceMemory, se::DeviceMemoryBase> mem_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MAYBE_OWNING_DEVICE_MEMORY_H_
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc
index 4166ef5baf..b9ec31c497 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc
@@ -262,7 +262,7 @@ void MultiOutputFusion::RecomputeReachability() {
void MultiOutputFusion::UpdateReachability(
HloInstruction* instr1, HloInstruction* instr2,
- tensorflow::gtl::ArraySlice<HloInstruction*> instrs_to_update,
+ absl::Span<HloInstruction* const> instrs_to_update,
const std::function<bool(HloInstruction*)>& skip) {
for (auto instr : instrs_to_update) {
if (skip != nullptr && skip(instr)) {
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h
index 4c8cb7d379..d2c52651c4 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.h
@@ -92,7 +92,7 @@ class MultiOutputFusion : public HloPassInterface {
// Update the reachability map after fusing instr1 and instr2.
void UpdateReachability(
HloInstruction* instr1, HloInstruction* instr2,
- tensorflow::gtl::ArraySlice<HloInstruction*> instrs_to_update,
+ absl::Span<HloInstruction* const> instrs_to_update,
const std::function<bool(HloInstruction*)>& skip = nullptr);
// Hook for multi-output fusion along producer-consumer edges.
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
index ccc06ce613..4869db79e7 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher.h
+++ b/tensorflow/compiler/xla/service/pattern_matcher.h
@@ -918,6 +918,7 @@ Op(::xla::HloInstruction** matched_inst) {
}
XLA_NULLOP_PATTERN(Constant)
XLA_NULLOP_PATTERN(Parameter)
+XLA_NULLOP_PATTERN(Iota)
#undef XLA_NULLOP_PATTERN
// Helpers for unary instructions.
diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc
index ae1e13d8a6..178a78ede0 100644
--- a/tensorflow/compiler/xla/service/platform_util.cc
+++ b/tensorflow/compiler/xla/service/platform_util.cc
@@ -89,7 +89,11 @@ PlatformUtil::GetSupportedPlatforms() {
if (platforms.empty()) {
return NotFound("no platforms found");
} else if (platforms.size() == 1) {
- return platforms[0];
+ se::Platform* platform = platforms[0];
+ if (!platform->Initialized()) {
+ TF_RETURN_IF_ERROR(platform->Initialize({}));
+ }
+ return platform;
}
// Multiple platforms present and we can't pick a reasonable default.
@@ -103,18 +107,27 @@ PlatformUtil::GetSupportedPlatforms() {
/* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() {
TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
+
+ se::Platform* platform = nullptr;
if (platforms.empty()) {
return NotFound("no platforms found");
} else if (platforms.size() == 1) {
- return platforms[0];
+ platform = platforms[0];
} else if (platforms.size() == 2) {
for (int i = 0; i < 2; i++) {
if (absl::AsciiStrToLower(platforms[i]->Name()) == kInterpreter &&
absl::AsciiStrToLower(platforms[1 - i]->Name()) != kInterpreter) {
- return platforms[1 - i];
+ platform = platforms[1 - i];
+ break;
}
}
}
+ if (platform != nullptr) {
+ if (!platform->Initialized()) {
+ TF_RETURN_IF_ERROR(platform->Initialize({}));
+ }
+ return platform;
+ }
// Multiple platforms present and we can't pick a reasonable default.
string platforms_string = absl::StrJoin(
@@ -132,6 +145,9 @@ PlatformUtil::GetSupportedPlatforms() {
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
for (se::Platform* platform : platforms) {
if (absl::AsciiStrToLower(platform->Name()) == platform_str) {
+ if (!platform->Initialized()) {
+ TF_RETURN_IF_ERROR(platform->Initialize({}));
+ }
return platform;
}
}
@@ -154,7 +170,11 @@ PlatformUtil::GetSupportedPlatforms() {
platform_name);
}
if (matched.size() == 1) {
- return matched[0];
+ auto platform = matched[0];
+ if (!platform->Initialized()) {
+ TF_RETURN_IF_ERROR(platform->Initialize({}));
+ }
+ return platform;
}
string matched_string = absl::StrJoin(
matched, ", ",
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc
index a395dd5333..fcf269eee9 100644
--- a/tensorflow/compiler/xla/service/reshape_mover_test.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc
@@ -34,12 +34,7 @@ namespace {
namespace op = xla::testing::opcode_matchers;
-class ReshapeMoverTest : public HloVerifiedTestBase {
- public:
- ReshapeMoverTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-};
+class ReshapeMoverTest : public HloVerifiedTestBase {};
TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) {
HloComputation::Builder builder(TestName());
diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc
index 2077b57c05..2f4b2667c4 100644
--- a/tensorflow/compiler/xla/service/scatter_expander.cc
+++ b/tensorflow/compiler/xla/service/scatter_expander.cc
@@ -26,7 +26,6 @@ limitations under the License.
namespace xla {
-using tensorflow::gtl::ArraySlice;
// Transposes the given scatter_indices such that the index_vector_dim becomes
// the most-minor dimension.
@@ -87,7 +86,7 @@ static StatusOr<HloInstruction*> CanonicalizeScatterIndices(
// major dimensions and all the window dimensions appear in the minor
// dimensions.
static StatusOr<HloInstruction*> PermuteScatterAndWindowDims(
- HloInstruction* updates, ArraySlice<int64> update_window_dims) {
+ HloInstruction* updates, absl::Span<const int64> update_window_dims) {
std::vector<int64> permutation;
const int64 updates_rank = ShapeUtil::Rank(updates->shape());
permutation.reserve(updates_rank);
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index e10c1d9927..f0e2566a3f 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -62,10 +62,9 @@ using absl::StrCat;
using absl::StrFormat;
// Records the arguments used to invoke a computation in an HloSnapshot proto.
-Status RecordArguments(
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- se::Stream* stream, TransferManager* transfer_manager,
- HloSnapshot* module) {
+Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
+ se::Stream* stream, TransferManager* transfer_manager,
+ HloSnapshot* module) {
module->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
TF_ASSIGN_OR_RETURN(
@@ -207,8 +206,8 @@ Status Service::ValidateResultShape(const Shape& client_shape,
StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
Service::ResolveAndValidateArguments(
- tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors) {
+ absl::Span<const GlobalDataHandle* const> arguments,
+ absl::Span<se::StreamExecutor* const> stream_executors) {
CHECK_EQ(options_.number_of_replicas(), stream_executors.size());
std::vector<std::vector<const ShapedBuffer*>> replicated_arguments;
replicated_arguments.resize(options_.number_of_replicas());
@@ -242,7 +241,7 @@ Service::ResolveAndValidateArguments(
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
- tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
+ absl::Span<const Shape* const> argument_shapes,
const ExecutionOptions* execution_options) {
auto config = absl::make_unique<HloModuleConfig>(program_shape);
ComputationLayout* computation_layout =
@@ -299,7 +298,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
const ExecutionOptions& execution_options) {
std::vector<const Shape*> argument_shapes;
for (const auto* arg : arguments) {
@@ -367,12 +366,10 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
StatusOr<std::vector<GlobalDataHandle>>
Service::ExecuteParallelAndRegisterResult(
- tensorflow::gtl::ArraySlice<Executable*> executables,
- tensorflow::gtl::ArraySlice<std::vector<std::vector<const ShapedBuffer*>>>
- arguments,
- Backend* backend, tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
- tensorflow::gtl::ArraySlice<string> result_tags,
- ExecutionProfile* profile) {
+ absl::Span<Executable* const> executables,
+ absl::Span<const std::vector<std::vector<const ShapedBuffer*>>> arguments,
+ Backend* backend, absl::Span<const DeviceHandle> device_handles,
+ absl::Span<const string> result_tags, ExecutionProfile* profile) {
// Streams where the computation are launched, so we can wait on the streams
// to complete.
std::vector<StreamPool::Ptr> streams;
@@ -511,8 +508,7 @@ Service::ExecuteParallelAndRegisterResult(
StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
Executable* executable,
- const tensorflow::gtl::ArraySlice<std::vector<const ShapedBuffer*>>
- arguments,
+ const absl::Span<const std::vector<const ShapedBuffer*>> arguments,
Backend* backend, const string& result_tag, ExecutionProfile* profile) {
// Set up streams.
std::vector<StreamPool::Ptr> streams;
@@ -555,8 +551,7 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
// TODO(b/69985541): Support profiling also on this path.
- std::vector<tensorflow::gtl::ArraySlice<const ShapedBuffer*>>
- replicated_arguments;
+ std::vector<absl::Span<const ShapedBuffer* const>> replicated_arguments;
for (const auto& arg : arguments) {
replicated_arguments.push_back(arg);
}
@@ -595,7 +590,7 @@ StatusOr<std::vector<se::StreamExecutor*>> Service::GetExecutors(
StatusOr<std::vector<std::vector<const ShapedBuffer*>>> Service::GetArguments(
const ExecutionOptions& execution_options,
- tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments) {
+ absl::Span<const GlobalDataHandle* const> arguments) {
// Resolve the allocations for the arguments of the computation, and create
// a vector of device memory offsets for the arguments from the allocations.
// In the case of partitioned computations, assume all arguments go on the
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 47d196fb2a..44c5248b15 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/allocation_tracker.h"
@@ -37,7 +38,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -176,7 +176,7 @@ class Service : public ServiceInterface {
// class.
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
const ExecutionOptions& execution_options);
// Picks a parallel response and fills the result.
@@ -191,7 +191,7 @@ class Service : public ServiceInterface {
// Prepare the arguments for executing parallel.
StatusOr<std::vector<std::vector<const ShapedBuffer*>>> GetArguments(
const ExecutionOptions& execution_options,
- tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments);
+ absl::Span<const GlobalDataHandle* const> arguments);
protected:
friend class LocalExecutable;
@@ -207,14 +207,14 @@ class Service : public ServiceInterface {
// the corresponding replica.
StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
ResolveAndValidateArguments(
- tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors);
+ absl::Span<const GlobalDataHandle* const> arguments,
+ absl::Span<se::StreamExecutor* const> stream_executors);
// Create a Hlo module config for the given program shape and arguments.
// execution_options is optional; if not given a default is used.
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
- tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
+ absl::Span<const Shape* const> argument_shapes,
const ExecutionOptions* execution_options);
// Builds an Executable for the given parameters.
@@ -242,21 +242,17 @@ class Service : public ServiceInterface {
// ExecutionProfile object which will be filled in with profile data.
StatusOr<GlobalDataHandle> ExecuteAndRegisterResult(
Executable* executable,
- const tensorflow::gtl::ArraySlice<std::vector<const ShapedBuffer*>>
- arguments,
+ const absl::Span<const std::vector<const ShapedBuffer*>> arguments,
Backend* backend, const string& result_tag, ExecutionProfile* profile);
// Runs the given executables with the given arguments and register the result
// from each executable in the allocation tracker. The handles of the result
// from the tracker are returned.
StatusOr<std::vector<GlobalDataHandle>> ExecuteParallelAndRegisterResult(
- tensorflow::gtl::ArraySlice<Executable*> executables,
- tensorflow::gtl::ArraySlice<std::vector<std::vector<const ShapedBuffer*>>>
- arguments,
- Backend* backend,
- tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
- tensorflow::gtl::ArraySlice<string> result_tags,
- ExecutionProfile* profile);
+ absl::Span<Executable* const> executables,
+ absl::Span<const std::vector<std::vector<const ShapedBuffer*>>> arguments,
+ Backend* backend, absl::Span<const DeviceHandle> device_handles,
+ absl::Span<const string> result_tags, ExecutionProfile* profile);
// Executes a single computation which has more than one target device.
// The N devices are expected to all return an empty tuple, but one, which
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index a04af8b0aa..2611749862 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -45,7 +45,7 @@ using absl::StrFormat;
using absl::StrJoin;
// Returns true if no element is present in slice more than once.
-bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) {
+bool AllUnique(absl::Span<const int64> slice) {
return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
}
@@ -57,11 +57,10 @@ Status ExpectArray(const Shape& shape, absl::string_view op_type) {
return Status::OK();
}
-Status VerifyReducerShape(
- const ProgramShape& reducer_shape,
- tensorflow::gtl::ArraySlice<const Shape*> init_value_shapes,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_element_types,
- int64 inputs) {
+Status VerifyReducerShape(const ProgramShape& reducer_shape,
+ absl::Span<const Shape* const> init_value_shapes,
+ absl::Span<const PrimitiveType> input_element_types,
+ int64 inputs) {
if (reducer_shape.parameters_size() != inputs * 2) {
return InvalidArgument(
"Reduction function must take %d parameters, but "
@@ -335,8 +334,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
}
/* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const int64 dimension) {
+ absl::Span<const Shape* const> arg_shapes, const int64 dimension) {
if (arg_shapes.empty()) {
return InvalidArgument("Concatenate expects at least one argument.");
}
@@ -394,7 +392,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
}
/* static */ StatusOr<Shape> ShapeInference::InferAfterAllShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes) {
+ absl::Span<const Shape* const> arg_shapes) {
for (const Shape* arg_shape : arg_shapes) {
if (arg_shape->element_type() != TOKEN) {
return InvalidArgument(
@@ -505,13 +503,21 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return InvalidArgument(
"The element types of the operands to Pad do not match.");
}
+ if (absl::c_any_of(padding_config.dimensions(),
+ [](const PaddingConfig::PaddingConfigDimension& p) {
+ return p.interior_padding() < 0;
+ })) {
+ return InvalidArgument("Interior padding cannot be negative: %s",
+ padding_config.ShortDebugString());
+ }
+
std::vector<int64> dimensions(ShapeUtil::Rank(operand_shape));
for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
- dimensions[i] = operand_shape.dimensions(i) +
- padding_config.dimensions(i).edge_padding_low() +
- padding_config.dimensions(i).edge_padding_high() +
+ const auto& p = padding_config.dimensions(i);
+ dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() +
+ p.edge_padding_high() +
std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
- padding_config.dimensions(i).interior_padding();
+ p.interior_padding();
}
return ShapeUtil::MakeShape(
ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape),
@@ -542,22 +548,22 @@ Status ValidateDotDimensionNumbers(
const Shape& lhs, const Shape& rhs,
const DotDimensionNumbers& dimension_numbers) {
// Check that dimension numbers are in range.
- auto dims_in_range =
- [](const int64 rank, tensorflow::gtl::ArraySlice<int64> contracting_dims,
- tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool {
+ auto dims_in_range = [](const int64 rank,
+ absl::Span<const int64> contracting_dims,
+ absl::Span<const int64> batch_dims) -> bool {
auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; };
return std::all_of(contracting_dims.begin(), contracting_dims.end(),
in_range) &&
std::all_of(batch_dims.begin(), batch_dims.end(), in_range);
};
- tensorflow::gtl::ArraySlice<int64> lhs_contracting_dimensions =
+ absl::Span<const int64> lhs_contracting_dimensions =
AsInt64Slice(dimension_numbers.lhs_contracting_dimensions());
- tensorflow::gtl::ArraySlice<int64> rhs_contracting_dimensions =
+ absl::Span<const int64> rhs_contracting_dimensions =
AsInt64Slice(dimension_numbers.rhs_contracting_dimensions());
- tensorflow::gtl::ArraySlice<int64> lhs_batch_dimensions =
+ absl::Span<const int64> lhs_batch_dimensions =
AsInt64Slice(dimension_numbers.lhs_batch_dimensions());
- tensorflow::gtl::ArraySlice<int64> rhs_batch_dimensions =
+ absl::Span<const int64> rhs_batch_dimensions =
AsInt64Slice(dimension_numbers.rhs_batch_dimensions());
if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions,
@@ -569,8 +575,8 @@ Status ValidateDotDimensionNumbers(
}
// Check that dimension numbers are unique.
- auto dims_unique = [](tensorflow::gtl::ArraySlice<int64> contracting_dims,
- tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool {
+ auto dims_unique = [](absl::Span<const int64> contracting_dims,
+ absl::Span<const int64> batch_dims) -> bool {
tensorflow::gtl::FlatSet<int64> dim_set;
auto is_unique = [&dim_set](int64 i) -> bool {
return dim_set.insert(i).second;
@@ -740,7 +746,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
const Shape& smaller_shape, const Shape& larger_shape,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
// Reject "magic" inference for binops on different shapes, requiring
// the user to provide an explicit broadcast dimension in this case.
@@ -841,7 +847,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
HloOpcode operation, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation"));
TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation"));
@@ -898,7 +904,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
HloOpcode opcode, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
VLOG(2) << StrFormat(
"inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
HloOpcodeString(opcode), ShapeUtil::HumanString(lhs),
@@ -997,8 +1003,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
- HloOpcode opcode,
- tensorflow::gtl::ArraySlice<const HloInstruction*> operands) {
+ HloOpcode opcode, absl::Span<const HloInstruction* const> operands) {
std::vector<const Shape*> operand_shapes;
operand_shapes.reserve(operands.size());
for (const HloInstruction* operand : operands) {
@@ -1008,8 +1013,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
- HloOpcode opcode,
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
+ HloOpcode opcode, absl::Span<const Shape* const> operand_shapes) {
for (const Shape* shape : operand_shapes) {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
}
@@ -1045,9 +1049,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferMapShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const ProgramShape& to_apply,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+ absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
+ absl::Span<const int64> dimensions) {
if (arg_shapes.empty()) {
return InvalidArgument("Map expects at least one argument.");
}
@@ -1703,7 +1706,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferFftShape(
const Shape& in, const FftType fft_type,
- const tensorflow::gtl::ArraySlice<int64> fft_length) {
+ const absl::Span<const int64> fft_length) {
const int64 fft_rank = fft_length.size();
if (fft_rank < 1 || fft_rank > 3) {
return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank);
@@ -1784,7 +1787,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferCrossReplicaSumShape(
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
+ absl::Span<const Shape* const> operand_shapes) {
for (const Shape* operand_shape : operand_shapes) {
TF_RETURN_IF_ERROR(
ExpectArray(*operand_shape, "operand of cross replica sum"));
@@ -1827,7 +1830,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferAllToAllTupleShape(
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
+ absl::Span<const Shape* const> operand_shapes) {
// An Alltoall HLO instruction receives N operands (with the same shape) and
// returns a tuple that contains N array shapes.
TF_RET_CHECK(!operand_shapes.empty());
@@ -1851,8 +1854,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ absl::Span<const Shape* const> arg_shapes,
+ absl::Span<const int64> dimensions_to_reduce,
const ProgramShape& to_apply) {
if (arg_shapes.empty()) {
return InvalidArgument("Reduce must have at least 2 arguments, has 0");
@@ -1864,8 +1867,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
int64 num_reduced_args = arg_shapes.size() / 2;
- tensorflow::gtl::ArraySlice<const Shape*> reduced_args(arg_shapes, 0,
- num_reduced_args);
+ auto reduced_args = arg_shapes.subspan(0, num_reduced_args);
// Check that all of the reduced tensors have the same dimensions. The element
// types may be different.
for (int64 i = 1; i < num_reduced_args; ++i) {
@@ -1889,8 +1891,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
}
- tensorflow::gtl::ArraySlice<const Shape*> init_values(
- arg_shapes, num_reduced_args, arg_shapes.size());
+ auto init_values = arg_shapes.subspan(num_reduced_args, arg_shapes.size());
std::vector<PrimitiveType> element_types;
for (const Shape* arg : reduced_args) {
element_types.push_back(arg->element_type());
@@ -1992,9 +1993,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
- const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> limits,
- tensorflow::gtl::ArraySlice<int64> strides) {
+ const Shape& arg, absl::Span<const int64> starts,
+ absl::Span<const int64> limits, absl::Span<const int64> strides) {
auto error = [&](const string& message) {
return InvalidArgument(
"%s in slice operation; argument shape: %s; starts: {%s}; limits: "
@@ -2056,7 +2056,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
const Shape& operand_shape, const Shape& start_indices_shape,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ absl::Span<const int64> slice_sizes) {
TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice"));
TF_RETURN_IF_ERROR(
ExpectArray(start_indices_shape, "start indices of dynamic slice"));
@@ -2183,7 +2183,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
- const Shape& operand_shape, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ const Shape& operand_shape, absl::Span<const int64> dimensions) {
TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse"));
if (!AllUnique(dimensions)) {
return InvalidArgument("a dimension number is duplicated in reverse");
@@ -2309,7 +2309,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
+ const Shape& operand, absl::Span<const int64> broadcast_sizes) {
TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast"));
for (int64 size : broadcast_sizes) {
if (size < 0) {
@@ -2327,8 +2327,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes) {
+ const Shape& operand, absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes) {
TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape"));
Shape inferred_shape =
@@ -2360,7 +2360,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ const Shape& operand, absl::Span<const int64> dimensions) {
TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
std::vector<int64> indices(ShapeUtil::Rank(operand));
@@ -2465,8 +2465,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferCallShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const ProgramShape& to_apply) {
+ absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply) {
// The applied function's arity equals the number of arguments.
if (arg_shapes.size() != to_apply.parameters_size()) {
string computation_signature = ShapeUtil::HumanString(to_apply);
@@ -2499,8 +2498,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
static Status ValidateGatherDimensionNumbers(
- const Shape& input_shape,
- tensorflow::gtl::ArraySlice<int64> start_indices_shape,
+ const Shape& input_shape, absl::Span<const int64> start_indices_shape,
const GatherDimensionNumbers& dim_numbers) {
if (!absl::c_is_sorted(dim_numbers.offset_dims())) {
return InvalidArgument(
@@ -2593,7 +2591,7 @@ static Status ValidateGatherDimensionNumbers(
/*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
const Shape& input_shape, const Shape& start_indices_shape,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ absl::Span<const int64> slice_sizes) {
TF_RETURN_IF_ERROR(
ExpectArray(input_shape, "input tensor operand gather op"));
TF_RETURN_IF_ERROR(
@@ -2703,8 +2701,7 @@ static Status ValidateGatherDimensionNumbers(
namespace {
Status ValidateScatterDimensionNumbers(
- const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> scatter_indices_shape,
+ const Shape& operand_shape, absl::Span<const int64> scatter_indices_shape,
const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
// Validate update_window_dims in ScatterDimensionNumbers.
if (!absl::c_is_sorted(dim_numbers.update_window_dims())) {
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 235b1a4cf3..a28345acef 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -21,12 +21,12 @@ limitations under the License.
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -55,7 +55,7 @@ class ShapeInference {
// given input shapes.
static StatusOr<Shape> InferBinaryOpShape(
HloOpcode opcode, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode,
const HloInstruction* lhs,
const HloInstruction* rhs);
@@ -73,18 +73,15 @@ class ShapeInference {
// Infers the shape produced by applying the given variadic operation to the
// given input operand shapes.
static StatusOr<Shape> InferVariadicOpShape(
- HloOpcode opcode,
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+ HloOpcode opcode, absl::Span<const Shape* const> operand_shapes);
static StatusOr<Shape> InferVariadicOpShape(
- HloOpcode opcode,
- tensorflow::gtl::ArraySlice<const HloInstruction*> operands);
+ HloOpcode opcode, absl::Span<const HloInstruction* const> operands);
// Infers the shape produced by applying the given mapping computation shape
// to the given operand shapes.
static StatusOr<Shape> InferMapShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const ProgramShape& to_apply,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
+ absl::Span<const int64> dimensions);
// Infers the shape produced by InferBatchNormTraining with the given
// operands.
@@ -116,14 +113,13 @@ class ShapeInference {
int64 feature_group_count = 1);
// Infers the shape produced by the given FFT type on the given operand.
- static StatusOr<Shape> InferFftShape(
- const Shape& in, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
+ static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type,
+ absl::Span<const int64> fft_length);
// Infers the shape produced by a cross replica sum with the given operand
// shapes.
static StatusOr<Shape> InferCrossReplicaSumShape(
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+ absl::Span<const Shape* const> operand_shapes);
// Infers final shape of an Alltoall operation that is created by the xla
// builder.
@@ -134,7 +130,7 @@ class ShapeInference {
// Infers the shape of an HLO all-to-all instruction.
static StatusOr<Shape> InferAllToAllTupleShape(
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+ absl::Span<const Shape* const> operand_shapes);
// Infers the shape of a collective permute operation.
static StatusOr<Shape> InferCollectivePermuteShape(const Shape& shape);
@@ -146,8 +142,8 @@ class ShapeInference {
// index as the leading parameter, and the program shape should match
// accordingly (or an error will result).
static StatusOr<Shape> InferReduceShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ absl::Span<const Shape* const> arg_shapes,
+ absl::Span<const int64> dimensions_to_reduce,
const ProgramShape& to_apply);
// Infers the shape produced by applying the given computation to the operand
@@ -165,24 +161,23 @@ class ShapeInference {
// Infers the shape produced by a reverse operation that reverses the order
// of the elements in the given dimensions.
- static StatusOr<Shape> InferReverseShape(
- const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ static StatusOr<Shape> InferReverseShape(const Shape& operand_shape,
+ absl::Span<const int64> dimensions);
// Infers the shape produced by a slice operation spanning from the starts to
// the limits in the original shape's dimensions.
//
// e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16]
- static StatusOr<Shape> InferSliceShape(
- const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> limits,
- tensorflow::gtl::ArraySlice<int64> strides);
+ static StatusOr<Shape> InferSliceShape(const Shape& arg,
+ absl::Span<const int64> starts,
+ absl::Span<const int64> limits,
+ absl::Span<const int64> strides);
// Infers the shape produced by a dynamic slice operation of size specified
// in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'.
static StatusOr<Shape> InferDynamicSliceShape(
const Shape& operand_shape, const Shape& start_indices_shape,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
// Infers the shape produced by a dynamic update slice operation based
// on the shape of operand and update.
@@ -213,30 +208,30 @@ class ShapeInference {
// Infers the shape produced by a broadcast operation.
static StatusOr<Shape> InferBroadcastShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+ const Shape& operand, absl::Span<const int64> broadcast_sizes);
// Infers the shape produced by a reshape operation from the element type of
// its operand and the new dimension sizes specified.
- static StatusOr<Shape> InferReshapeShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
+ static StatusOr<Shape> InferReshapeShape(const Shape& operand,
+ absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes);
// Infers the shape produced by a transpose operation from the element type of
// its operand and its dimensions field.
static StatusOr<Shape> InferTransposeShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions);
+ const Shape& operand, absl::Span<const int64> dimensions);
// Helper that infers the shape produced by performing a concatenate operation
// with the given operand shapes.
static StatusOr<Shape> InferConcatOpShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, int64 dimension);
+ absl::Span<const Shape* const> arg_shapes, int64 dimension);
// Infers the shape produced by a kAfterAll. Trivially this shape is always a
// TOKEN shape. However, ShapeInference serves two purposes: inferring shapes
// and checking operand shapes. This method verifies that the operand shapes
// are all TOKENs.
static StatusOr<Shape> InferAfterAllShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes);
+ absl::Span<const Shape* const> arg_shapes);
// Helper that validates the given operand shape can be converted to the
// target output_shape via a convert instruction -- the requirement is that
@@ -266,8 +261,7 @@ class ShapeInference {
// Helper that validates the given arg_shapes are compatible with the shape of
// the to_apply parameters, and returns the to_apply result shape.
static StatusOr<Shape> InferCallShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const ProgramShape& to_apply);
+ absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply);
// Helper that infers the shape produced by performing a dot operation with
// the given LHS and RHS shapes.
@@ -281,7 +275,7 @@ class ShapeInference {
static StatusOr<Shape> InferGatherShape(
const Shape& input_shape, const Shape& start_indices_shape,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
// Helper that validates the given input shape, scatter indices shape, updates
// shape, and scatter dimension numbers that constitute a scatter operation,
@@ -299,7 +293,7 @@ class ShapeInference {
// even in the presence of broadcasting of one of the operands over the other.
static StatusOr<Shape> InferElementwiseBinaryOpShape(
HloOpcode operation, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
// Helper for inferring the shape of Clamp ops.
static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand,
@@ -327,7 +321,7 @@ class ShapeInference {
// smaller_shape is broadcast to.
static StatusOr<Shape> InferInDimBroadcastShape(
const Shape& smaller_shape, const Shape& larger_shape,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference);
};
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 4ed8fc6b86..cc92e58ef8 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -17,18 +17,17 @@ limitations under the License.
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace {
-using ::tensorflow::gtl::ArraySlice;
using ::testing::ContainsRegex;
using ::testing::HasSubstr;
@@ -58,9 +57,9 @@ class ReduceShapeInferenceTest : public ShapeInferenceTest {
// Helper that runs reduce shape inference with the input 'arg' and given
// dimensions to reduce, and checks the inferred shape is as expected. The
// element type here is hard-coded to F32.
- void ExpectInferredReduceShape(
- const Shape& expected_inferred_shape, const Shape& arg,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
+ void ExpectInferredReduceShape(const Shape& expected_inferred_shape,
+ const Shape& arg,
+ absl::Span<const int64> dimensions_to_reduce) {
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
auto inferred_status = ShapeInference::InferReduceShape(
{&arg, &f32_}, dimensions_to_reduce, to_apply);
@@ -252,7 +251,7 @@ TEST_F(ShapeInferenceTest, ClampBadShapes) {
TEST_F(ShapeInferenceTest, Complex) {
auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
- const tensorflow::gtl::ArraySlice<int64>& bcast) {
+ const absl::Span<const int64>& bcast) {
return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs,
bcast);
};
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h
index 905a7e82e6..e1d26da4a2 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.h
+++ b/tensorflow/compiler/xla/service/shaped_buffer.h
@@ -20,11 +20,11 @@ limitations under the License.
#include <ostream>
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
@@ -84,6 +84,14 @@ class ShapedBuffer {
*buffers_.mutable_element(index) = buffer;
}
+ // Sets all buffers.
+ //
+ // Precondition: buffers.shape == on_device_shape_
+ void set_buffers(ShapeTree<se::DeviceMemoryBase> buffers) {
+ CHECK(ShapeUtil::Equal(buffers.shape(), on_device_shape_));
+ buffers_ = std::move(buffers);
+ }
+
// Returns the underlying ShapeTree containing all the device addresses in the
// ShapedBuffer.
const ShapeTree<se::DeviceMemoryBase>& buffers() const { return buffers_; }
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index f77690a462..21725946b3 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -20,12 +20,12 @@ limitations under the License.
#include <set>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -130,7 +130,7 @@ class TransferManager {
// Resets the devices associated with this transfer manager.
virtual Status ResetDevices(
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> executor) = 0;
+ absl::Span<se::StreamExecutor* const> executor) = 0;
// Given an allocated ShapedBuffer, constructs the tuple index table(s) in
// each buffer of the given ShapedBuffer corresponding to tuple shapes. If the
@@ -211,8 +211,7 @@ class TransferManager {
// to construct a tuple index table in the platform-specific tuple
// representation.
virtual Status WriteSingleTupleIndexTable(
- se::Stream* stream,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> elements,
+ se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
const Shape& shape, se::DeviceMemoryBase* region) = 0;
private:
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index cf00ca102b..6fed7c76d0 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -360,7 +360,7 @@ Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
}
Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) {
- tensorflow::gtl::ArraySlice<HloInstruction*> operands(tuple->operands());
+ absl::Span<HloInstruction* const> operands(tuple->operands());
PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple);
points_to_set.AddPointedToBuffer(
logical_buffer_analysis_->GetBuffer(tuple, /*index=*/{}),
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
index 62c7bb685d..a9e8a51e09 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
@@ -24,6 +24,7 @@ limitations under the License.
#include <vector>
#include "absl/container/inlined_vector.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -34,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/compactptrset.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index 10d382e8ab..a32d1f9026 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -72,9 +72,8 @@ class TuplePointsToAnalysisTest : public HloTestBase {
// Checks that the given points-to set contains exactly (unordered) the given
// LogicalBuffers.
- void ExpectHasBuffers(
- const PointsToSet::BufferList& points_to_set,
- tensorflow::gtl::ArraySlice<const LogicalBuffer*> buffers) {
+ void ExpectHasBuffers(const PointsToSet::BufferList& points_to_set,
+ absl::Span<const LogicalBuffer* const> buffers) {
std::vector<const LogicalBuffer*> vec(buffers.begin(), buffers.end());
EXPECT_THAT(points_to_set, UnorderedElementsAreArray(vec));
}
@@ -83,7 +82,7 @@ class TuplePointsToAnalysisTest : public HloTestBase {
// top-level buffers of the given instructions.
void ExpectHasTopLevelBuffers(
const PointsToSet::BufferList& points_to_set,
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
+ absl::Span<HloInstruction* const> instructions) {
PointsToSet::BufferList buffers;
for (auto instruction : instructions) {
buffers.push_back(GetBuffer(instruction, /*index=*/{}));
@@ -94,7 +93,7 @@ class TuplePointsToAnalysisTest : public HloTestBase {
// Overload which takes a set instead of a vector.
void ExpectHasTopLevelBuffers(
const PointsToSet::BufferSet& points_to_set,
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
+ absl::Span<HloInstruction* const> instructions) {
ExpectHasTopLevelBuffers(
PointsToSet::BufferList(points_to_set.begin(), points_to_set.end()),
instructions);
@@ -104,8 +103,7 @@ class TuplePointsToAnalysisTest : public HloTestBase {
// aliases which are exactly (unordered) the given instruction/index pairs.
void ExpectHasBufferAliases(
const HloInstruction* instruction, const ShapeIndex& index,
- tensorflow::gtl::ArraySlice<std::pair<HloInstruction*, ShapeIndex>>
- expected) {
+ absl::Span<const std::pair<HloInstruction*, ShapeIndex>> expected) {
const LogicalBuffer* buffer =
points_to_analysis_->GetBufferDefinedAt(instruction, index)
.ValueOrDie();
diff --git a/tensorflow/compiler/xla/service/tuple_util.cc b/tensorflow/compiler/xla/service/tuple_util.cc
index 4a530bb0b2..cfb0c787d0 100644
--- a/tensorflow/compiler/xla/service/tuple_util.cc
+++ b/tensorflow/compiler/xla/service/tuple_util.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/tuple_util.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
@@ -40,7 +40,7 @@ namespace xla {
/*static*/ HloInstruction* TupleUtil::AppendSuffix(
HloInstruction* input_tuple,
- tensorflow::gtl::ArraySlice<HloInstruction*> trailing_values) {
+ absl::Span<HloInstruction* const> trailing_values) {
CHECK(ShapeUtil::IsTuple(input_tuple->shape()));
HloComputation* computation = input_tuple->parent();
diff --git a/tensorflow/compiler/xla/service/tuple_util.h b/tensorflow/compiler/xla/service/tuple_util.h
index e5ff9aaa83..bc5aac09f2 100644
--- a/tensorflow/compiler/xla/service/tuple_util.h
+++ b/tensorflow/compiler/xla/service/tuple_util.h
@@ -38,7 +38,7 @@ class TupleUtil {
// `input_tuple`.
static HloInstruction* AppendSuffix(
HloInstruction* input_tuple,
- tensorflow::gtl::ArraySlice<HloInstruction*> trailing_values);
+ absl::Span<HloInstruction* const> trailing_values);
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc
index 7e4ac92a7c..c3c2603c7e 100644
--- a/tensorflow/compiler/xla/service/while_loop_analysis.cc
+++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc
@@ -211,8 +211,7 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
VLOG(2) << "Couldn't evaluate while cond: " << result.status();
return nullopt;
}
- if (result.ValueOrDie()->data<bool>() ==
- tensorflow::gtl::ArraySlice<bool>{false}) {
+ if (result.ValueOrDie()->data<bool>() == absl::Span<const bool>{false}) {
VLOG(2) << "Loop has static trip count of " << trip_count;
return trip_count;
}
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
index f4098f28b3..e8fe33e626 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
@@ -110,6 +110,7 @@ bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually(
case HloOpcode::kBitcast:
case HloOpcode::kBroadcast:
+ case HloOpcode::kIota:
case HloOpcode::kReshape:
case HloOpcode::kReverse:
case HloOpcode::kSlice:
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
index e14014b961..32e69c335b 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
@@ -28,10 +28,6 @@ namespace op = xla::testing::opcode_matchers;
class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase {
public:
- WhileLoopInvariantCodeMotionTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-
// Makes a computation which has one parameter, of the given shape, and always
// returns PRED[]{true}. This is useful as a dummy loop condition.
HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape,
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index cfe4104f6d..1c892ba179 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -28,11 +28,6 @@ namespace {
namespace op = xla::testing::opcode_matchers;
class WhileLoopSimplifierTest : public HloVerifiedTestBase {
- public:
- WhileLoopSimplifierTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false) {}
-
protected:
// Makes an HloModule that contains a loop with `num_iters` iteration.
void MakeModuleWithSimpleLoop(int num_iters);
diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc
index e8f76ff745..f90ac91f9d 100644
--- a/tensorflow/compiler/xla/service/while_util.cc
+++ b/tensorflow/compiler/xla/service/while_util.cc
@@ -94,7 +94,7 @@ WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) {
/*static*/ StatusOr<WhileUtil::MakeInstructionsLiveInResult>
WhileUtil::MakeInstructionsLiveIn(
HloInstruction* while_instr,
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
+ absl::Span<HloInstruction* const> instructions) {
CHECK(ShapeUtil::IsTuple(while_instr->shape()));
int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size();
diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h
index e67636d80f..b1c4486887 100644
--- a/tensorflow/compiler/xla/service/while_util.h
+++ b/tensorflow/compiler/xla/service/while_util.h
@@ -55,7 +55,7 @@ class WhileUtil {
// that contains `while_instr`.
static StatusOr<MakeInstructionsLiveInResult> MakeInstructionsLiveIn(
HloInstruction* while_instr,
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions);
+ absl::Span<HloInstruction* const> instructions);
using LoopStateTy = std::vector<HloInstruction*>;
using LoopBodyGeneratorTy = std::function<StatusOr<LoopStateTy>(
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index c793a39c27..52c895e8d4 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -23,13 +23,13 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -262,6 +262,25 @@ class ShapeTree {
template <typename Fn>
Status ForEachMutableElementWithStatus(const Fn& func);
+ // Maps each element to generate a new tree with the same shape.
+ template <typename U>
+ ShapeTree<U> Map(const std::function<U(const T&)>& func) {
+ ShapeTree<U> result(shape_storage_);
+ ForEachElement([&](const ShapeIndex& index, const T& t) {
+ *result.mutable_element(index) = func(t);
+ });
+ return result;
+ }
+
+ template <typename U>
+ ShapeTree<U> Map(const std::function<U(T*)>& func) {
+ ShapeTree<U> result(shape_storage_);
+ ForEachMutableElement([&](const ShapeIndex& index, T* t) {
+ *result.mutable_element(index) = func(t);
+ });
+ return result;
+ }
+
// Copy the subtree of values from 'other' rooted at ShapeIndex
// 'source_base_index' into the subtree of value in this ShapeTree rooted at
// 'target_base_index'.
@@ -463,9 +482,6 @@ template <typename T>
ShapeTree<T>::ShapeTree(Shape shape)
: shape_storage_(std::make_shared<Shape>(std::move(shape))),
shape_(shape_storage_.get()) {
- // The shape_ field is just used to hold the structure of the shape.
- // It should not be relied upon to store layout information.
- LayoutUtil::ClearLayout(shape_storage_.get());
const int64 count = CountSubshapes(*shape_);
nodes_.reserve(count);
nodes_.emplace_back(ShapeIndex{});
@@ -502,9 +518,6 @@ template <typename T>
ShapeTree<T>::ShapeTree(Shape shape, const T& init_value)
: shape_storage_(std::make_shared<Shape>(std::move(shape))),
shape_(shape_storage_.get()) {
- // The shape_ field is just used to hold the structure of the shape.
- // It should not be relied upon to store layout information.
- LayoutUtil::ClearLayout(shape_storage_.get());
const int64 count = CountSubshapes(*shape_);
nodes_.reserve(count);
nodes_.emplace_back(ShapeIndex{}, init_value);
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 5477a78a9a..9772c06bce 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -95,11 +95,11 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts,
}
if (ShapeUtil::IsTuple(lhs)) {
- return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
- [=](const Shape& l, const Shape& r) {
- return CompareShapes(l, r, compare_layouts,
- ignore_fp_precision);
- });
+ return absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(),
+ [=](const Shape& l, const Shape& r) {
+ return CompareShapes(l, r, compare_layouts,
+ ignore_fp_precision);
+ });
} else if (!ShapeUtil::IsArray(lhs)) {
// Non-tuple, non-array tupes such as opaque and token types are trivially
// the same.
@@ -111,13 +111,13 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts,
return false;
}
if (LayoutUtil::IsDenseArray(lhs)) {
- if (!ContainersEqual(LayoutUtil::MinorToMajor(lhs),
- LayoutUtil::MinorToMajor(rhs))) {
+ if (!absl::c_equal(LayoutUtil::MinorToMajor(lhs),
+ LayoutUtil::MinorToMajor(rhs))) {
VLOG(3) << "CompareShapes: lhs layout != rhs layout";
return false;
}
- if (!ContainersEqual(lhs.layout().padded_dimensions(),
- rhs.layout().padded_dimensions())) {
+ if (!absl::c_equal(lhs.layout().padded_dimensions(),
+ rhs.layout().padded_dimensions())) {
VLOG(3)
<< "CompareShapes: lhs padded_dimensions != rhs padded_dimensions";
return false;
@@ -139,8 +139,8 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts,
// Constructs and returns the new shape with the given minor_to_major order in
// its Layout.
StatusOr<Shape> MakeShapeWithLayoutInternal(
- PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> minor_to_major) {
+ PrimitiveType element_type, absl::Span<const int64> dimensions,
+ absl::Span<const int64> minor_to_major) {
if (dimensions.size() != minor_to_major.size()) {
return InvalidArgument("Dimensions size is %ld, but layout size is %ld.",
dimensions.size(), minor_to_major.size());
@@ -214,8 +214,8 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
return program_shape;
}
-/* static */ Shape ShapeUtil::MakeShape(
- PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) {
+/* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type,
+ absl::Span<const int64> dimensions) {
CHECK(IsArrayPrimitiveType(element_type));
Shape result;
PopulateShape(element_type, dimensions, &result);
@@ -223,21 +223,21 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
}
/* static */ Shape ShapeUtil::MakeShapeWithLayout(
- PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> minor_to_major) {
+ PrimitiveType element_type, absl::Span<const int64> dimensions,
+ absl::Span<const int64> minor_to_major) {
return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major)
.ValueOrDie();
}
/* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout(
- PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ PrimitiveType element_type, absl::Span<const int64> dimensions) {
std::vector<int64> layout(dimensions.size());
std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
return MakeShapeWithLayout(element_type, dimensions, layout);
}
/* static */ Shape ShapeUtil::MakeShapeWithSparseLayout(
- PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
+ PrimitiveType element_type, absl::Span<const int64> dimensions,
int64 max_sparse_elements) {
CHECK(IsArrayPrimitiveType(element_type));
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
@@ -256,9 +256,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return MakeShapeWithDescendingLayout(shape.element_type(), dims);
}
-/* static */ void ShapeUtil::PopulateShape(
- PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
- Shape* shape) {
+/* static */ void ShapeUtil::PopulateShape(PrimitiveType element_type,
+ absl::Span<const int64> dimensions,
+ Shape* shape) {
shape->Clear();
shape->set_element_type(element_type);
for (int64 dimension : dimensions) {
@@ -268,8 +268,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
TF_DCHECK_OK(ValidateShape(*shape));
}
-/* static */ Shape ShapeUtil::MakeTupleShape(
- tensorflow::gtl::ArraySlice<Shape> shapes) {
+/* static */ Shape ShapeUtil::MakeTupleShape(absl::Span<const Shape> shapes) {
Shape result;
result.set_element_type(TUPLE);
result.mutable_tuple_shapes()->Reserve(shapes.size());
@@ -662,7 +661,7 @@ StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) {
const Shape& rhs) {
CHECK(ShapeUtil::IsArray(lhs));
CHECK(ShapeUtil::IsArray(rhs));
- return ContainersEqual(lhs.dimensions(), rhs.dimensions());
+ return absl::c_equal(lhs.dimensions(), rhs.dimensions());
}
/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
@@ -676,8 +675,8 @@ StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) {
return IsArray(rhs) && SameDimensions(lhs, rhs);
} else if (lhs.element_type() == TUPLE) {
return rhs.element_type() == TUPLE &&
- ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
- CompatibleIgnoringElementType);
+ absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(),
+ CompatibleIgnoringElementType);
} else {
// Opaque, token, etc types are vacuously compatible.
return lhs.element_type() == rhs.element_type();
@@ -691,8 +690,8 @@ StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) {
CompatibleIgnoringElementType(lhs, rhs);
} else if (lhs.element_type() == TUPLE) {
return rhs.element_type() == TUPLE &&
- ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
- CompatibleIgnoringFpPrecision);
+ absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(),
+ CompatibleIgnoringFpPrecision);
} else {
// Opaque, token, etc types are vacuously compatible.
return lhs.element_type() == rhs.element_type();
@@ -791,7 +790,7 @@ StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) {
allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout());
} else {
CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString();
- tensorflow::gtl::ArraySlice<int64> padded_dimensions =
+ absl::Span<const int64> padded_dimensions =
LayoutUtil::PaddedDimensions(shape);
if (!padded_dimensions.empty()) {
CHECK_EQ(Rank(shape), padded_dimensions.size());
@@ -1035,7 +1034,7 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
/* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) {
CHECK(ShapeUtil::IsArray(shape));
- return ArrayContains<int64>(AsInt64Slice(shape.dimensions()), 1);
+ return absl::c_linear_search(shape.dimensions(), 1);
}
namespace {
@@ -1115,7 +1114,7 @@ Status ForEachMutableSubshapeHelper(
}
/* static */ Shape ShapeUtil::PermuteDimensions(
- tensorflow::gtl::ArraySlice<int64> permutation, const Shape& shape) {
+ absl::Span<const int64> permutation, const Shape& shape) {
Shape new_shape = shape;
new_shape.clear_dimensions();
for (auto dim : Permute(permutation, shape.dimensions())) {
@@ -1259,7 +1258,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
/* static */ bool ShapeUtil::TransposeIsBitcast(
const Shape& input_shape, const Shape& output_shape,
- tensorflow::gtl::ArraySlice<int64> dimension_mapping) {
+ absl::Span<const int64> dimension_mapping) {
CHECK(LayoutUtil::HasLayout(input_shape) &&
LayoutUtil::HasLayout(output_shape));
@@ -1286,7 +1285,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
// apply(input_dimensions, I) =
// apply((dimension_mapping * output_dimensions), I)
// input_dimensions = dimension_mapping * output_dimensions
- return ContainersEqual(
+ return absl::c_equal(
ComposePermutations(dimension_mapping,
AsInt64Slice(output_shape.layout().minor_to_major())),
input_shape.layout().minor_to_major());
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 83e58545bf..8234fcdd3f 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "absl/container/inlined_vector.h"
#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
@@ -147,7 +147,7 @@ class ShapeIndexView {
string ToString() const;
private:
- tensorflow::gtl::ArraySlice<int64> indices_;
+ absl::Span<const int64> indices_;
};
std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index);
@@ -328,7 +328,7 @@ class ShapeUtil {
static Shape ChangeElementType(const Shape& original, PrimitiveType type);
// Creates a tuple shape from a slice of element shapes within the tuple.
- static Shape MakeTupleShape(tensorflow::gtl::ArraySlice<Shape> shapes);
+ static Shape MakeTupleShape(absl::Span<const Shape> shapes);
// Creates an opaque shape. These are generally used for threading a context
// into a custom operation.
@@ -355,31 +355,29 @@ class ShapeUtil {
// Constructs a new shape with the given element type and sequence of
// dimensions.
static Shape MakeShape(PrimitiveType element_type,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const int64> dimensions);
// Creates a Shape with element type corresponding to T and the given
// dimensions
template <typename T>
- static Shape MakeShapeWithType(
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+ static Shape MakeShapeWithType(absl::Span<const int64> dimensions) {
return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(),
dimensions);
}
// Constructs a new shape with the given minor_to_major order in its Layout.
// Returns a value shape such that shape.has_layout().
- static Shape MakeShapeWithLayout(
- PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> minor_to_major);
+ static Shape MakeShapeWithLayout(PrimitiveType element_type,
+ absl::Span<const int64> dimensions,
+ absl::Span<const int64> minor_to_major);
- static Shape MakeShapeWithSparseLayout(
- PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
- int64 max_sparse_elements);
+ static Shape MakeShapeWithSparseLayout(PrimitiveType element_type,
+ absl::Span<const int64> dimensions,
+ int64 max_sparse_elements);
// Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
static Shape MakeShapeWithDescendingLayout(
- PrimitiveType element_type,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ PrimitiveType element_type, absl::Span<const int64> dimensions);
// Returns a new Shape based on the given Shape with low-dimension-major
// layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions
@@ -391,8 +389,7 @@ class ShapeUtil {
// As MakeShape, but the object to write to is passed in.
static void PopulateShape(PrimitiveType element_type,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- Shape* shape);
+ absl::Span<const int64> dimensions, Shape* shape);
// Validates that the provided shape satisfies invariants.
static Status ValidateShape(const Shape& shape);
@@ -539,7 +536,7 @@ class ShapeUtil {
// !HasLayout(shape) ||
// TransposeIsBitcast(shape, PermuteDimensions(permutation, shape),
// InversePermutation(permutation)).
- static Shape PermuteDimensions(tensorflow::gtl::ArraySlice<int64> permutation,
+ static Shape PermuteDimensions(absl::Span<const int64> permutation,
const Shape& shape);
// If we can go from `shape_pre` to `shape_post` by merely inserting or
@@ -580,9 +577,9 @@ class ShapeUtil {
// to its input and thus may be replaced with a bitcast.
//
// Precondition: Both input_shape and output_shape have explicit layouts.
- static bool TransposeIsBitcast(
- const Shape& input_shape, const Shape& output_shape,
- tensorflow::gtl::ArraySlice<int64> dimension_mapping);
+ static bool TransposeIsBitcast(const Shape& input_shape,
+ const Shape& output_shape,
+ absl::Span<const int64> dimension_mapping);
// Returns whether a reshape from "input_shape" to "output_shape" is a
// bitcast.
@@ -621,12 +618,12 @@ class ShapeUtil {
// continue, or false otherwise.
//
// visitor_function must be a callable of type
- // StatusOr<bool>(ArraySlice<int64>) or compatible.
+ // StatusOr<bool>(Span<int64>) or compatible.
template <typename FnType>
static Status ForEachIndexWithStatus(const Shape& shape,
- tensorflow::gtl::ArraySlice<int64> base,
- tensorflow::gtl::ArraySlice<int64> count,
- tensorflow::gtl::ArraySlice<int64> incr,
+ absl::Span<const int64> base,
+ absl::Span<const int64> count,
+ absl::Span<const int64> incr,
const FnType& visitor_function) {
return ForEachIndexInternal(shape, base, count, incr, visitor_function);
}
@@ -648,13 +645,12 @@ class ShapeUtil {
}
template <typename FnType>
- static void ForEachIndex(const Shape& shape,
- tensorflow::gtl::ArraySlice<int64> base,
- tensorflow::gtl::ArraySlice<int64> count,
- tensorflow::gtl::ArraySlice<int64> incr,
+ static void ForEachIndex(const Shape& shape, absl::Span<const int64> base,
+ absl::Span<const int64> count,
+ absl::Span<const int64> incr,
const FnType& visitor_function) {
ForEachIndexWithStatus(shape, base, count, incr,
- [&](tensorflow::gtl::ArraySlice<int64> indices) {
+ [&](absl::Span<const int64> indices) {
return StatusOr<bool>(visitor_function(indices));
})
.IgnoreError();
@@ -676,7 +672,7 @@ class ShapeUtil {
template <typename FnType>
static void ForEachIndex(const Shape& shape, const FnType& visitor_function) {
ForEachIndexWithStatus(shape,
- [&](tensorflow::gtl::ArraySlice<int64> indices) {
+ [&](absl::Span<const int64> indices) {
return StatusOr<bool>(visitor_function(indices));
})
.IgnoreError();
@@ -687,18 +683,18 @@ class ShapeUtil {
// matter.
//
// visitor_function must be a callable of type
- // void(ArraySlice<int64>) or compatible.
+ // void(Span<int64>) or compatible.
template <typename FnType>
static void ForEachIndexParallel(const Shape& shape,
- tensorflow::gtl::ArraySlice<int64> base,
- tensorflow::gtl::ArraySlice<int64> count,
- tensorflow::gtl::ArraySlice<int64> incr,
+ absl::Span<const int64> base,
+ absl::Span<const int64> count,
+ absl::Span<const int64> incr,
const FnType& visitor_function) {
// The parallel version of ForEachIndexInternal can never fail.
CHECK(ForEachIndexInternal(
shape, base, count, incr,
- [&visitor_function](tensorflow::gtl::ArraySlice<int64> indexes)
- -> StatusOr<bool> {
+ [&visitor_function](
+ absl::Span<const int64> indexes) -> StatusOr<bool> {
visitor_function(indexes);
return true;
},
@@ -720,9 +716,9 @@ class ShapeUtil {
template <typename FnType>
static Status ForEachIndexInternal(const Shape& shape,
- tensorflow::gtl::ArraySlice<int64> base,
- tensorflow::gtl::ArraySlice<int64> count,
- tensorflow::gtl::ArraySlice<int64> incr,
+ absl::Span<const int64> base,
+ absl::Span<const int64> count,
+ absl::Span<const int64> incr,
const FnType& visitor_function,
bool parallel = false) {
if (ShapeUtil::IsZeroElementArray(shape)) {
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index 7549ba9c78..6ca4085aaf 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -705,11 +705,10 @@ TEST(ShapeUtilTest, ForEachIndex) {
Shape shape = ShapeUtil::MakeShape(F32, data.dimensions);
// Increments at every invocation.
int invocations = 0;
- auto increment_func =
- [&invocations](tensorflow::gtl::ArraySlice<int64> indexes) {
- invocations++;
- return true;
- };
+ auto increment_func = [&invocations](absl::Span<const int64> indexes) {
+ invocations++;
+ return true;
+ };
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
@@ -726,8 +725,7 @@ TEST(ShapeUtilTest, ForEachIndexWithStatus) {
// Increments at every invocation.
int invocations = 0;
auto increment_func =
- [&invocations](
- tensorflow::gtl::ArraySlice<int64> indexes) -> StatusOr<bool> {
+ [&invocations](absl::Span<const int64> indexes) -> StatusOr<bool> {
if (++invocations == 5) {
return Unimplemented("Cannot increment beyond 5.");
}
@@ -748,7 +746,7 @@ TEST(ShapeUtilTest, ForEachIndexParallel) {
Shape shape = ShapeUtil::MakeShape(F32, {10, 10});
int64 output[10][10];
int init = 5;
- auto set_func = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+ auto set_func = [&](absl::Span<const int64> indexes) {
output[indexes[0]][indexes[1]] = init + indexes[0] + indexes[1];
};
diff --git a/tensorflow/compiler/xla/sparse_index_array.cc b/tensorflow/compiler/xla/sparse_index_array.cc
index 31844abd89..1c135dda86 100644
--- a/tensorflow/compiler/xla/sparse_index_array.cc
+++ b/tensorflow/compiler/xla/sparse_index_array.cc
@@ -33,7 +33,7 @@ SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank,
}
SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank,
- tensorflow::gtl::ArraySlice<int64> indices)
+ absl::Span<const int64> indices)
: SparseIndexArray(max_indices, rank,
std::vector<int64>(indices.begin(), indices.end())) {}
@@ -48,25 +48,24 @@ int64 SparseIndexArray::index_count() const {
return indices_.size() / rank_;
}
-tensorflow::gtl::ArraySlice<int64> SparseIndexArray::At(
+absl::Span<const int64> SparseIndexArray::At(
int64 sparse_element_number) const {
CHECK_GT(rank_, 0);
CHECK_GE(sparse_element_number, 0);
CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size());
- return tensorflow::gtl::ArraySlice<int64>(
+ return absl::Span<const int64>(
indices_.data() + rank_ * sparse_element_number, rank_);
}
-tensorflow::gtl::MutableArraySlice<int64> SparseIndexArray::At(
- int64 sparse_element_number) {
+absl::Span<int64> SparseIndexArray::At(int64 sparse_element_number) {
CHECK_GT(rank_, 0);
CHECK_GE(sparse_element_number, 0);
CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size());
- return tensorflow::gtl::MutableArraySlice<int64>(
- indices_.data() + rank_ * sparse_element_number, rank_);
+ return absl::Span<int64>(indices_.data() + rank_ * sparse_element_number,
+ rank_);
}
-void SparseIndexArray::Append(tensorflow::gtl::ArraySlice<int64> index) {
+void SparseIndexArray::Append(absl::Span<const int64> index) {
CHECK_GT(rank_, 0);
CHECK_EQ(index.size(), rank_);
indices_.insert(indices_.end(), index.begin(), index.end());
@@ -90,12 +89,12 @@ bool SparseIndexArray::Validate(const Shape& shape) const {
if (num_indices < 2) {
return true;
}
- tensorflow::gtl::ArraySlice<int64> last = At(0);
+ absl::Span<const int64> last = At(0);
if (!IndexUtil::IndexInBounds(shape, last)) {
return false;
}
for (int64 n = 1; n < num_indices; ++n) {
- tensorflow::gtl::ArraySlice<int64> next = At(n);
+ absl::Span<const int64> next = At(n);
if (!IndexUtil::IndexInBounds(shape, next)) {
return false;
}
diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h
index 70fab3bea5..a96d483462 100644
--- a/tensorflow/compiler/xla/sparse_index_array.h
+++ b/tensorflow/compiler/xla/sparse_index_array.h
@@ -21,10 +21,10 @@ limitations under the License.
#include <vector>
#include "absl/container/inlined_vector.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
@@ -65,7 +65,7 @@ class SparseIndexArray {
SparseIndexArray(int64 max_indices, int64 rank,
std::vector<int64> indices = {});
SparseIndexArray(int64 max_indices, int64 rank,
- tensorflow::gtl::ArraySlice<int64> indices);
+ absl::Span<const int64> indices);
// Returns the number of elements represented by the indices stored in the
// array.
@@ -73,12 +73,12 @@ class SparseIndexArray {
// Returns a slice that refers to the given sparse index number. The argument
// must be in the range [0, element_count()).
- tensorflow::gtl::ArraySlice<int64> At(int64 sparse_element_number) const;
- tensorflow::gtl::MutableArraySlice<int64> At(int64 sparse_element_number);
+ absl::Span<const int64> At(int64 sparse_element_number) const;
+ absl::Span<int64> At(int64 sparse_element_number);
// Adds the given index at the end of the array. The new size of the
// SparseIndexArray must not exceed `max_indices`.
- void Append(tensorflow::gtl::ArraySlice<int64> index);
+ void Append(absl::Span<const int64> index);
// Removes all indices from the array.
void Clear();
@@ -96,8 +96,8 @@ class SparseIndexArray {
int64 max_indices() const { return max_indices_; }
// Returns a pointer to the int64 array that holds the sparse indices.
- tensorflow::gtl::MutableArraySlice<int64> mutable_data() { return &indices_; }
- tensorflow::gtl::ArraySlice<int64> data() const { return indices_; }
+ absl::Span<int64> mutable_data() { return absl::MakeSpan(indices_); }
+ absl::Span<const int64> data() const { return indices_; }
// Sorts this sparse index array along with the set of corresponding values.
// The indices and values are sorted in the lexicographic order of the
@@ -115,7 +115,7 @@ class SparseIndexArray {
// std::cout << v[0] << ", " << v[1] << ", " << v[2] << std::endl;
//
template <typename NativeT>
- void SortWithValues(tensorflow::gtl::MutableArraySlice<NativeT> values);
+ void SortWithValues(absl::Span<NativeT> values);
private:
std::vector<int64> indices_;
@@ -124,8 +124,7 @@ class SparseIndexArray {
};
template <typename NativeT>
-void SparseIndexArray::SortWithValues(
- tensorflow::gtl::MutableArraySlice<NativeT> values) {
+void SparseIndexArray::SortWithValues(absl::Span<NativeT> values) {
int64 num_elements = index_count();
CHECK_EQ(values.size(), num_elements);
std::vector<int64> sort_order;
diff --git a/tensorflow/compiler/xla/sparse_index_array_test.cc b/tensorflow/compiler/xla/sparse_index_array_test.cc
index 7377f88958..e54057c400 100644
--- a/tensorflow/compiler/xla/sparse_index_array_test.cc
+++ b/tensorflow/compiler/xla/sparse_index_array_test.cc
@@ -33,7 +33,7 @@ TEST(SparseIndexArrayTest, Sort) {
std::vector<double> values = {
12.0, 13.0, 11.0, 15.0, 14.0, 16.0,
};
- a.SortWithValues<double>(&values);
+ a.SortWithValues<double>(absl::MakeSpan(values));
ASSERT_EQ(a.data(), std::vector<int64>({1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5,
6, 7, 6, 7, 8}));
ASSERT_EQ(values, std::vector<double>({11.0, 12.0, 13.0, 14.0, 15.0, 16.0}));
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index a0829b0d02..36b8fb2644 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -69,7 +69,6 @@ cc_library(
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
@@ -77,6 +76,8 @@ cc_library(
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_headers_lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
@@ -101,6 +102,7 @@ cc_library(
"//tensorflow/core:test",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -132,6 +134,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -208,6 +211,7 @@ cc_library(
"//tensorflow/core:test",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -282,6 +286,7 @@ cc_library(
"//tensorflow/core:stream_executor_no_cuda",
"//third_party/eigen3",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
@@ -562,6 +567,7 @@ xla_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -578,8 +584,7 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
+ "@com_google_absl//absl/types:span",
],
)
@@ -602,8 +607,8 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/types:span",
],
)
@@ -625,12 +630,11 @@ xla_test(
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
- "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1016,6 +1020,7 @@ xla_test(
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1144,6 +1149,7 @@ xla_test(
"//tensorflow/core:test",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1174,6 +1180,7 @@ xla_test_library(
"//tensorflow/core:test",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1457,11 +1464,11 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1475,14 +1482,12 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
- "//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
@@ -1492,7 +1497,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
- "//tensorflow/core:test",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1648,8 +1653,8 @@ xla_test(
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1662,13 +1667,13 @@ xla_test(
deps = [
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1828,6 +1833,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1841,15 +1847,11 @@ xla_test(
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_parser",
- "//tensorflow/compiler/xla/service:hlo_runner",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
@@ -1860,6 +1862,7 @@ xla_test(
"//tensorflow/core:test",
"//third_party/eigen3",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1867,10 +1870,8 @@ xla_test(
name = "multioutput_fusion_test",
srcs = ["multioutput_fusion_test.cc"],
deps = [
- "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:hlo",
@@ -1884,6 +1885,7 @@ xla_test(
"//tensorflow/core:test",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -2012,16 +2014,15 @@ xla_test(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
- "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/types:span",
],
)
@@ -2134,6 +2135,8 @@ xla_test(
shard_count = 30,
tags = [
"enable_for_xla_interpreter",
+ # Require optimized builds, iota_test_cpu is very slow in fastbuild.
+ "optonly",
],
deps = [
":client_library_test_base",
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 577fd1ab3b..0bf4556b43 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -35,13 +36,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/casts.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace {
-using tensorflow::gtl::ArraySlice;
class ArrayElementwiseOpTest : public ClientLibraryTestBase {
public:
@@ -433,8 +432,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) {
class IntegerDivideOpTest : public ArrayElementwiseOpTest {
protected:
template <typename T>
- void TestDivRem(ArraySlice<T> dividends, ArraySlice<T> divisors,
- ArraySlice<T> quotients, ArraySlice<T> remainders) {
+ void TestDivRem(absl::Span<const T> dividends, absl::Span<const T> divisors,
+ absl::Span<const T> quotients,
+ absl::Span<const T> remainders) {
{
XlaBuilder builder(TestName());
XlaOp dividend;
diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc
index 6c20f654fe..65589b0d6a 100644
--- a/tensorflow/compiler/xla/tests/bfloat16_test.cc
+++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc
@@ -65,7 +65,7 @@ XLA_TEST_F(Bfloat16Test, LogOperation) {
Log(x);
ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(1.387f), {},
- error_spec_);
+ ErrorSpec(0.01, 0.01));
}
XLA_TEST_F(Bfloat16Test, NegateScalarF16) {
@@ -110,7 +110,7 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
{static_cast<bfloat16>(5), static_cast<bfloat16>(5)})
.get()});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01));
+ ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01, 0.02));
}
XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
index 1d28e85b16..fe4267c73b 100644
--- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
@@ -53,10 +53,11 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
}
}
- std::unique_ptr<GlobalData> MakeR3Data(
- tensorflow::gtl::ArraySlice<int64> bounds,
- tensorflow::gtl::ArraySlice<int64> minor_to_major, Shape* r3_shape,
- Array3D<float>* r3_array, float start, float end, int seed) {
+ std::unique_ptr<GlobalData> MakeR3Data(absl::Span<const int64> bounds,
+ absl::Span<const int64> minor_to_major,
+ Shape* r3_shape,
+ Array3D<float>* r3_array, float start,
+ float end, int seed) {
*r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r3_array->FillRandom(start, end, seed);
auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout(
@@ -66,10 +67,11 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
return r3_global_data;
}
- std::unique_ptr<GlobalData> MakeR2Data(
- tensorflow::gtl::ArraySlice<int64> bounds,
- tensorflow::gtl::ArraySlice<int64> minor_to_major, Shape* r2_shape,
- Array2D<float>* r2_array, float start, float end, int seed) {
+ std::unique_ptr<GlobalData> MakeR2Data(absl::Span<const int64> bounds,
+ absl::Span<const int64> minor_to_major,
+ Shape* r2_shape,
+ Array2D<float>* r2_array, float start,
+ float end, int seed) {
*r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r2_array->FillRandom(start, end, seed);
auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout(
@@ -348,7 +350,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
Array3D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1],
spec.output_bounds[2]);
- auto Each = ([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
+ auto Each = ([&](absl::Span<const int64> indices, float* value) {
float r3_implicit = r3_implicit_array(indices[0] % spec.input_bounds[0],
indices[1] % spec.input_bounds[1],
indices[2] % spec.input_bounds[2]);
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 9cd974fd9b..8a236db0ff 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -95,15 +95,14 @@ string ClientLibraryTestBase::TestName() const {
}
StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ XlaBuilder* builder, absl::Span<GlobalData* const> arguments) {
// Build the computation, as a convenience.
TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
return client_->Execute(computation, arguments, &execution_options_);
}
StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
ExecutionOptions execution_options = execution_options_;
if (shape_with_output_layout != nullptr) {
@@ -115,7 +114,7 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
}
StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
// Build the computation, as a convenience.
TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
@@ -124,8 +123,7 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
StatusOr<std::unique_ptr<Literal>>
ClientLibraryTestBase::ExecuteAndTransferReference(
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
ExecutionOptions execution_options = execution_options_;
if (shape_with_output_layout != nullptr) {
@@ -138,7 +136,7 @@ ClientLibraryTestBase::ExecuteAndTransferReference(
}
string ClientLibraryTestBase::ExecuteToString(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ XlaBuilder* builder, absl::Span<GlobalData* const> arguments) {
auto computation_status = builder->Build();
if (!computation_status.ok()) {
return computation_status.status().ToString();
@@ -156,7 +154,7 @@ string ClientLibraryTestBase::ExecuteToString(
void ClientLibraryTestBase::ComputeAndCompareR1(
XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ absl::Span<GlobalData* const> arguments) {
std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments);
@@ -164,15 +162,14 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
void ClientLibraryTestBase::ComputeAndCompareLiteral(
XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
- const Shape* shape_with_layout) {
+ absl::Span<GlobalData* const> arguments, const Shape* shape_with_layout) {
EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
shape_with_layout));
}
void ClientLibraryTestBase::ComputeAndCompareLiteral(
XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
+ absl::Span<GlobalData* const> arguments, ErrorSpec error,
const Shape* shape_with_layout) {
EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
error, shape_with_layout));
@@ -180,7 +177,7 @@ void ClientLibraryTestBase::ComputeAndCompareLiteral(
Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
const xla::XlaComputation& computation, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
const std::function<void(const Literal& actual,
const string& error_message)>& verify_output) {
// Try with no layout requirement.
@@ -205,7 +202,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
const xla::XlaComputation& computation, const Literal& /*expected*/,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
const std::function<void(const Literal& actual,
const string& error_message)>& verify_output,
const Shape* output_with_layout) {
@@ -252,10 +249,9 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
// Every argument has an assigned layout.
TF_ASSIGN_OR_RETURN(
auto actual,
- ExecuteAndTransfer(
- computation,
- tensorflow::gtl::ArraySlice<GlobalData*>(arguments_with_layout),
- output_with_layout));
+ ExecuteAndTransfer(computation,
+ absl::Span<GlobalData* const>(arguments_with_layout),
+ output_with_layout));
string error_message = "Test with input layouts: ";
for (const auto& str : layout_strings) {
absl::StrAppend(&error_message, str, " ");
@@ -269,7 +265,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments_passed_in,
+ absl::Span<GlobalData* const> arguments_passed_in,
const Shape* shape_with_layout) {
std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
arguments_passed_in.end());
@@ -290,10 +286,6 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
if (ShapeUtil::ElementIsFloating(expected.shape()) ||
ShapeUtil::ElementIsComplex(expected.shape())) {
LOG(WARNING) << "performing exact comparison of floating point numbers";
- } else {
- TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) ||
- expected.shape().element_type() == PRED)
- << ShapeUtil::HumanString(expected.shape());
}
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
@@ -333,8 +325,8 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments_passed_in,
- ErrorSpec error, const Shape* shape_with_layout) {
+ absl::Span<GlobalData* const> arguments_passed_in, ErrorSpec error,
+ const Shape* shape_with_layout) {
std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
arguments_passed_in.end());
@@ -350,8 +342,6 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
}
- TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) ||
- ShapeUtil::ElementIsComplex(expected.shape()));
TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
@@ -392,7 +382,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
void ClientLibraryTestBase::ComputeAndCompareR1U8(
XlaBuilder* builder, absl::string_view expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ absl::Span<GlobalData* const> arguments) {
auto actual_status = ExecuteAndTransfer(builder, arguments);
EXPECT_IS_OK(actual_status.status());
if (!actual_status.ok()) {
@@ -411,7 +401,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
void ClientLibraryTestBase::ComputeAndCompareTuple(
XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ absl::Span<GlobalData* const> arguments) {
auto actual_status = ExecuteAndTransfer(builder, arguments);
EXPECT_IS_OK(actual_status.status());
if (!actual_status.ok()) {
@@ -423,7 +413,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
void ClientLibraryTestBase::ComputeAndCompareTuple(
XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ absl::Span<GlobalData* const> arguments, ErrorSpec error) {
auto actual_status = ExecuteAndTransfer(builder, arguments);
EXPECT_IS_OK(actual_status.status());
if (!actual_status.ok()) {
@@ -434,7 +424,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
}
void ClientLibraryTestBase::ComputeAndCompare(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<Literal> arguments) {
+ XlaBuilder* builder, absl::Span<const Literal> arguments) {
auto status_or_data = ComputeValueAndReference(builder, arguments);
EXPECT_IS_OK(status_or_data);
if (!status_or_data.ok()) {
@@ -446,8 +436,7 @@ void ClientLibraryTestBase::ComputeAndCompare(
}
void ClientLibraryTestBase::ComputeAndCompare(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<Literal> arguments,
- ErrorSpec error) {
+ XlaBuilder* builder, absl::Span<const Literal> arguments, ErrorSpec error) {
auto status_or_data = ComputeValueAndReference(builder, arguments);
EXPECT_IS_OK(status_or_data);
if (!status_or_data.ok()) {
@@ -460,7 +449,7 @@ void ClientLibraryTestBase::ComputeAndCompare(
StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
ClientLibraryTestBase::ComputeValueAndReference(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<Literal> arguments) {
+ XlaBuilder* builder, absl::Span<const Literal> arguments) {
// Transfer the arguments to the executor service. We put the unique_ptr's
// into a vector to keep the data alive on the service until the end of this
// function.
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index ac96d3e325..22dfdfb0e4 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -37,7 +38,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/bitmap.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -49,8 +49,8 @@ namespace xla {
// use_bfloat16_params with that value. Returns the result.
template <typename TestCase>
std::vector<TestCase> ExpandUseBfloat16(
- tensorflow::gtl::ArraySlice<bool> use_bfloat16_params,
- tensorflow::gtl::ArraySlice<TestCase> specs) {
+ absl::Span<const bool> use_bfloat16_params,
+ absl::Span<const TestCase> specs) {
std::vector<TestCase> expanded;
for (bool use_bfloat16 : use_bfloat16_params) {
for (const auto& spec : specs) {
@@ -93,15 +93,15 @@ class ClientLibraryTestBase : public ::testing::Test {
// execution options. Modify execution_options_ in your test if you want to
// customize the options.
StatusOr<std::unique_ptr<GlobalData>> Execute(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ XlaBuilder* builder, absl::Span<GlobalData* const> arguments);
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout = nullptr);
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout = nullptr);
// This executes the computation via the reference client (which connects a
@@ -109,13 +109,13 @@ class ClientLibraryTestBase : public ::testing::Test {
// computation.
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransferReference(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout = nullptr);
// Run a computation and return its value as a string. If an error
// occurs, then instead return the error as a string.
string ExecuteToString(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ absl::Span<GlobalData* const> arguments);
// Convenience methods for building and running a computation, transferring
// the result, and comparing it to the expected value(s). Methods are
@@ -125,102 +125,98 @@ class ClientLibraryTestBase : public ::testing::Test {
// for integral types without the ErrorSpec parameter.
template <typename NativeT>
void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ absl::Span<GlobalData* const> arguments);
template <typename NativeT>
void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
ErrorSpec error);
template <typename NativeT>
void ComputeAndCompareR1(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<NativeT> expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ absl::Span<const NativeT> expected,
+ absl::Span<GlobalData* const> arguments);
template <typename NativeT>
void ComputeAndCompareR1(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<NativeT> expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<const NativeT> expected,
+ absl::Span<GlobalData* const> arguments,
ErrorSpec error);
// As above, but uses a bitmap to hold the predicate vector to avoid
// deficiencies of vector<bool>.
void ComputeAndCompareR1(XlaBuilder* builder,
const tensorflow::core::Bitmap& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ absl::Span<GlobalData* const> arguments);
template <typename NativeT>
void ComputeAndCompareR2(XlaBuilder* builder,
const Array2D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ absl::Span<GlobalData* const> arguments);
template <typename NativeT>
void ComputeAndCompareR2(XlaBuilder* builder,
const Array2D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
ErrorSpec error);
template <typename NativeT>
void ComputeAndCompareR3(XlaBuilder* builder,
const Array3D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ absl::Span<GlobalData* const> arguments);
template <typename NativeT>
void ComputeAndCompareR3(XlaBuilder* builder,
const Array3D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
ErrorSpec error);
template <typename NativeT>
void ComputeAndCompareR4(XlaBuilder* builder,
const Array4D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ absl::Span<GlobalData* const> arguments);
template <typename NativeT>
void ComputeAndCompareR4(XlaBuilder* builder,
const Array4D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
ErrorSpec error);
// Build and run the computation and compare the result with the given
// literal. shape_with_layout indicates the result layout to request when
// calling Execute.
- void ComputeAndCompareLiteral(
- XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
- const Shape* shape_with_layout = nullptr);
- void ComputeAndCompareLiteral(
- XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
- const Shape* shape_with_layout = nullptr);
+ void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected,
+ absl::Span<GlobalData* const> arguments,
+ const Shape* shape_with_layout = nullptr);
+ void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected,
+ absl::Span<GlobalData* const> arguments,
+ ErrorSpec error,
+ const Shape* shape_with_layout = nullptr);
// ComputeAndCompare variant which returns an error status.
Status ComputeAndCompareLiteralWithStatus(
XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
const Shape* shape_with_layout = nullptr);
Status ComputeAndCompareLiteralWithStatus(
XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
+ absl::Span<GlobalData* const> arguments, ErrorSpec error,
const Shape* shape_with_layout = nullptr);
// Compare the result of the computation to a strings. In XLA strings are
// represented using rank-1 U8 shapes.
- void ComputeAndCompareR1U8(
- XlaBuilder* builder, absl::string_view expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ void ComputeAndCompareR1U8(XlaBuilder* builder, absl::string_view expected,
+ absl::Span<GlobalData* const> arguments);
// Convenience method for running a built computation, transferring the
// result, and comparing it to the expected tuple literal.
- void ComputeAndCompareTuple(
- XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
- void ComputeAndCompareTuple(
- XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
+ void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected,
+ absl::Span<GlobalData* const> arguments);
+ void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected,
+ absl::Span<GlobalData* const> arguments,
+ ErrorSpec error);
// Convenience method for running a built computation and comparing the result
// with the reference result.
void ComputeAndCompare(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<Literal> arguments);
+ absl::Span<const Literal> arguments);
void ComputeAndCompare(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<Literal> arguments,
- ErrorSpec error);
+ absl::Span<const Literal> arguments, ErrorSpec error);
// Create scalar operations for use in reductions.
XlaComputation CreateScalarRelu();
@@ -337,7 +333,7 @@ class ClientLibraryTestBase : public ::testing::Test {
// converted to bfloat16.
template <typename NativeT>
std::unique_ptr<GlobalData> CreateR1Parameter(
- tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
+ absl::Span<const NativeT> values, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle);
// Creates a parameter instruction that wraps the given constant array
@@ -381,7 +377,7 @@ class ClientLibraryTestBase : public ::testing::Test {
// actual).
StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
ComputeValueAndReference(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<Literal> arguments);
+ absl::Span<const Literal> arguments);
Client* client_;
Client* ref_client_; // To compute reference result.
@@ -390,12 +386,12 @@ class ClientLibraryTestBase : public ::testing::Test {
private:
Status ComputeAndCompareLiteralWithAllOutputLayouts(
const xla::XlaComputation& computation, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
const std::function<void(const Literal& actual,
const string& error_message)>& verify_output);
Status ComputeAndCompareLiteralWithAllInputLayouts(
const xla::XlaComputation& computation, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
const std::function<void(const Literal& actual,
const string& error_message)>& verify_output,
const Shape* output_with_layout = nullptr);
@@ -415,7 +411,7 @@ class ClientLibraryTestBase : public ::testing::Test {
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR0(
XlaBuilder* builder, NativeT expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ absl::Span<GlobalData* const> arguments) {
std::unique_ptr<Literal> expected_literal =
LiteralUtil::CreateR0<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
@@ -425,7 +421,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR0(
XlaBuilder* builder, NativeT expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ absl::Span<GlobalData* const> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
@@ -440,8 +436,8 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR1(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ XlaBuilder* builder, absl::Span<const NativeT> expected,
+ absl::Span<GlobalData* const> arguments) {
std::unique_ptr<Literal> expected_literal =
LiteralUtil::CreateR1<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
@@ -450,8 +446,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR1(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ XlaBuilder* builder, absl::Span<const NativeT> expected,
+ absl::Span<GlobalData* const> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
@@ -467,7 +463,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR2(
XlaBuilder* builder, const Array2D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ absl::Span<GlobalData* const> arguments) {
std::unique_ptr<Literal> expected_literal =
LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
@@ -477,7 +473,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR2(
XlaBuilder* builder, const Array2D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ absl::Span<GlobalData* const> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
@@ -493,7 +489,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR3(
XlaBuilder* builder, const Array3D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ absl::Span<GlobalData* const> arguments) {
std::unique_ptr<Literal> expected_literal =
LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
@@ -503,7 +499,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR3(
XlaBuilder* builder, const Array3D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ absl::Span<GlobalData* const> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
@@ -519,7 +515,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR4(
XlaBuilder* builder, const Array4D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ absl::Span<GlobalData* const> arguments) {
std::unique_ptr<Literal> expected_literal =
LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
@@ -529,7 +525,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR4(
XlaBuilder* builder, const Array4D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ absl::Span<GlobalData* const> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
@@ -558,7 +554,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
- tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
+ absl::Span<const NativeT> values, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = LiteralUtil::CreateR1(values);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
index 7c52c9fbbb..03d5696499 100644
--- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc
+++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -30,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@@ -38,10 +38,9 @@ namespace {
class CompilationCacheTest : public ClientLibraryTestBase {
public:
- void ExecuteComputationR0F32(
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, float expected_result,
- bool expect_cache_hit) {
+ void ExecuteComputationR0F32(const XlaComputation& computation,
+ absl::Span<GlobalData* const> arguments,
+ float expected_result, bool expect_cache_hit) {
ExecutionProfile execution_profile;
std::unique_ptr<Literal> result =
client_
@@ -56,7 +55,7 @@ class CompilationCacheTest : public ClientLibraryTestBase {
void ExecuteComputationR2F32(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
std::initializer_list<std::initializer_list<float>> expected_result,
bool expect_cache_hit) {
ExecutionProfile execution_profile;
diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc
index 50a9ebc1e9..526626c1dd 100644
--- a/tensorflow/compiler/xla/tests/copy_test.cc
+++ b/tensorflow/compiler/xla/tests/copy_test.cc
@@ -54,7 +54,7 @@ class CopyOpTest : public HloTestBase {
void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3);
void TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, size_t n4,
- tensorflow::gtl::ArraySlice<int64> permutation);
+ absl::Span<const int64> permutation);
};
XLA_TEST_F(CopyOpTest, CopyR0Bool) {
@@ -187,9 +187,9 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
LiteralTestUtil::ExpectR3EqualArray3D(a, *result);
}
-void CopyOpTest::TestCopyConstantLayoutR4(
- size_t n1, size_t n2, size_t n3, size_t n4,
- tensorflow::gtl::ArraySlice<int64> permutation) {
+void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
+ size_t n4,
+ absl::Span<const int64> permutation) {
Array4D<int32> a(n1, n2, n3, n4);
for (size_t i = 0; i < n1; ++i) {
for (size_t j = 0; j < n2; ++j) {
diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc
index 5f234f36a8..86fd1ceb13 100644
--- a/tensorflow/compiler/xla/tests/deallocation_test.cc
+++ b/tensorflow/compiler/xla/tests/deallocation_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include <memory>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -24,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace {
@@ -36,7 +36,7 @@ class DeallocationTest : public ClientLibraryTestBase {
// Build and execute the given computation then verify the results can be
// transferred from the device successfully.
std::unique_ptr<GlobalData> ExecuteAndCheckTransfer(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ XlaBuilder* builder, absl::Span<GlobalData* const> arguments) {
XlaComputation computation = builder->Build().ConsumeValueOrDie();
auto global_data =
client_->Execute(computation, arguments, &execution_options_)
diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
index 2db6503afa..eb15fc0593 100644
--- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@@ -42,7 +42,7 @@ class DeconstructTupleTest : public ClientLibraryTestBase {
// Build and execute the given computation then verify the results can be
// transferred from the device successfully.
std::unique_ptr<GlobalData> ExecuteAndCheckTransfer(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ XlaBuilder* builder, absl::Span<GlobalData* const> arguments) {
XlaComputation computation = builder->Build().ConsumeValueOrDie();
auto global_data =
client_->Execute(computation, arguments, &execution_options_)
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
index 7f6f203a1b..9bf3767ca3 100644
--- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
@@ -114,14 +114,14 @@ class DynamicSliceTest : public ClientLibraryTestBase {
}
template <typename IndexT, typename DataT>
- void RunR1(tensorflow::gtl::ArraySlice<int> input_values_int,
+ void RunR1(absl::Span<const int> input_values_int,
const std::vector<IndexT> slice_starts,
const std::vector<int64>& slice_sizes,
- tensorflow::gtl::ArraySlice<int> expected_values_int) {
+ absl::Span<const int> expected_values_int) {
// bfloat16 has explicit constructors, so it does not implicitly convert the
// way built-in types do, which is why we can't take the parameter as an
- // ArraySlice<DataT>. We also can't convert it to a vector, because
- // vector<bool> is special so that it cannot be an ArraySlice<bool>, which
+ // Span<DataT>. We also can't convert it to a vector, because
+ // vector<bool> is special so that it cannot be a Span<bool>, which
// is what the code below wants. So instead we do this.
Literal input_values =
std::move(*LiteralUtil::CreateR1(input_values_int)
@@ -385,10 +385,10 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
}
template <typename IndexT, typename DataT>
- void RunR1(tensorflow::gtl::ArraySlice<int> input_values_int,
- tensorflow::gtl::ArraySlice<int> update_values_int,
+ void RunR1(absl::Span<const int> input_values_int,
+ absl::Span<const int> update_values_int,
const std::vector<IndexT> slice_starts,
- tensorflow::gtl::ArraySlice<int> expected_values_int) {
+ absl::Span<const int> expected_values_int) {
Literal input_values =
std::move(*LiteralUtil::CreateR1(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc
index 4a835a8e21..3be9657db4 100644
--- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc
+++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc
@@ -17,12 +17,12 @@ limitations under the License.
#include <string>
#include "absl/strings/str_join.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -37,8 +37,8 @@ class FloorCeilTest : public ClientLibraryTestBase {
};
// Runs a computation and comparison on expected vs f(input)
- void TestR1F32(tensorflow::gtl::ArraySlice<float> input,
- tensorflow::gtl::ArraySlice<float> expected, Function f) {
+ void TestR1F32(absl::Span<const float> input,
+ absl::Span<const float> expected, Function f) {
LOG(INFO) << "input: {" << absl::StrJoin(expected, ", ") << "}";
XlaBuilder builder(TestName());
auto c = ConstantR1<float>(&builder, input);
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index 341124170a..7cb2f0cedf 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -23,6 +23,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "absl/memory/memory.h"
+#include "absl/types/span.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/client_library.h"
@@ -42,14 +43,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/platform/types.h"
-using tensorflow::gtl::ArraySlice;
-
namespace xla {
namespace {
@@ -113,7 +111,7 @@ class FusionTest : public HloTestBase {
hlos[0] = builder.AddInstruction(std::move(root_hlo));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(
- ArraySlice<HloInstruction*>(hlos, 0, Arity + 1),
+ absl::Span<HloInstruction* const>(hlos).subspan(0, Arity + 1),
HloInstruction::FusionKind::kLoop);
auto expected = LiteralUtil::CreateR2FromArray2D(answer_data);
@@ -127,12 +125,12 @@ class FusionTest : public HloTestBase {
private:
template <typename T>
- T ComputeElementwiseAnswer(HloOpcode opcode, ArraySlice<float> xs);
+ T ComputeElementwiseAnswer(HloOpcode opcode, absl::Span<const float> xs);
};
template <>
float FusionTest::ComputeElementwiseAnswer<float>(HloOpcode opcode,
- ArraySlice<float> xs) {
+ absl::Span<const float> xs) {
switch (opcode) {
case HloOpcode::kAdd:
return xs[0] + xs[1];
@@ -157,7 +155,7 @@ float FusionTest::ComputeElementwiseAnswer<float>(HloOpcode opcode,
template <>
bool FusionTest::ComputeElementwiseAnswer<bool>(HloOpcode opcode,
- ArraySlice<float> xs) {
+ absl::Span<const float> xs) {
switch (opcode) {
case HloOpcode::kEq:
return xs[0] == xs[1];
@@ -601,7 +599,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2},
- HloInstruction::FusionKind::kLoop);
+ HloInstruction::FusionKind::kInput);
EXPECT_TRUE(
LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(15),
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc
index 205d417f0c..6d63498044 100644
--- a/tensorflow/compiler/xla/tests/gather_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc
@@ -34,8 +34,7 @@ class GatherOperationTest : public HloTestBase {
RunTest(hlo_text, {operand, start_indices});
}
- void RunTest(const string& hlo_text,
- tensorflow::gtl::ArraySlice<Literal*> args) {
+ void RunTest(const string& hlo_text, absl::Span<Literal* const> args) {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc
index 51450314b6..1115e50fe3 100644
--- a/tensorflow/compiler/xla/tests/half_test.cc
+++ b/tensorflow/compiler/xla/tests/half_test.cc
@@ -126,9 +126,8 @@ INSTANTIATE_TEST_CASE_P(half, UnaryPredTest,
::testing::Values(UnaryPredTestParam{
[](half x) { return isfinite(x); }, &IsFinite}));
-using BinaryBuildFuncTy =
- std::function<void(const xla::XlaOp& x, const xla::XlaOp& y,
- tensorflow::gtl::ArraySlice<int64>)>;
+using BinaryBuildFuncTy = std::function<void(
+ const xla::XlaOp& x, const xla::XlaOp& y, absl::Span<const int64>)>;
struct BinaryOpTestParam {
std::function<half(half, half)> compute_func;
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index 93ea144438..fc4c68246e 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -44,7 +44,6 @@ namespace {
using absl::optional;
using absl::string_view;
-using tensorflow::gtl::ArraySlice;
constexpr char kInterpreter[] = "interpreter";
@@ -130,14 +129,12 @@ DebugOptions HloTestBase::GetDebugOptionsForTest() {
}
StatusOr<std::unique_ptr<Literal>> HloTestBase::Execute(
- std::unique_ptr<HloModule> module,
- tensorflow::gtl::ArraySlice<Literal*> arguments) {
+ std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments) {
return test_runner_.Execute(std::move(module), arguments);
}
std::unique_ptr<Literal> HloTestBase::ExecuteNoHloPasses(
- std::unique_ptr<HloModule> module,
- tensorflow::gtl::ArraySlice<Literal*> arguments) {
+ std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments) {
return test_runner_
.Execute(std::move(module), arguments,
/*run_hlo_passes=*/false)
@@ -145,8 +142,7 @@ std::unique_ptr<Literal> HloTestBase::ExecuteNoHloPasses(
}
std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
- std::unique_ptr<HloModule> module,
- tensorflow::gtl::ArraySlice<Literal*> arguments) {
+ std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments) {
return test_runner_.Execute(std::move(module), arguments).ValueOrDie();
}
@@ -169,7 +165,8 @@ StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule(
}
StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
- std::unique_ptr<HloModule> module, const ArraySlice<Literal*> arguments,
+ std::unique_ptr<HloModule> module,
+ const absl::Span<Literal* const> arguments,
const optional<ErrorSpec>& error, bool run_hlo_passes,
const std::function<void(HloModule*)>& reference_preprocessor) {
TF_RETURN_IF_ERROR(hlo_verifier_->Run(module.get()).status());
@@ -188,7 +185,8 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
}
::testing::AssertionResult HloTestBase::RunAndCompare(
- std::unique_ptr<HloModule> module, const ArraySlice<Literal*> arguments,
+ std::unique_ptr<HloModule> module,
+ const absl::Span<Literal* const> arguments,
const optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
auto result =
@@ -201,7 +199,8 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
}
::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
- std::unique_ptr<HloModule> module, const ArraySlice<Literal*> arguments,
+ std::unique_ptr<HloModule> module,
+ const absl::Span<Literal* const> arguments,
const optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
auto result =
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index 06bcc39741..4c88257bb2 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/test.h"
@@ -114,18 +114,15 @@ class HloTestBase : public ::testing::Test {
// Executes the given module and return the result as a Literal.
StatusOr<std::unique_ptr<Literal>> Execute(
- std::unique_ptr<HloModule> module,
- tensorflow::gtl::ArraySlice<Literal*> arguments);
+ std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments);
// Same as above, except the module will be executed without running any HLO
// passes on it.
std::unique_ptr<Literal> ExecuteNoHloPasses(
- std::unique_ptr<HloModule> module,
- tensorflow::gtl::ArraySlice<Literal*> arguments);
+ std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments);
std::unique_ptr<Literal> ExecuteAndTransfer(
- std::unique_ptr<HloModule> module,
- tensorflow::gtl::ArraySlice<Literal*> arguments);
+ std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments);
// Executes the given hlo module on two backends and compares results.
//
@@ -140,7 +137,7 @@ class HloTestBase : public ::testing::Test {
// modified.
::testing::AssertionResult RunAndCompare(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<Literal*> arguments,
+ const absl::Span<Literal* const> arguments,
const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
@@ -149,7 +146,7 @@ class HloTestBase : public ::testing::Test {
// optimization.
::testing::AssertionResult RunAndCompareNoHloPasses(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<Literal*> arguments,
+ const absl::Span<Literal* const> arguments,
const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
@@ -261,7 +258,7 @@ class HloTestBase : public ::testing::Test {
// error happens before the results are computed, returns the error status.
StatusOr<::testing::AssertionResult> RunAndCompareInternal(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<Literal*> arguments,
+ const absl::Span<Literal* const> arguments,
const absl::optional<ErrorSpec>& error, bool run_hlo_passes,
const std::function<void(HloModule*)>& reference_preprocessor);
};
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
index cc6967feed..8fbc4fa753 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
@@ -29,8 +29,8 @@ namespace xla {
// performs verification on that module on tear-down.
class HloVerifiedTestBase : public HloTestBase {
protected:
- explicit HloVerifiedTestBase(bool layout_sensitive,
- bool allow_mixed_precision);
+ explicit HloVerifiedTestBase(bool layout_sensitive = false,
+ bool allow_mixed_precision = false);
~HloVerifiedTestBase() override;
// Constructs a default shape verifier.
diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc
index 07c3c6b878..310f349592 100644
--- a/tensorflow/compiler/xla/tests/iota_test.cc
+++ b/tensorflow/compiler/xla/tests/iota_test.cc
@@ -39,7 +39,7 @@ TEST_P(IotaR1Test, DoIt) {
const auto element_type = std::get<0>(spec);
const int64 num_elements = std::get<1>(spec);
XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type));
- IotaGen(&builder, element_type, num_elements);
+ Iota(&builder, element_type, num_elements);
if (element_type == F32) {
ComputeAndCompareR1<float>(&builder, GetR1Expected<float>(num_elements), {},
ErrorSpec{0.0001});
@@ -71,7 +71,7 @@ TEST_P(IotaR2Test, DoIt) {
XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type));
std::vector<int64> dimensions = {42};
dimensions.insert(dimensions.begin() + iota_dim, num_elements);
- IotaGen(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim);
+ Iota(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim);
if (primitive_util::IsFloatingPointType(element_type)) {
ComputeAndCompare(&builder, {}, ErrorSpec{0.0001});
} else {
@@ -98,7 +98,7 @@ TEST_P(IotaR3Test, DoIt) {
XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type));
std::vector<int64> dimensions = {42, 19};
dimensions.insert(dimensions.begin() + iota_dim, num_elements);
- IotaGen(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim);
+ Iota(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim);
if (primitive_util::IsFloatingPointType(element_type)) {
ComputeAndCompare(&builder, {}, ErrorSpec{0.0001});
} else {
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h
index 3dad91951e..96f72212f3 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.h
+++ b/tensorflow/compiler/xla/tests/literal_test_util.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <string>
#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -62,7 +62,7 @@ class LiteralTestUtil {
static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual);
template <typename NativeT>
- static void ExpectR1Equal(tensorflow::gtl::ArraySlice<NativeT> expected,
+ static void ExpectR1Equal(absl::Span<const NativeT> expected,
const LiteralSlice& actual);
template <typename NativeT>
static void ExpectR2Equal(
@@ -102,7 +102,7 @@ class LiteralTestUtil {
const ErrorSpec& error);
template <typename NativeT>
- static void ExpectR1Near(tensorflow::gtl::ArraySlice<NativeT> expected,
+ static void ExpectR1Near(absl::Span<const NativeT> expected,
const LiteralSlice& actual, const ErrorSpec& error);
template <typename NativeT>
@@ -160,7 +160,7 @@ template <typename NativeT>
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Equal(
- tensorflow::gtl::ArraySlice<NativeT> expected, const LiteralSlice& actual) {
+ absl::Span<const NativeT> expected, const LiteralSlice& actual) {
EXPECT_TRUE(Equal(*LiteralUtil::CreateR1<NativeT>(expected), actual));
}
@@ -206,7 +206,7 @@ template <typename NativeT>
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Near(
- tensorflow::gtl::ArraySlice<NativeT> expected, const LiteralSlice& actual,
+ absl::Span<const NativeT> expected, const LiteralSlice& actual,
const ErrorSpec& error) {
EXPECT_TRUE(Near(*LiteralUtil::CreateR1<NativeT>(expected), actual, error));
}
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc
index 948b60061e..a8c68fc7fd 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.cc
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc
@@ -156,7 +156,7 @@ ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const {
ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ absl::Span<const ShapedBuffer* const> arguments) {
return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
DefaultExecutableRunOptions())
.ConsumeValueOrDie();
@@ -164,7 +164,7 @@ ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
const ExecutableBuildOptions& build_options,
const ExecutableRunOptions& run_options) {
return ExecuteLocally(computation, arguments, build_options, run_options)
@@ -173,14 +173,14 @@ ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ absl::Span<const ShapedBuffer* const> arguments) {
return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
DefaultExecutableRunOptions());
}
StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
const ExecutableBuildOptions& build_options,
const ExecutableRunOptions& run_options) {
std::vector<const Shape*> argument_layouts(arguments.size());
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h
index b4477e9a6b..90095c5d41 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.h
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
@@ -31,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -93,19 +93,19 @@ class LocalClientTestBase : public ::testing::Test {
// options.
StatusOr<ScopedShapedBuffer> ExecuteLocally(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
+ absl::Span<const ShapedBuffer* const> arguments);
StatusOr<ScopedShapedBuffer> ExecuteLocally(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
const ExecutableBuildOptions& build_options,
const ExecutableRunOptions& run_options);
ScopedShapedBuffer ExecuteLocallyOrDie(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
+ absl::Span<const ShapedBuffer* const> arguments);
ScopedShapedBuffer ExecuteLocallyOrDie(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
const ExecutableBuildOptions& build_options,
const ExecutableRunOptions& run_options);
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index 16b77e965d..05f90ba9fb 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
@@ -37,7 +38,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
@@ -47,7 +47,6 @@ limitations under the License.
namespace xla {
namespace {
-using ::tensorflow::gtl::ArraySlice;
class MultiOutputFusionTest : public HloTestBase {
protected:
@@ -96,8 +95,8 @@ class MultiOutputFusionTest : public HloTestBase {
auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
if (manual_fusion) {
- auto tuple = computation->AddInstruction(HloInstruction::CreateTuple(
- ArraySlice<HloInstruction*>({sub, add2}, 0, 2)));
+ auto tuple =
+ computation->AddInstruction(HloInstruction::CreateTuple({sub, add2}));
auto gte0 = computation->AddInstruction(
HloInstruction::CreateGetTupleElement(elem_shape2, tuple, 0));
auto gte1 = computation->AddInstruction(
@@ -159,8 +158,8 @@ class MultiOutputFusionTest : public HloTestBase {
auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
if (manual_fusion) {
- auto tuple = computation->AddInstruction(HloInstruction::CreateTuple(
- ArraySlice<HloInstruction*>({sub_U8, add}, 0, 2)));
+ auto tuple = computation->AddInstruction(
+ HloInstruction::CreateTuple({sub_U8, add}));
auto gte0 = computation->AddInstruction(
HloInstruction::CreateGetTupleElement(elem_shape_U8, tuple, 0));
diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc
index 2fc7f816b5..58539e6b06 100644
--- a/tensorflow/compiler/xla/tests/pred_test.cc
+++ b/tensorflow/compiler/xla/tests/pred_test.cc
@@ -31,7 +31,7 @@ class PredTest : public ClientLibraryTestBase {
protected:
void TestCompare(bool lhs, bool rhs, bool expected,
std::function<XlaOp(const xla::XlaOp&, const xla::XlaOp&,
- tensorflow::gtl::ArraySlice<int64>)>
+ absl::Span<const int64>)>
op) {
XlaBuilder builder(TestName());
XlaOp lhs_op = ConstantR0<bool>(&builder, lhs);
diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc
index 326e13b386..5f322b768d 100644
--- a/tensorflow/compiler/xla/tests/prng_test.cc
+++ b/tensorflow/compiler/xla/tests/prng_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include <limits>
#include <memory>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
@@ -26,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -37,8 +37,7 @@ namespace {
class PrngTest : public ClientLibraryTestBase {
protected:
template <typename T>
- std::unique_ptr<Literal> UniformTest(T a, T b,
- tensorflow::gtl::ArraySlice<int64> dims,
+ std::unique_ptr<Literal> UniformTest(T a, T b, absl::Span<const int64> dims,
int64 seed = 42);
// Computes the χ² statistic of a sample of the discrete uniform distribution
@@ -50,8 +49,9 @@ class PrngTest : public ClientLibraryTestBase {
};
template <typename T>
-std::unique_ptr<Literal> PrngTest::UniformTest(
- T a, T b, tensorflow::gtl::ArraySlice<int64> dims, int64 seed) {
+std::unique_ptr<Literal> PrngTest::UniformTest(T a, T b,
+ absl::Span<const int64> dims,
+ int64 seed) {
XlaBuilder builder(TestName());
RngUniform(
ConstantR0<T>(&builder, a), ConstantR0<T>(&builder, b),
@@ -61,7 +61,7 @@ std::unique_ptr<Literal> PrngTest::UniformTest(
auto actual =
ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie();
EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions()));
- actual->EachCell<T>([=](tensorflow::gtl::ArraySlice<int64>, T value) {
+ actual->EachCell<T>([=](absl::Span<const int64>, T value) {
EXPECT_LE(a, value);
EXPECT_LT(value, b);
});
@@ -117,7 +117,7 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) {
for (int64 seed = 0; seed < count; ++seed) {
auto result = UniformTest<bfloat16>(low, high, {}, /*seed=*/seed);
result->Literal::EachCell<bfloat16>(
- [&](tensorflow::gtl::ArraySlice<int64>, bfloat16 value) {
+ [&](absl::Span<const int64>, bfloat16 value) {
int64 index = static_cast<int64>((value - low) / interval);
counts[index]++;
});
@@ -149,8 +149,8 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count,
auto actual =
ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie();
std::vector<int32> counts(range_size, 0);
- actual->EachCell<int32>([&counts](tensorflow::gtl::ArraySlice<int64>,
- int32 value) { ++counts[value]; });
+ actual->EachCell<int32>(
+ [&counts](absl::Span<const int64>, int32 value) { ++counts[value]; });
int64 sum = 0;
for (int32 i = 0; i < range_size; ++i) {
sum += Square(static_cast<int64>(counts[i] - expected_count));
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index 346f702488..8c62adea23 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -34,6 +34,7 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
@@ -53,7 +54,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -115,8 +115,7 @@ class ReduceTest : public ClientLibraryTestBase {
ErrorSpec(0.001));
}
- void RunR1ToR0PredTest(bool and_reduce,
- tensorflow::gtl::ArraySlice<int> input_data) {
+ void RunR1ToR0PredTest(bool and_reduce, absl::Span<const int> input_data) {
const int element_count = input_data.size();
XlaBuilder builder(TestName());
const Shape input_shape = ShapeUtil::MakeShape(S32, {element_count});
@@ -261,8 +260,8 @@ class ReduceTest : public ClientLibraryTestBase {
void ComputeAndCompareGeneric(
typename std::enable_if<std::is_floating_point<NativeT>::value,
XlaBuilder>::type* builder,
- tensorflow::gtl::ArraySlice<NativeT> expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ absl::Span<const NativeT> expected,
+ absl::Span<GlobalData* const> arguments) {
ComputeAndCompareR1<NativeT>(builder, expected, arguments,
ErrorSpec(0.01, 1e-4));
}
@@ -271,8 +270,8 @@ class ReduceTest : public ClientLibraryTestBase {
void ComputeAndCompareGeneric(
typename std::enable_if<std::is_integral<NativeT>::value,
XlaBuilder>::type* builder,
- tensorflow::gtl::ArraySlice<NativeT> expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ absl::Span<const NativeT> expected,
+ absl::Span<GlobalData* const> arguments) {
ComputeAndCompareR1<NativeT>(builder, expected, arguments);
}
@@ -304,7 +303,7 @@ class ReduceTest : public ClientLibraryTestBase {
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
// NativeT can be bool, and std::vector<bool> does not convert to
- // ArraySlice.
+ // Span.
std::unique_ptr<NativeT[]> expected(new NativeT[cols]);
for (int64 colno = 0; colno < cols; ++colno) {
NativeT column_result = initial_value;
@@ -316,7 +315,7 @@ class ReduceTest : public ClientLibraryTestBase {
}
ComputeAndCompareGeneric<NativeT>(
- &builder, tensorflow::gtl::ArraySlice<NativeT>(expected.get(), cols),
+ &builder, absl::Span<const NativeT>(expected.get(), cols),
{input_global_data.get()});
}
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 60167619a4..997880a018 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -38,7 +39,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -57,7 +57,7 @@ class ReduceWindowTestBase : public ClientLibraryTestBase {
public:
ErrorSpec DefaultErrorSpec() const {
if (use_bfloat16()) {
- return ErrorSpec(1e-1, 5e-2);
+ return ErrorSpec(2e-1, 6e-2);
} else {
return ErrorSpec(1e-3, 1e-3);
}
@@ -70,8 +70,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); }
void ReduceWindowAdd(const XlaOp& input,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
Padding padding) {
auto init = CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0.0f),
&builder_);
@@ -81,8 +81,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
}
void ReduceWindowMax(const XlaOp& input,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
Padding padding) {
auto init =
CreateConstantFromLiteral(LiteralUtil::MinValue(F32), &builder_);
@@ -92,8 +92,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
}
void ReduceWindowMin(const XlaOp& input,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
Padding padding) {
auto init =
CreateConstantFromLiteral(LiteralUtil::MaxValue(F32), &builder_);
@@ -1303,7 +1303,7 @@ TEST_P(R1ReduceWindowTest, DoIt) {
std::vector<float> input_vector(param.base_bounds[0]);
std::iota(std::begin(input_vector), std::end(input_vector), 0);
std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR1(tensorflow::gtl::ArraySlice<float>(input_vector));
+ LiteralUtil::CreateR1(absl::Span<const float>(input_vector));
XlaOp parameter;
auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
&b, &parameter);
@@ -1327,7 +1327,7 @@ TEST_P(R1ReduceWindowTest, DoIt) {
? +[](float a, float b) { return a + b; }
: +[](float a, float b) { return std::max(a, b); };
auto expected = ReferenceUtil::ReduceWindow1DGeneric(
- /*operand=*/tensorflow::gtl::ArraySlice<float>(input_vector),
+ /*operand=*/absl::Span<const float>(input_vector),
/*init=*/kInitValue,
/*reduce_func=*/reduce_func,
/*window=*/param.window_bounds,
diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc
index 368f5583c9..ae24eb5eb4 100644
--- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <random>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc
index 382d1b1ae7..17d12715f6 100644
--- a/tensorflow/compiler/xla/tests/reshape_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <random>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
@@ -35,7 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -689,9 +689,8 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
std::mt19937 rng;
std::uniform_real_distribution<float> distribution;
Array4D<float> input(2, 1, 1, 1);
- input.Each(
- [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
- float* cell) { *cell = distribution(rng); });
+ input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
@@ -711,9 +710,8 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
std::mt19937 rng;
std::uniform_real_distribution<float> distribution;
Array4D<float> input(2, 1, 4, 1);
- input.Each(
- [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
- float* cell) { *cell = distribution(rng); });
+ input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
@@ -734,9 +732,8 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
std::mt19937 rng;
std::uniform_real_distribution<float> distribution;
Array4D<float> input(5, 10, 2, 3);
- input.Each(
- [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
- float* cell) { *cell = distribution(rng); });
+ input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
@@ -747,7 +744,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
/*new_sizes=*/{5, 60});
Array2D<float> expected_array(5, 60);
- input.Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* cell) {
+ input.Each([&](absl::Span<const int64> indices, float* cell) {
expected_array(indices[0], indices[2] * 30 + indices[1] * 3 + indices[3]) =
*cell;
});
@@ -762,7 +759,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
std::uniform_real_distribution<float> distribution;
Array4D<float> input_array(2, 3, 5, 7);
input_array.Each(
- [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ [&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
LiteralUtil::CreateR4FromArray4DWithLayout(
@@ -842,9 +839,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) {
std::vector<int64> bounds = {2, 2, 2, 2};
std::vector<int64> new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]};
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
- input.Each(
- [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
- float* cell) { *cell = distribution(rng); });
+ input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
@@ -871,9 +867,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
std::vector<int64> bounds = {1, 1, 250, 300};
std::vector<int64> new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]};
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
- input.Each(
- [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
- float* cell) { *cell = distribution(rng); });
+ input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
@@ -900,9 +895,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
std::vector<int64> bounds = {5, 5, 1, 10};
std::vector<int64> new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]};
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
- input.Each(
- [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
- float* cell) { *cell = distribution(rng); });
+ input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
@@ -930,9 +924,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
std::vector<int64> bounds = {5, 5, 10, 1};
std::vector<int64> new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]};
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
- input.Each(
- [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
- float* cell) { *cell = distribution(rng); });
+ input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
@@ -959,9 +952,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
std::vector<int64> bounds = {3, 3, 1, 3};
std::vector<int64> new_bounds = {bounds[1], bounds[0], bounds[2], bounds[3]};
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
- input.Each(
- [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
- float* cell) { *cell = distribution(rng); });
+ input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({0, 1, 2, 3}));
diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc
index c755ff63c9..74ded82ddf 100644
--- a/tensorflow/compiler/xla/tests/reverse_test.cc
+++ b/tensorflow/compiler/xla/tests/reverse_test.cc
@@ -39,8 +39,8 @@ static std::array<bool, 1> use_bfloat16_params{false};
#endif
struct ReverseSpec {
- tensorflow::gtl::ArraySlice<int64> input_dims;
- tensorflow::gtl::ArraySlice<int64> reversal;
+ absl::Span<const int64> input_dims;
+ absl::Span<const int64> reversal;
bool use_bfloat16;
string ToTestCaseName() const {
@@ -91,17 +91,16 @@ TEST_P(FloatReverseTest, Reverses) {
std::unique_ptr<Literal> expected = input_literal->CloneToUnique();
std::vector<int64> output_indices(spec.input_dims.size());
- expected->EachCell<float>(
- [&](tensorflow::gtl::ArraySlice<int64> indices, float) {
- for (int64 i = 0; i < indices.size(); ++i) {
- output_indices[i] = indices[i];
- }
- float value = input_literal->Get<float>(indices);
- for (int64 dim : spec.reversal) {
- output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim];
- }
- expected->Set<float>(output_indices, value);
- });
+ expected->EachCell<float>([&](absl::Span<const int64> indices, float) {
+ for (int64 i = 0; i < indices.size(); ++i) {
+ output_indices[i] = indices[i];
+ }
+ float value = input_literal->Get<float>(indices);
+ for (int64 dim : spec.reversal) {
+ output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim];
+ }
+ expected->Set<float>(output_indices, value);
+ });
ComputeAndCompareLiteral(&builder, *expected, {});
}
diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
index a620fe1908..e692b8c5d5 100644
--- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include <memory>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/layout_util.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/casts.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -47,8 +47,7 @@ class RoundTripPackedLiteralTest : public ClientLibraryTestBase {
TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) {
string data(sizeof(float) * 2, 0);
- tensorflow::gtl::MutableArraySlice<float> floats(
- tensorflow::bit_cast<float*>(data.data()), 2);
+ absl::Span<float> floats(tensorflow::bit_cast<float*>(data.data()), 2);
floats[0] = 42.0;
floats[1] = 24.0;
@@ -70,8 +69,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) {
TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
string data(sizeof(float) * 4, 0);
- tensorflow::gtl::MutableArraySlice<float> floats(
- tensorflow::bit_cast<float*>(data.data()), 4);
+ absl::Span<float> floats(tensorflow::bit_cast<float*>(data.data()), 4);
// With x as the minor dimension, these will become:
floats[0] = 42.0; // y=0,x=0
floats[1] = 24.0; // y=0,x=1
@@ -105,8 +103,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
string data(sizeof(float) * 4, 0);
- tensorflow::gtl::MutableArraySlice<float> floats(
- tensorflow::bit_cast<float*>(data.data()), 4);
+ absl::Span<float> floats(tensorflow::bit_cast<float*>(data.data()), 4);
// With y as the minor dimension, these will become:
floats[0] = 42.0; // y=0,x=0
floats[1] = 24.0; // y=1,x=0
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index cf2d453f43..07460a7e01 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -31,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -46,9 +46,8 @@ class ScalarComputationsTest : public ClientLibraryTestBase {
// A template for building and running a binary comparison test.
template <typename NativeT>
void TestCompare(NativeT lhs, NativeT rhs, bool expected,
- std::function<XlaOp(const XlaOp&, const XlaOp&,
- tensorflow::gtl::ArraySlice<int64>)>
- op) {
+ const std::function<XlaOp(const XlaOp&, const XlaOp&,
+ absl::Span<const int64>)>& op) {
XlaBuilder builder(TestName());
XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs);
XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs);
@@ -58,9 +57,8 @@ class ScalarComputationsTest : public ClientLibraryTestBase {
template <typename NativeT>
void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected,
- std::function<XlaOp(const XlaOp&, const XlaOp&,
- tensorflow::gtl::ArraySlice<int64>)>
- op) {
+ const std::function<XlaOp(const XlaOp&, const XlaOp&,
+ absl::Span<const int64>)>& op) {
XlaBuilder builder(TestName());
XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs);
XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs);
diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc
index 99eeb12e2b..1858dcea61 100644
--- a/tensorflow/compiler/xla/tests/scatter_test.cc
+++ b/tensorflow/compiler/xla/tests/scatter_test.cc
@@ -32,8 +32,7 @@ class ScatterTest : public HloTestBase {
RunTest(hlo_text, {operand, scatter_indices, updates});
}
- void RunTest(const string& hlo_text,
- tensorflow::gtl::ArraySlice<Literal*> args) {
+ void RunTest(const string& hlo_text, absl::Span<Literal* const> args) {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
index e3d4f98dd7..f737b5158b 100644
--- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
+++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
@@ -42,8 +42,8 @@ struct SelectAndScatterTestParam {
std::vector<int64> operand_shape;
std::vector<int64> source_shape;
Padding padding_type;
- tensorflow::gtl::ArraySlice<int64> window_dimensions;
- tensorflow::gtl::ArraySlice<int64> window_strides;
+ absl::Span<const int64> window_dimensions;
+ absl::Span<const int64> window_strides;
};
class SelectAndScatterTest
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index 69585ae39a..c9a58aefb4 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -29,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -194,7 +194,7 @@ class SliceR1Test : public ClientLibraryTestBase,
protected:
template <typename NativeT>
void Run(const R1Spec& spec) {
- // This can't be an std::vector, since you can't grab an ArraySlice of a
+ // This can't be an std::vector, since you can't grab a Span of a
// vector<bool>.
absl::InlinedVector<NativeT, 1> input(spec.input_dim0);
std::iota(input.begin(), input.end(), NativeT());
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 776f93d9f7..c20a7c8fe4 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -183,8 +183,8 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
break;
case PRED: {
std::uniform_int_distribution<int> generator(0, 1);
- TF_CHECK_OK(literal->Populate<bool>(
- [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+ TF_CHECK_OK(
+ literal->Populate<bool>([&](absl::Span<const int64> /*indices*/) {
return generator(*engine);
}));
break;
@@ -203,6 +203,7 @@ enum class ConstantType { kUnknown, kZero, kOne };
// Return the constant type required by this computation, if known.
ConstantType GetInitValue(const HloComputation& computation) {
+ // TODO(b/77635120): Add init values, for min, max, and their arg variants.
const HloInstruction* const root = computation.root_instruction();
if (computation.num_parameters() != 2 || root->operand_count() != 2 ||
root->operand(0)->opcode() != HloOpcode::kParameter ||
@@ -227,16 +228,16 @@ bool NeedsInitValue(const HloUse& use) {
const HloInstruction* const instruction = use.instruction;
const HloOpcode opcode = instruction->opcode();
const int64 op_num = use.operand_number;
- return (
- ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) &&
- op_num == 1) ||
- (opcode == HloOpcode::kSelectAndScatter && op_num == 2));
+ return ((opcode == HloOpcode::kReduceWindow && op_num == 1) ||
+ (opcode == HloOpcode::kSelectAndScatter && op_num == 2) ||
+ (opcode == HloOpcode::kReduce &&
+ op_num >= instruction->operand_count() / 2));
}
// Generate random values that are constrained to the input_shape minus the
// output_shape so as not to produce wrapping slices, for instance.
-std::unique_ptr<Literal> MakeRandomIndex(
- tensorflow::gtl::ArraySlice<int64> index_space, std::minstd_rand0* engine) {
+std::unique_ptr<Literal> MakeRandomIndex(absl::Span<const int64> index_space,
+ std::minstd_rand0* engine) {
std::vector<int32> start_indices(index_space.size());
if (engine != nullptr) {
for (int i = 0; i < index_space.size(); ++i) {
@@ -293,7 +294,7 @@ std::vector<HloInstruction*> FindConstrainedUses(
// generate a constrained literal (either bounded in the case of indices, or
// zero in the case of init_values for reductions).
StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
- const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses,
+ const absl::Span<HloInstruction* const> constrained_uses,
const HloInstruction& param, std::minstd_rand0* engine) {
std::vector<int64> index_space;
bool no_duplicates = false;
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
index 277d53d423..7790737c09 100644
--- a/tensorflow/compiler/xla/tests/test_utils.h
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -21,11 +21,11 @@ limitations under the License.
#include <random>
#include "absl/memory/memory.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/platform.h"
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
index c101cd2d20..f2b3b49015 100644
--- a/tensorflow/compiler/xla/tests/tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -507,7 +507,7 @@ XLA_TEST_F(TupleTest, ComplexTuples) {
{{10011, 20022}, {30031, 40042}}});
auto prod = absl::make_unique<Literal>(sum->shape());
ASSERT_TRUE(prod->Populate<complex64>(
- [&sum](tensorflow::gtl::ArraySlice<int64> indexes) {
+ [&sum](absl::Span<const int64> indexes) {
return sum->Get<complex64>(indexes) *
(indexes[indexes.size() - 1] == 0
? complex64(1, 2)
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index 6a7ddd9b55..7fd42944de 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -84,7 +84,7 @@ struct ParsedProfileOutputLine {
Status ParseOneProfileOutputLine(
const string& line, bool expect_hlo,
gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results,
- tensorflow::gtl::ArraySlice<absl::string_view> opcodes_to_ignore = {}) {
+ absl::Span<const absl::string_view> opcodes_to_ignore = {}) {
string separator = "[^:]*:: +";
string match_percentage = R"(\d+\.\d*% +\d+Σ)";
string match_cycles = R"((\d+) cycles +\( *()" + match_percentage + R"()\))";
@@ -171,10 +171,10 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
ServiceExecutableRunOptions run_options(
exec_run_options, /*borrow_stream=*/nullptr,
backend->eigen_intra_op_thread_pool());
+ std::vector<const ShapedBuffer*> args = {&lhs_arg, &rhs_arg};
TF_ASSERT_OK_AND_ASSIGN(
auto execution_result,
- executable->ExecuteOnStream(&run_options, {&lhs_arg, &rhs_arg},
- &hlo_execution_profile));
+ executable->ExecuteOnStream(&run_options, args, &hlo_execution_profile));
TF_ASSERT_OK(stream_ptr->BlockHostUntilDone());
(void)execution_result;
diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc
index 00147015a6..7289ae7df6 100644
--- a/tensorflow/compiler/xla/text_literal_writer.cc
+++ b/tensorflow/compiler/xla/text_literal_writer.cc
@@ -19,12 +19,12 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/types.h"
@@ -46,8 +46,7 @@ namespace xla {
Status status;
tensorflow::WritableFile* f_ptr = f.get();
literal.EachCellAsString(
- [f_ptr, &status](tensorflow::gtl::ArraySlice<int64> indices,
- const string& value) {
+ [f_ptr, &status](absl::Span<const int64> indices, const string& value) {
if (!status.ok()) {
return;
}
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index f23c5b3ef1..3a086c66bb 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -24,6 +24,7 @@ tf_cc_binary(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/base",
"@com_google_absl//absl/strings",
],
)
@@ -43,6 +44,7 @@ cc_library(
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -68,6 +70,7 @@ tf_cc_binary(
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:interpreter_plugin",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -95,6 +98,7 @@ cc_library(
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
alwayslink = True,
)
@@ -173,6 +177,7 @@ tf_cc_binary(
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:interpreter_plugin",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -194,6 +199,7 @@ tf_cc_binary(
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -213,6 +219,7 @@ tf_cc_binary(
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:interpreter_plugin",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc
index d15b71b792..c866a13de7 100644
--- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -38,7 +39,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -46,7 +46,7 @@ limitations under the License.
namespace xla {
namespace tools {
-void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
+void RealMain(absl::Span<char* const> args) {
Client* client = ClientLibrary::LocalClientOrDie();
for (char* arg : args) {
HloSnapshot module;
@@ -77,7 +77,7 @@ int main(int argc, char** argv) {
}
tensorflow::port::InitMain(argv[0], &argc, &argv);
- tensorflow::gtl::ArraySlice<char*> args(argv, argc);
+ absl::Span<char* const> args(argv, argc);
args.remove_prefix(1); // Pop off the binary name, argv[0]
xla::tools::RealMain(args);
return 0;
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc
index c446b27a04..4375e7c138 100644
--- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -31,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -59,7 +59,7 @@ class OperationDumper : public DfsHloVisitorWithDefault {
string path_;
};
-void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
+void RealMain(absl::Span<char* const> args) {
LocalClient* client = ClientLibrary::LocalClientOrDie();
LocalService* local_service =
ClientLibrary::GetXlaService(client->platform());
@@ -104,7 +104,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
int main(int argc, char** argv) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
- tensorflow::gtl::ArraySlice<char*> args(argv, argc);
+ absl::Span<char* const> args(argv, argc);
args.remove_prefix(1); // Pop off the binary name, argv[0]
xla::tools::RealMain(args);
return 0;
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
index d86a4474b3..723569862c 100644
--- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -26,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -34,7 +34,7 @@ limitations under the License.
namespace xla {
namespace tools {
-void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool compile) {
+void RealMain(absl::Span<char* const> args, bool compile) {
LocalClient* client = ClientLibrary::LocalClientOrDie();
LocalService* local_service =
ClientLibrary::GetXlaService(client->platform());
@@ -102,7 +102,7 @@ int main(int argc, char** argv) {
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
QCHECK(argc > 1) << "\nERROR: must specify at least one module\n" << usage;
- tensorflow::gtl::ArraySlice<char*> args(argv, argc);
+ absl::Span<char* const> args(argv, argc);
args.remove_prefix(1); // Pop off the binary name, argv[0]
xla::tools::RealMain(args, compile);
return 0;
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc
index bd8b89542f..07ef5ff656 100644
--- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -35,7 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/service.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -45,7 +45,7 @@ using tensorflow::Env;
namespace xla {
namespace tools {
-void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
+void RealMain(absl::Span<char* const> args) {
Client* client = ClientLibrary::LocalClientOrDie();
for (char* arg : args) {
HloSnapshot module;
@@ -78,7 +78,7 @@ int main(int argc, char** argv) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
- tensorflow::gtl::ArraySlice<char*> args(argv, argc);
+ absl::Span<char* const> args(argv, argc);
args.remove_prefix(1); // Pop off the binary name, argv[0]
xla::tools::RealMain(args);
return 0;
diff --git a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
index 75b63c3b84..23ce1d235b 100644
--- a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
+++ b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
@@ -17,9 +17,9 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/base/casts.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
@@ -67,9 +67,8 @@ int main(int argc, char** argv) {
floats.push_back(value);
}
- tensorflow::StringPiece content( // non-absl ok
- tensorflow::bit_cast<const char*>(floats.data()),
- floats.size() * sizeof(float));
+ absl::string_view content(absl::bit_cast<const char*>(floats.data()),
+ floats.size() * sizeof(float));
TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
output_file, content));
return 0;
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index e826d6fa93..ba814af476 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -40,6 +40,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/global_data.h"
@@ -59,7 +60,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -253,7 +253,7 @@ StatusOr<HloSnapshot> ParseInputFile(const string& filename,
return InvalidArgument("Could not parse %s.", filename);
}
-int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
+int RealMain(absl::Span<char* const> args, const Options& opts) {
LocalClient* client = ClientLibrary::LocalClientOrDie();
int exit_status = EXIT_SUCCESS;
@@ -344,7 +344,7 @@ int main(int argc, char** argv) {
LOG(QFATAL) << usage;
}
- tensorflow::gtl::ArraySlice<char*> args(argv, argc);
+ absl::Span<char* const> args(argv, argc);
args.remove_prefix(1); // Pop off the binary name, argv[0]
return xla::tools::RealMain(args, opts);
}
diff --git a/tensorflow/compiler/xla/tools/show_signature.cc b/tensorflow/compiler/xla/tools/show_signature.cc
index 10e7202acf..cdf306dfd1 100644
--- a/tensorflow/compiler/xla/tools/show_signature.cc
+++ b/tensorflow/compiler/xla/tools/show_signature.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -37,7 +38,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -45,7 +45,7 @@ limitations under the License.
namespace xla {
namespace tools {
-void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
+void RealMain(absl::Span<char* const> args) {
Client* client = ClientLibrary::LocalClientOrDie();
for (char* arg : args) {
HloSnapshot module;
@@ -66,7 +66,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
int main(int argc, char** argv) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
- tensorflow::gtl::ArraySlice<char*> args(argv, argc);
+ absl::Span<char* const> args(argv, argc);
args.remove_prefix(1); // Pop off the binary name, argv[0]
xla::tools::RealMain(args);
return 0;
diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc
index 0f607a0c8a..68cab7387c 100644
--- a/tensorflow/compiler/xla/util.cc
+++ b/tensorflow/compiler/xla/util.cc
@@ -76,7 +76,7 @@ string Reindent(absl::string_view original,
});
}
-bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank) {
+bool IsPermutation(absl::Span<const int64> permutation, int64 rank) {
if (rank != permutation.size()) {
return false;
}
@@ -90,7 +90,7 @@ bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank) {
}
std::vector<int64> InversePermutation(
- tensorflow::gtl::ArraySlice<int64> input_permutation) {
+ absl::Span<const int64> input_permutation) {
DCHECK(IsPermutation(input_permutation, input_permutation.size()));
std::vector<int64> output_permutation(input_permutation.size(), -1);
for (size_t i = 0; i < input_permutation.size(); ++i) {
@@ -99,8 +99,8 @@ std::vector<int64> InversePermutation(
return output_permutation;
}
-std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,
- tensorflow::gtl::ArraySlice<int64> p2) {
+std::vector<int64> ComposePermutations(absl::Span<const int64> p1,
+ absl::Span<const int64> p2) {
CHECK_EQ(p1.size(), p2.size());
std::vector<int64> output;
for (size_t i = 0; i < p1.size(); ++i) {
@@ -109,7 +109,7 @@ std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,
return output;
}
-bool IsIdentityPermutation(tensorflow::gtl::ArraySlice<int64> permutation) {
+bool IsIdentityPermutation(absl::Span<const int64> permutation) {
for (int64 i = 0; i < permutation.size(); ++i) {
if (permutation[i] != i) {
return false;
@@ -130,7 +130,7 @@ PaddingConfig MakeNoPaddingConfig(int64 rank) {
}
PaddingConfig MakeEdgePaddingConfig(
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
+ absl::Span<const std::pair<int64, int64>> padding) {
PaddingConfig padding_config;
for (const std::pair<int64, int64>& dim : padding) {
auto dimension = padding_config.add_dimensions();
@@ -207,14 +207,13 @@ void LogLines(int sev, absl::string_view text, const char* fname, int lineno) {
}
}
-int64 Product(tensorflow::gtl::ArraySlice<int64> xs) {
+int64 Product(absl::Span<const int64> xs) {
return std::accumulate(xs.begin(), xs.end(), static_cast<int64>(1),
std::multiplies<int64>());
}
-std::vector<std::pair<int64, int64>> CommonFactors(
- tensorflow::gtl::ArraySlice<int64> a,
- tensorflow::gtl::ArraySlice<int64> b) {
+std::vector<std::pair<int64, int64>> CommonFactors(absl::Span<const int64> a,
+ absl::Span<const int64> b) {
CHECK_EQ(Product(a), Product(b));
if (0 == Product(a)) {
return {std::make_pair(0, 0), std::make_pair(a.size(), b.size())};
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index 62f486369f..8ce7416474 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -29,13 +29,13 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logging.h"
@@ -101,65 +101,63 @@ struct ScopedLoggingTimer {
uint64 start_micros;
};
-// Given a vector<T>, returns a MutableArraySlice<char> that points at its
+// Given a vector<T>, returns a Span<char> that points at its
// internals.
//
// Warning: if the vector is updated its storage pointer may change, so use this
// with caution (ideally in limited scopes with temporary lifetimes).
template <typename T>
-tensorflow::gtl::MutableArraySlice<uint8> MutableByteSlice(std::vector<T>* v) {
- return tensorflow::gtl::MutableArraySlice<uint8>(
- reinterpret_cast<uint8*>(v->data()), v->size() * sizeof(T));
+absl::Span<uint8> MutableByteSlice(std::vector<T>* v) {
+ return absl::Span<uint8>(reinterpret_cast<uint8*>(v->data()),
+ v->size() * sizeof(T));
}
// Turns an immutable slice of type T into an immutable slice of bytes with the
// same byte size.
template <typename T>
-tensorflow::gtl::ArraySlice<uint8> CastToByteSlice(
- tensorflow::gtl::ArraySlice<T> slice) {
- return tensorflow::gtl::ArraySlice<uint8>(
- reinterpret_cast<const uint8*>(slice.data()), slice.size() * sizeof(T));
+absl::Span<const uint8> CastToByteSlice(absl::Span<const T> slice) {
+ return absl::Span<const uint8>(reinterpret_cast<const uint8*>(slice.data()),
+ slice.size() * sizeof(T));
}
// Casts a byte slice to a non-byte type T, checking that the original slice
// length is a multiple of sizeof(T).
template <typename T>
-tensorflow::gtl::ArraySlice<T> CastByteSlice(
- tensorflow::gtl::ArraySlice<uint8> slice) {
+absl::Span<const T> CastByteSlice(absl::Span<const uint8> slice) {
CHECK_EQ(0, slice.size() % sizeof(T));
- return tensorflow::gtl::ArraySlice<T>(
- reinterpret_cast<const T*>(slice.data()), slice.size() / sizeof(T));
+ return absl::Span<const T>(reinterpret_cast<const T*>(slice.data()),
+ slice.size() / sizeof(T));
}
// Convenience function to force a vector to convert to an immutable slice.
template <typename T>
-tensorflow::gtl::ArraySlice<T> AsSlice(const std::vector<T>& v) {
- return tensorflow::gtl::ArraySlice<T>(v);
+absl::Span<const T> AsSlice(const std::vector<T>& v) {
+ return absl::Span<const T>(v);
}
-// Converts a mutable vector pointer into a MutableArraySlice of the same
+// Converts a mutable vector pointer into a Span of the same
// type.
template <typename T>
-tensorflow::gtl::MutableArraySlice<T> AsMutableSlice(std::vector<T>* v) {
- return tensorflow::gtl::MutableArraySlice<T>(v->data(), v->size());
+absl::Span<T> AsMutableSlice(std::vector<T>* v) {
+ return absl::Span<T>(v->data(), v->size());
}
// xla::int64 is not the same type as tensorflow::protobuf_int64 in open-source.
// Wrapper function that gives an int64 array slice view of a repeated int64
// protobuf field.
-static inline tensorflow::gtl::ArraySlice<int64> AsInt64Slice(
+static inline absl::Span<const int64> AsInt64Slice(
const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>& v) {
- tensorflow::gtl::ArraySlice<tensorflow::protobuf_int64> slice(v);
- return tensorflow::gtl::ArraySlice<int64>(
- reinterpret_cast<const int64*>(slice.data()), slice.size());
+ absl::Span<const tensorflow::protobuf_int64> slice(v);
+ return absl::Span<const int64>(reinterpret_cast<const int64*>(slice.data()),
+ slice.size());
}
// As above, but for uint64 types.
-static inline tensorflow::gtl::ArraySlice<uint64> AsUInt64Slice(
+static inline absl::Span<const uint64> AsUInt64Slice(
const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_uint64>& v) {
- tensorflow::gtl::ArraySlice<tensorflow::protobuf_uint64> slice(v);
- return tensorflow::gtl::ArraySlice<uint64>(
- reinterpret_cast<const uint64*>(slice.data()), slice.size());
+ absl::Span<const tensorflow::protobuf_uint64> slice(v);
+ return absl::Span<const uint64>(reinterpret_cast<const uint64*>(slice.data()),
+ slice.size());
}
// Compares two containers for equality. Returns true iff the two containers
@@ -175,7 +173,7 @@ template <typename Container1T,
typename ElementType = typename Container1T::value_type>
bool ContainersEqual(const Container1T& c1,
std::initializer_list<ElementType> il) {
- tensorflow::gtl::ArraySlice<ElementType> c2{il};
+ absl::Span<const ElementType> c2{il};
return ContainersEqual(c1, c2);
}
@@ -193,9 +191,9 @@ bool ContainersEqual(const Container1T& c1, const Container2T& c2,
// source and destination. The source starting index is src_base, while the
// destination one is dest_base.
template <typename D, typename S>
-void StridedCopy(tensorflow::gtl::MutableArraySlice<D> dest, int64 dest_base,
- int64 dest_stride, tensorflow::gtl::ArraySlice<S> src,
- int64 src_base, int64 src_stride, int64 count) {
+void StridedCopy(absl::Span<D> dest, int64 dest_base, int64 dest_stride,
+ absl::Span<const S> src, int64 src_base, int64 src_stride,
+ int64 count) {
for (; count > 0; --count, dest_base += dest_stride, src_base += src_stride) {
dest[dest_base] = static_cast<D>(src[src_base]);
}
@@ -285,7 +283,7 @@ Status ResourceExhaustedStrCat(Args&&... concat) {
string Reindent(absl::string_view original, absl::string_view indentation);
// Checks whether permutation is a permutation of the [0, rank) integer range.
-bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank);
+bool IsPermutation(absl::Span<const int64> permutation, int64 rank);
// Applies `permutation` on `input` and returns the permuted array.
// For each i, output[permutation[i]] = input[i].
@@ -293,10 +291,11 @@ bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank);
// Precondition:
// 1. `permutation` is a permutation of 0..permutation.size()-1.
// 2. permutation.size() == input.size().
-template <template <typename...> class C, typename T>
-std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
- C<T> input) {
- tensorflow::gtl::ArraySlice<T> data(input);
+template <typename Container>
+std::vector<typename Container::value_type> Permute(
+ absl::Span<const int64> permutation, const Container& input) {
+ using T = typename Container::value_type;
+ absl::Span<const T> data(input);
CHECK(IsPermutation(permutation, data.size()));
std::vector<T> output(data.size());
for (size_t i = 0; i < permutation.size(); ++i) {
@@ -305,27 +304,16 @@ std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
return output;
}
-// Override of the above that works around compile failures with gcc 7.1.1.
-// For details see https://github.com/tensorflow/tensorflow/issues/10843
-// Hide this workaround from MSVC as it causes ambiguous error.
-#ifndef _MSC_VER
-template <typename T>
-std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
- const std::vector<T>& input) {
- return Permute<std::vector, T>(permutation, input);
-}
-#endif
-
// Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i.
std::vector<int64> InversePermutation(
- tensorflow::gtl::ArraySlice<int64> input_permutation);
+ absl::Span<const int64> input_permutation);
// Composes two permutations: output[i] = p1[p2[i]].
-std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,
- tensorflow::gtl::ArraySlice<int64> p2);
+std::vector<int64> ComposePermutations(absl::Span<const int64> p1,
+ absl::Span<const int64> p2);
// Returns true iff permutation == {0, 1, 2, ...}.
-bool IsIdentityPermutation(tensorflow::gtl::ArraySlice<int64> permutation);
+bool IsIdentityPermutation(absl::Span<const int64> permutation);
template <typename Container>
int64 PositionInContainer(const Container& container, int64 value) {
@@ -379,7 +367,7 @@ PaddingConfig MakeNoPaddingConfig(int64 rank);
// Returns a PaddingConfig object where 'padding' contains
// (low edge padding, high edge padding) pairs for each dimension.
PaddingConfig MakeEdgePaddingConfig(
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ absl::Span<const std::pair<int64, int64>> padding);
// Returns true if the padding configuration has at least one dimension with
// non-zero interior padding.
@@ -446,7 +434,7 @@ std::unique_ptr<Derived> unique_ptr_static_cast(std::unique_ptr<Base> ptr) {
return std::unique_ptr<Derived>(static_cast<Derived*>(ptr.release()));
}
-int64 Product(tensorflow::gtl::ArraySlice<int64> xs);
+int64 Product(absl::Span<const int64> xs);
// Returns the start indices of consecutive non-overlapping subsequences of `a`
// and `b` with the same product, i.e. `(i, j)` so
@@ -459,8 +447,8 @@ int64 Product(tensorflow::gtl::ArraySlice<int64> xs);
//
// If the given shapes have non-zero size, returns the bounds of the shortest
// possible such subsequences; else, returns `{(0, 0), (a.size, b.size)}`.
-std::vector<std::pair<int64, int64>> CommonFactors(
- tensorflow::gtl::ArraySlice<int64> a, tensorflow::gtl::ArraySlice<int64> b);
+std::vector<std::pair<int64, int64>> CommonFactors(absl::Span<const int64> a,
+ absl::Span<const int64> b);
// Removes illegal characters from filenames.
string SanitizeFileName(string file_name);
@@ -471,11 +459,6 @@ int64 FindIndex(const C& c, Value&& value) {
return std::distance(c.begin(), it);
}
-template <typename T>
-bool ArrayContains(tensorflow::gtl::ArraySlice<T> c, const T& value) {
- return absl::c_find(c, value) != c.end();
-}
-
template <typename C, typename Value>
void InsertAt(C* c, int64 index, Value&& value) {
c->insert(c->begin() + index, std::forward<Value>(value));
@@ -487,7 +470,7 @@ void EraseAt(C* c, int64 index) {
}
template <typename T>
-std::vector<T> ArraySliceToVector(tensorflow::gtl::ArraySlice<T> slice) {
+std::vector<T> ArraySliceToVector(absl::Span<const T> slice) {
return std::vector<T>(slice.begin(), slice.end());
}
diff --git a/tensorflow/compiler/xla/util_test.cc b/tensorflow/compiler/xla/util_test.cc
index 288479c893..50a3c545fb 100644
--- a/tensorflow/compiler/xla/util_test.cc
+++ b/tensorflow/compiler/xla/util_test.cc
@@ -37,45 +37,6 @@ TEST(UtilTest, ReindentsDifferentNumberOfLeadingSpacesUniformly) {
EXPECT_EQ(want, got);
}
-// Some smoke tests for ContainersEqual. Keeping it simple since these are just
-// basic wrappers around std::equal.
-TEST(UtilTest, ContainersEqualDefault) {
- std::vector<int> c1 = {1, 2, 3, 4};
- std::vector<int> c2 = {1, 2, 3};
- std::vector<int> c3 = {};
- std::vector<int> c4 = {1, 2, 3, 4};
- std::vector<int> c5 = {1, 2, 3, 4, 5};
- std::vector<int> c6 = {1, 3, 4, 5};
-
- EXPECT_TRUE(ContainersEqual(c1, c4));
- EXPECT_TRUE(ContainersEqual(c4, c1));
- EXPECT_FALSE(ContainersEqual(c1, c2));
- EXPECT_FALSE(ContainersEqual(c2, c1));
- EXPECT_FALSE(ContainersEqual(c1, c3));
- EXPECT_FALSE(ContainersEqual(c3, c1));
- EXPECT_FALSE(ContainersEqual(c1, c5));
- EXPECT_FALSE(ContainersEqual(c5, c1));
- EXPECT_FALSE(ContainersEqual(c1, c6));
- EXPECT_FALSE(ContainersEqual(c6, c1));
-}
-
-TEST(UtilTest, ContainersEqualPredicate) {
- std::vector<int> c1 = {1, 2, 3, 4};
- std::vector<int> c2 = {10, 20, 30, 40};
-
- EXPECT_TRUE(ContainersEqual(
- c1, c2, [](const int& i1, const int& i2) { return i1 < i2; }));
- EXPECT_FALSE(ContainersEqual(
- c1, c2, [](const int& i1, const int& i2) { return i1 > i2; }));
-}
-
-TEST(UtilTest, ContainersEqualDifferentContainerTypes) {
- std::vector<int> c1 = {1, 2, 3, 4};
- std::list<int> c2 = {1, 2, 3, 4};
-
- EXPECT_TRUE(ContainersEqual(c1, c2));
-}
-
TEST(UtilTest, HumanReadableNumFlopsExample) {
ASSERT_EQ("1.00GFLOP/s", HumanReadableNumFlops(1e9, 1e9));
}
@@ -117,8 +78,8 @@ TEST(UtilTest, CommonFactors) {
/*.expected =*/{{0, 0}, {0, 1}, {2, 2}, {3, 2}, {4, 3}, {4, 4}}},
};
for (const auto& test_case : test_cases) {
- EXPECT_TRUE(ContainersEqual(test_case.expected,
- CommonFactors(test_case.a, test_case.b)));
+ EXPECT_TRUE(absl::c_equal(test_case.expected,
+ CommonFactors(test_case.a, test_case.b)));
}
}
diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc
index 268dc5db01..8ea8dbab25 100644
--- a/tensorflow/compiler/xla/window_util.cc
+++ b/tensorflow/compiler/xla/window_util.cc
@@ -20,11 +20,12 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace window_util {
-Window MakeWindow(tensorflow::gtl::ArraySlice<int64> sizes) {
+Window MakeWindow(absl::Span<const int64> sizes) {
Window window;
for (int64 size : sizes) {
auto* dimension = window.add_dimensions();
@@ -36,7 +37,7 @@ Window MakeWindow(tensorflow::gtl::ArraySlice<int64> sizes) {
return window;
}
-PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice<int64> sizes) {
+PaddingConfig MakeSymmetricPadding(absl::Span<const int64> sizes) {
PaddingConfig config;
for (int64 size : sizes) {
auto* dimension = config.add_dimensions();
diff --git a/tensorflow/compiler/xla/window_util.h b/tensorflow/compiler/xla/window_util.h
index ba473e2c8c..1fb9e855fc 100644
--- a/tensorflow/compiler/xla/window_util.h
+++ b/tensorflow/compiler/xla/window_util.h
@@ -16,22 +16,22 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_WINDOW_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_WINDOW_UTIL_H_
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace window_util {
// Creates a window with the given sizes in the dimensions and all strides set
// to 1.
-Window MakeWindow(tensorflow::gtl::ArraySlice<int64> sizes);
+Window MakeWindow(absl::Span<const int64> sizes);
// Creates a padding config with symmetrical padding in each dimension, of value
// given by sizes; e.g. {0, 1, 2} would create a R3 padding config that had zero
// pixels of padding in dimension 0, one pixel of padding symmetrically, on each
// side of dimension 1, and two pixels of padding symmetrically on dimension 2.
-PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice<int64> sizes);
+PaddingConfig MakeSymmetricPadding(absl::Span<const int64> sizes);
string ToString(const WindowDimension& dim);
string ToString(const Window& window);
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index aaba5aa92e..8e43f275e1 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -105,13 +105,14 @@ enum PaddingValue {
message PaddingConfig {
// Describes the padding configuration for a dimension.
message PaddingConfigDimension {
- // Padding amount on the low-end (next to the index 0).
+ // Padding amount on the low-end (next to the index 0). May be negative.
int64 edge_padding_low = 1;
- // Padding amount on the high-end (next to the highest index).
+ // Padding amount on the high-end (next to the highest index). May be
+ // negative.
int64 edge_padding_high = 2;
- // Padding amount between the elements.
+ // Padding amount between the elements. May not be negative.
int64 interior_padding = 3;
}
@@ -393,13 +394,14 @@ message WindowDimension {
// Dilation factor of the sliding window in this dimension. A dilation factor
// of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are
- // implicitly placed between each kernel element. See documentation for
- // convolution.
+ // implicitly placed between each kernel element. This value may not be less
+ // than 1. See documentation for convolution.
int64 window_dilation = 5;
// Dilation factor of the base area in this dimension. A dilation factor of 1
// means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly
- // placed between each base area element. See documentation for convolution.
+ // placed between each base area element. This value may not be less than 1.
+ // See documentation for convolution.
int64 base_dilation = 6;
// Window reversal means that this dimension was logically reversed before the
diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD
new file mode 100644
index 0000000000..efbe980278
--- /dev/null
+++ b/tensorflow/compiler/xrt/BUILD
@@ -0,0 +1,83 @@
+# Description: Operations defined for XRT
+
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = [
+ "//learning/brain:__subpackages__",
+ "//tensorflow/compiler/xrt:__subpackages__",
+ ],
+)
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_gen_op_libs",
+)
+load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
+
+xla_proto_library(
+ name = "xrt_proto",
+ srcs = ["xrt.proto"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/compiler/tf2xla:host_compute_metadata_proto",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_proto",
+ ],
+)
+
+cc_library(
+ name = "xrt_utils",
+ srcs = [
+ "xrt_compilation_cache.cc",
+ "xrt_device.cc",
+ "xrt_state.cc",
+ ],
+ hdrs = [
+ "xrt_compilation_cache.h",
+ "xrt_device.h",
+ "xrt_state.h",
+ ],
+ deps = [
+ "//tensorflow/compiler/jit:xla_device",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service:backend",
+ "//tensorflow/compiler/xla/service:device_memory_allocator",
+ "//tensorflow/compiler/xla/service:shaped_buffer",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/stream_executor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+tf_gen_op_libs(
+ op_lib_names = [
+ "xrt_compile_ops",
+ "xrt_state_ops",
+ "xrt_execute_op",
+ ],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "xrt_server",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":xrt_compile_ops_op_lib",
+ ":xrt_execute_op_op_lib",
+ ":xrt_state_ops_op_lib",
+ "//tensorflow/compiler/xrt/kernels:xrt_ops",
+ ],
+)
diff --git a/tensorflow/compiler/xrt/cc/BUILD b/tensorflow/compiler/xrt/cc/BUILD
new file mode 100644
index 0000000000..5c1e86b76b
--- /dev/null
+++ b/tensorflow/compiler/xrt/cc/BUILD
@@ -0,0 +1,20 @@
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_gen_op_wrappers_cc",
+)
+
+tf_gen_op_wrappers_cc(
+ name = "xrt_ops",
+ op_lib_names = [
+ "xrt_compile_ops",
+ "xrt_state_ops",
+ "xrt_execute_op",
+ ],
+ pkg = "//tensorflow/compiler/xrt",
+)
diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD
new file mode 100644
index 0000000000..68ba17a424
--- /dev/null
+++ b/tensorflow/compiler/xrt/kernels/BUILD
@@ -0,0 +1,72 @@
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = [
+ "//learning/brain:__subpackages__",
+ "//tensorflow/compiler/xrt:__subpackages__",
+ ],
+)
+
+cc_library(
+ name = "xrt_state_ops",
+ hdrs = ["xrt_state_ops.h"],
+ deps = [
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:compile_only_client",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_computation",
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
+ "//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service:computation_placer",
+ "//tensorflow/compiler/xla/service:hlo_proto",
+ "//tensorflow/compiler/xrt:xrt_proto",
+ "//tensorflow/compiler/xrt:xrt_utils",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "xrt_ops",
+ srcs = [
+ "xrt_compile_ops.cc",
+ "xrt_execute_op.cc",
+ "xrt_state_ops.cc",
+ ],
+ deps = [
+ ":xrt_state_ops",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:compile_only_client",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_computation",
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
+ "//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service:computation_placer",
+ "//tensorflow/compiler/xla/service:hlo_proto",
+ "//tensorflow/compiler/xrt:xrt_proto",
+ "//tensorflow/compiler/xrt:xrt_utils",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/stream_executor:stream_executor_headers_lib",
+ ],
+ alwayslink = 1,
+)
diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc
new file mode 100644
index 0000000000..5cf2bc8861
--- /dev/null
+++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc
@@ -0,0 +1,239 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for compiling XLA computations and managing handles that refer to
+// them.
+
+#include <cstdlib>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/compiler/xrt/xrt.pb.h"
+#include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
+#include "tensorflow/compiler/xrt/xrt_device.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/proto_serialization.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+namespace {
+
+const int kDefaultCacheSize = 100;
+
+class XRTCompileOp : public OpKernel {
+ public:
+ explicit XRTCompileOp(OpKernelConstruction* ctx);
+ ~XRTCompileOp() override;
+ XRTCompileOp(const XRTCompileOp&) = delete;
+ XRTCompileOp& operator=(const XRTCompileOp&) = delete;
+
+ void Compute(OpKernelContext* ctx) override;
+
+ private:
+ Status Compile(OpKernelContext* ctx,
+ const xrt::XLAComputation& computation_proto,
+ std::unique_ptr<xla::LocalExecutable>* program);
+};
+
+Status CompilationCacheKey(const xrt::XLAComputation& computation,
+ string* key) {
+ string serialized;
+ TF_RET_CHECK(SerializeToStringDeterministic(computation, &serialized));
+ uint64 fingerprint = Fingerprint64(serialized);
+ *key = strings::StrCat(fingerprint);
+ return Status::OK();
+}
+
+XRTCompileOp::XRTCompileOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+Status XRTCompileOp::Compile(OpKernelContext* ctx,
+ const xrt::XLAComputation& computation_proto,
+ std::unique_ptr<xla::LocalExecutable>* program) {
+ const xrt::XLAComputationConfig& config = computation_proto.config();
+
+ // The default config value is 0; treat it as 1 for convenience.
+ int num_replicas = config.num_replicas() ? config.num_replicas() : 1;
+ TF_RET_CHECK(num_replicas == 1);
+ int num_cores_per_replica =
+ config.num_cores_per_replica() ? config.num_cores_per_replica() : 1;
+ TF_RET_CHECK(num_cores_per_replica == 1);
+ TF_RET_CHECK(config.per_core_program_shape_size() == 0);
+
+ // We are guaranteed that the underlying device object won't be deleted out
+ // from under us, while the ScopedRef is live.
+ class XRTGenericDeviceAccessor::ScopedRef device_ref;
+ TF_RETURN_IF_ERROR(
+ XRTGenericDeviceAccessor::InitScopedRef(ctx, 0, &device_ref));
+
+ xla::LocalClient* client = device_ref.client();
+
+ // There is officially no way to use XLA in a client/server architecture where
+ // client and server are built from different revisions, because the XLA team
+ // does not want to give any guarantees about the stability of the Hlo
+ // proto. For cloud TPU this is fine because server and client versions can be
+ // assumed to be synced to the same version. For general use the mechanism
+ // here (using a snapshot from XlaComputation) works as well as the "official"
+ // XLA client/server design, which serializes the same proto between client
+ // and server, so in reality is probably fine.
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation computation,
+ client->LoadSnapshot(computation_proto.hlo_snapshot()));
+
+ std::vector<const xla::Shape*> argument_layouts(
+ config.program_shape().parameters_size());
+ for (int i = 0; i < config.program_shape().parameters_size(); ++i) {
+ argument_layouts[i] = &config.program_shape().parameters(i);
+ }
+ xla::ExecutableBuildOptions build_options;
+ build_options.set_device_ordinal(client->default_device_ordinal());
+ build_options.set_result_layout(config.program_shape().result());
+ build_options.set_device_allocator(device_ref.backend()->memory_allocator());
+
+ VLOG(1) << "Building executable";
+ auto compile_result =
+ client->Compile(computation, argument_layouts, build_options);
+ if (!compile_result.ok()) {
+ return compile_result.status();
+ }
+ *program = std::move(compile_result.ValueOrDie());
+ return Status::OK();
+}
+
+void XRTCompileOp::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "XRTCompileOp::Compute";
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm));
+
+ const Tensor& computation_input = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(computation_input.shape()),
+ errors::Internal("computation input should be a string scalar"));
+
+ xrt::XLAComputation computation_proto;
+ OP_REQUIRES(
+ ctx,
+ computation_proto.ParseFromString(computation_input.scalar<string>()()),
+ errors::InvalidArgument(
+ "Unable to parse computation input to XLAComputation"));
+
+ string key;
+ OP_REQUIRES_OK(ctx, CompilationCacheKey(computation_proto, &key));
+
+ // Process-wide cache of XLA executables.
+ XRTCompilationCache* cache;
+ OP_REQUIRES_OK(ctx,
+ rm->LookupOrCreate<XRTCompilationCache>(
+ rm->default_container(), kXRTCompilationCacheResourceName,
+ &cache, [](XRTCompilationCache** new_cache) {
+ *new_cache = new XRTCompilationCache(kDefaultCacheSize);
+ return Status::OK();
+ }));
+ core::ScopedUnref cache_unref(cache);
+
+ int64 uid;
+ OP_REQUIRES_OK(
+ ctx, cache->CompileIfKeyAbsent(
+ key, &uid, [&](std::unique_ptr<xla::LocalExecutable>* program) {
+ VLOG(1) << "Compiling XLA executable";
+ return Compile(ctx, computation_proto, program);
+ }));
+
+ Tensor output(DT_INT64, TensorShape({}));
+ output.scalar<int64>()() = uid;
+ ctx->set_output(0, output);
+}
+
+XRTCompileOp::~XRTCompileOp() = default;
+
+class XRTReleaseCompilationRefOp : public OpKernel {
+ public:
+ explicit XRTReleaseCompilationRefOp(OpKernelConstruction* ctx);
+ ~XRTReleaseCompilationRefOp() override;
+ XRTReleaseCompilationRefOp(const XRTReleaseCompilationRefOp&) = delete;
+ XRTReleaseCompilationRefOp& operator=(const XRTReleaseCompilationRefOp&) =
+ delete;
+
+ void Compute(OpKernelContext* ctx) override;
+};
+
+XRTReleaseCompilationRefOp::XRTReleaseCompilationRefOp(
+ OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+XRTReleaseCompilationRefOp::~XRTReleaseCompilationRefOp() = default;
+
+void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "XRTReleaseCompilationRefOp::Compute";
+
+ const Tensor& key_tensor = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(key_tensor.shape()),
+ errors::Internal("computation key should be a string scalar"));
+ int64 uid = key_tensor.scalar<int64>()();
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm));
+
+ // Process-wide cache of XLA executables.
+ XRTCompilationCache* cache;
+ OP_REQUIRES_OK(ctx, rm->Lookup<XRTCompilationCache>(
+ rm->default_container(),
+ kXRTCompilationCacheResourceName, &cache));
+ core::ScopedUnref cache_unref(cache);
+
+ OP_REQUIRES_OK(ctx, cache->Release(uid));
+
+ VLOG(2) << "Released computation handle " << uid;
+}
+
+} // namespace
+
+REGISTER_KERNEL_BUILDER(Name("XRTCompile")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("computation")
+ .HostMemory("handle"),
+ XRTCompileOp);
+REGISTER_KERNEL_BUILDER(Name("XRTCompile")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("computation")
+ .HostMemory("handle"),
+ XRTCompileOp);
+
+REGISTER_KERNEL_BUILDER(Name("XRTReleaseCompilationHandle")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("handle"),
+ XRTReleaseCompilationRefOp);
+REGISTER_KERNEL_BUILDER(Name("XRTReleaseCompilationHandle")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("handle"),
+ XRTReleaseCompilationRefOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc
new file mode 100644
index 0000000000..257b054f16
--- /dev/null
+++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc
@@ -0,0 +1,254 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/service/computation_placer.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/compiler/xrt/xrt.pb.h"
+#include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
+#include "tensorflow/compiler/xrt/xrt_device.h"
+#include "tensorflow/compiler/xrt/xrt_state.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+
+namespace tensorflow {
+
+namespace {
+
+uint32 InitialRandomSeed() {
+ // Support plumbing the TF seed through to XLA is being worked on.
+ // If a user wants deterministic behavior, their best option
+ // is to start with a known checkpoint. This also handles issues when
+ // multiple random calls can be invoked in any order by TF executor.
+ // Another option is to use stateless random ops. They have much cleaner
+ // semantics.
+ // If a user really wants to set a deterministic seed for XLA-based
+ // devices, this is the place to do it.
+ std::random_device rd;
+ // Make the starting value odd.
+ return rd() | 1;
+}
+
+uint32 GetXLARandomSeed() {
+ // We initialize counter with an odd number and increment it by two
+ // everytime. This ensures that it will never be zero, even
+ // after an overflow. When seeded with zero, some XLA backends
+ // can return all zeros instead of random numbers.
+ static std::atomic<uint32> counter(InitialRandomSeed());
+ return counter.fetch_add(2);
+}
+
+// Looks up the input `key` in the compilation cache.
+Status GetComputationCacheEntry(
+ XRTCompilationCache* cache, int64 key,
+ std::unique_ptr<XRTCompilationCacheEntryRef>* entry) {
+ TF_RETURN_IF_ERROR(cache->Lookup(key, entry));
+ return Status::OK();
+}
+
+// Populates `inputs` with the input tensors to the computation.
+Status GetComputationInputs(OpKernelContext* context, ResourceMgr* rm,
+ bool release_inputs,
+ std::vector<XRTTupleAllocation*>* input_tuples,
+ std::vector<xla::ShapedBuffer>* input_allocations,
+ std::vector<xla::ShapedBuffer*>* input_pointers) {
+ OpInputList arg_list;
+ TF_RETURN_IF_ERROR(context->input_list("input_handles", &arg_list));
+
+ input_tuples->resize(arg_list.size());
+ input_pointers->resize(arg_list.size());
+ for (int i = 0; i < arg_list.size(); ++i) {
+ TF_RET_CHECK(TensorShapeUtils::IsScalar(arg_list[i].shape()));
+ int64 input_uid = arg_list[i].scalar<int64>()();
+ TF_RETURN_IF_ERROR(
+ XRTTupleAllocation::Lookup(rm, input_uid, &(*input_tuples)[i]));
+ if (release_inputs) {
+ // We are holding a reference to the tuple, so we can safely delete it
+ // from the resource manager here.
+ TF_RETURN_IF_ERROR(
+ XRTTupleAllocation::DeleteFromResourceManager(rm, input_uid));
+ VLOG(2) << "Released allocation handle " << input_uid;
+ }
+ XRTTupleAllocation* tuple = (*input_tuples)[i];
+ input_allocations->emplace_back(tuple->ToShapedBuffer());
+ }
+ for (int i = 0; i < arg_list.size(); ++i) {
+ (*input_pointers)[i] = &(*input_allocations)[i];
+ }
+ return Status::OK();
+}
+
+// XRTExecuteOp
+
+class XRTExecuteOp : public AsyncOpKernel {
+ public:
+ explicit XRTExecuteOp(OpKernelConstruction* context);
+ ~XRTExecuteOp() override;
+
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
+
+ private:
+ Status DoWork(OpKernelContext* context);
+};
+
+XRTExecuteOp::XRTExecuteOp(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {}
+
+void XRTExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
+ // Schedule onto the default queue, for unbounded concurrency. See b/73520706
+ Env::Default()->SchedClosure([this, context, done]() {
+ OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
+ done();
+ });
+}
+
+Status XRTExecuteOp::DoWork(OpKernelContext* context) {
+ VLOG(1) << "XRTExecuteOp::Compute";
+ ResourceMgr* rm;
+ TF_RETURN_IF_ERROR(
+ XRTGenericDeviceAccessor::GetResourceManager(context, &rm));
+
+ const Tensor& execution_input = context->input(0);
+ TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_input.shape()));
+ int64 compilation_handle = execution_input.scalar<int64>()();
+
+ const Tensor& execution_config = context->input(1);
+ TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
+ xrt::XRTExecutionConfig config_proto;
+ TF_RET_CHECK(
+ config_proto.ParseFromString(execution_config.scalar<string>()()));
+
+ int core_index_in_replica = config_proto.core_index_in_replica();
+ TF_RET_CHECK(core_index_in_replica == 0);
+ bool release_inputs = config_proto.release_input_handles();
+ bool release_compilation = config_proto.release_compilation_handle();
+
+ XRTCompilationCache* cache;
+ TF_RETURN_IF_ERROR(rm->Lookup<XRTCompilationCache>(
+ rm->default_container(), kXRTCompilationCacheResourceName, &cache));
+ core::ScopedUnref cache_unref(cache);
+
+ std::unique_ptr<XRTCompilationCacheEntryRef> entry;
+ TF_RETURN_IF_ERROR(cache->Lookup(compilation_handle, &entry));
+
+ if (release_compilation) {
+ // Process-wide cache of XLA executables.
+ TF_RETURN_IF_ERROR(cache->Release(compilation_handle));
+ VLOG(2) << "Released compilation handle " << compilation_handle;
+ }
+
+ std::vector<XRTTupleAllocation*> input_tuples;
+ // Make a cleanup method so that we can safely return in error conditions
+ // without leaking references to allocations.
+ auto buffer_releaser = gtl::MakeCleanup([&input_tuples]() {
+ for (auto tuple : input_tuples) {
+ if (tuple != nullptr) {
+ tuple->Unref();
+ }
+ }
+ });
+ std::vector<xla::ShapedBuffer> input_allocations;
+ std::vector<xla::ShapedBuffer*> input_pointers;
+ TF_RETURN_IF_ERROR(GetComputationInputs(context, rm, release_inputs,
+ &input_tuples, &input_allocations,
+ &input_pointers));
+
+ // We are guaranteed that the underlying device object won't be deleted out
+ // from under us, while the ScopedRef is live.
+ class XRTGenericDeviceAccessor::ScopedRef device_ref;
+ TF_RETURN_IF_ERROR(
+ XRTGenericDeviceAccessor::InitScopedRef(context, 0, &device_ref));
+
+ int rng_seed = config_proto.rng_seed();
+ if (rng_seed == 0) {
+ rng_seed = GetXLARandomSeed();
+ }
+
+ se::Stream* stream = context->op_device_context()
+ ? context->op_device_context()->stream()
+ : nullptr;
+
+ // Execute the computation.
+ VLOG(2) << "Executing computation.";
+ xla::ExecutableRunOptions run_options;
+ run_options.set_stream(stream);
+ run_options.set_allocator(device_ref.backend()->memory_allocator());
+ run_options.set_intra_op_thread_pool(&context->eigen_cpu_device());
+ run_options.set_rng_seed(rng_seed);
+
+ Env* env = Env::Default();
+ auto start_time = env->NowMicros();
+
+ xla::LocalExecutable* executable = entry->get().get_executable();
+ auto run_result = executable->Run(input_pointers, run_options);
+ if (!run_result.ok()) {
+ return run_result.status();
+ }
+
+ auto elapsed = env->NowMicros() - start_time;
+ VLOG(2) << "Elapsed time: " << elapsed << "us";
+
+ auto scoped_buffer = run_result.ConsumeValueOrDie();
+ auto shaped_buffer = scoped_buffer.release();
+ XRTTupleAllocation* output_tuple;
+ TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
+ shaped_buffer, device_ref.backend(), device_ref.device_ordinal(),
+ &output_tuple));
+
+ Tensor* output_tensor;
+ TF_RETURN_IF_ERROR(
+ context->allocate_output(0, TensorShape({}), &output_tensor));
+ int64 key;
+ TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key));
+ output_tensor->scalar<int64>()() = key;
+
+ return Status::OK();
+}
+
+XRTExecuteOp::~XRTExecuteOp() = default;
+
+} // namespace
+
+REGISTER_KERNEL_BUILDER(Name("XRTExecute")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("computation_handle")
+ .HostMemory("execution_config")
+ .HostMemory("input_handles")
+ .HostMemory("output_handle"),
+ XRTExecuteOp);
+
+REGISTER_KERNEL_BUILDER(Name("XRTExecute")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("computation_handle")
+ .HostMemory("execution_config")
+ .HostMemory("input_handles")
+ .HostMemory("output_handle"),
+ XRTExecuteOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc
new file mode 100644
index 0000000000..ffea592491
--- /dev/null
+++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc
@@ -0,0 +1,110 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for allocating XLA literals in device memory and managing handles
+// that refer to them.
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xrt/kernels/xrt_state_ops.h"
+
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+
+namespace tensorflow {
+
+REGISTER_KERNEL_BUILDER(Name("XRTAllocate")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("allocation")
+ .HostMemory("handle"),
+ XRTAllocateOp<XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTAllocate")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("allocation")
+ .HostMemory("handle"),
+ XRTAllocateOp<XRTGenericDeviceAccessor>);
+
+REGISTER_KERNEL_BUILDER(Name("XRTSubTuple")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("base_handle")
+ .HostMemory("shape_index")
+ .HostMemory("output_handle"),
+ XRTSubTupleOp<false, XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTSubTuple")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("base_handle")
+ .HostMemory("shape_index")
+ .HostMemory("output_handle"),
+ XRTSubTupleOp<false, XRTGenericDeviceAccessor>);
+
+REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("base_handle")
+ .HostMemory("shape_index")
+ .HostMemory("output_handle"),
+ XRTSubTupleOp<true, XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("base_handle")
+ .HostMemory("shape_index")
+ .HostMemory("output_handle"),
+ XRTSubTupleOp<true, XRTGenericDeviceAccessor>);
+
+REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("tuple_description")
+ .HostMemory("input_handles")
+ .HostMemory("output_handle"),
+ XRTMakeTupleOp<XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("tuple_description")
+ .HostMemory("input_handles")
+ .HostMemory("output_handle"),
+ XRTMakeTupleOp<XRTGenericDeviceAccessor>);
+
+REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("handle")
+ .HostMemory("literal"),
+ XRTReadLiteralOp<false, XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("handle")
+ .HostMemory("literal"),
+ XRTReadLiteralOp<false, XRTGenericDeviceAccessor>);
+
+REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("handle")
+ .HostMemory("literal"),
+ XRTReadLiteralOp<true, XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("handle")
+ .HostMemory("literal"),
+ XRTReadLiteralOp<true, XRTGenericDeviceAccessor>);
+
+REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("handle"),
+ XRTReleaseAllocationOp<XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("handle"),
+ XRTReleaseAllocationOp<XRTGenericDeviceAccessor>);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
new file mode 100644
index 0000000000..478c9663a7
--- /dev/null
+++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
@@ -0,0 +1,424 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for allocating XLA literals in device memory and managing handles
+// that refer to them.
+
+#ifndef TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
+#define TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/compiler/xrt/xrt.pb.h"
+#include "tensorflow/compiler/xrt/xrt_device.h"
+#include "tensorflow/compiler/xrt/xrt_state.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// Helper functions for templated ops.
+class XRTStateHelpers {
+ public:
+ // The Status return value allows us to use the
+ // TF_ASSIGN_OR_RETURN macro, which doesn't work within the body of an
+ // OpKernel::Compute method.
+ static Status MakeLiteral(const xla::LiteralProto& proto,
+ std::unique_ptr<xla::Literal>* literal) {
+ TF_ASSIGN_OR_RETURN(*literal, xla::Literal::CreateFromProto(proto));
+ return Status::OK();
+ }
+
+ // ParseTupleNode is the recursive function used to parse a recursive
+ // xrt::XLATupleNode proto and generate the xla::Shape of the 'spine' i.e. the
+ // tuple shape where every leaf is an existing allocation. As a side-effect it
+ // fills in input_vector by looking up allocations from handles in the
+ // input_tensor_list as they are referenced by nodes in the proto.
+ static Status ParseTupleNode(
+ const xrt::XLATupleNode& tuple_node, const OpInputList& input_tensor_list,
+ std::vector<XRTTupleAllocation::ExpandedTupleInput>* input_vector,
+ xla::Shape* shape, ResourceMgr* rm) {
+ if (tuple_node.tuples_size() > 0) {
+ // This is an internal node in the proto so descend recursively.
+ xla::Shape dummy = xla::ShapeUtil::MakeShapeWithType<float>({});
+ std::vector<xla::Shape> subshapes(tuple_node.tuples_size(), dummy);
+ *xla::ShapeUtil::GetMutableSubshape(shape, {}) =
+ xla::ShapeUtil::MakeTupleShape(subshapes);
+ for (int i = 0; i < tuple_node.tuples_size(); ++i) {
+ TF_RETURN_IF_ERROR(ParseTupleNode(
+ tuple_node.tuples(i), input_tensor_list, input_vector,
+ xla::ShapeUtil::GetMutableSubshape(shape, {i}), rm));
+ }
+ } else {
+ // This is a leaf node in the proto so look up the referenced input.
+ int input_index = tuple_node.input_index();
+ if (input_index < 0 || input_index >= input_vector->size()) {
+ return errors::InvalidArgument("Invalid tuple input index ",
+ input_index, ": MakeTuple has ",
+ input_vector->size(), " inputs.");
+ }
+ bool release_this_input = tuple_node.release_input_handle();
+ XRTTupleAllocation::ExpandedTupleInput& input =
+ input_vector->at(input_index);
+ if (input.allocation != nullptr &&
+ (input.release_allocation_after_use || release_this_input)) {
+ return errors::InvalidArgument(
+ "Invalid tuple tree: input index ", input_index,
+ " is repeated but release_input_handle is true.");
+ }
+ if (input.allocation == nullptr) {
+ // We haven't dereferenced this handle yet.
+ TF_RET_CHECK(
+ TensorShapeUtils::IsScalar(input_tensor_list[input_index].shape()));
+ int64 key = input_tensor_list[input_index].scalar<int64>()();
+ TF_RETURN_IF_ERROR(
+ XRTTupleAllocation::Lookup(rm, key, &input.allocation));
+ input.release_allocation_after_use = release_this_input;
+ }
+ }
+ return Status::OK();
+ }
+
+ // Parses a xrt::XLATupleNode proto recursively and returns the corresponding
+ // ShapeTree where each leaf is an allocation corresponding to a handle in
+ // input_tensor_list. The ordinal of one of the allocations is returned in
+ // device_ordinal. Since it's not possible to specify a xrt::XLATupleNode with
+ // no leaves, device_ordinal will always be filled in by a successful call to
+ // ParseTupleTree.
+ static Status ParseTupleTree(
+ const xrt::XLATupleNode& tuple_tree_root,
+ const OpInputList& input_tensor_list,
+ std::vector<XRTTupleAllocation::ExpandedTupleInput>* input_vector,
+ xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>* tuple_shape_tree,
+ int* device_ordinal, ResourceMgr* rm) {
+ // First get the shape of the 'spine' of the new tuple, where every leaf is
+ // an existing allocation. As a side-effect dereference the input handles
+ // into allocations in input_vector.
+ xla::Shape tuple_tree_shape;
+ TF_RETURN_IF_ERROR(ParseTupleNode(tuple_tree_root, input_tensor_list,
+ input_vector, &tuple_tree_shape, rm));
+ // Make the shape tree of allocations where the shape is the spine and each
+ // leaf is one of the allocations looked up in input_vector. Internal nodes
+ // have nullptr allocations.
+ *tuple_shape_tree = xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>(
+ tuple_tree_shape);
+ tuple_shape_tree->ForEachMutableElement(
+ [&](const xla::ShapeIndex& index,
+ XRTTupleAllocation::ExpandedTupleInput* element) {
+ if (tuple_shape_tree->IsLeaf(index)) {
+ // Find the matching leaf in the proto tree.
+ const xrt::XLATupleNode* tuple_node = &tuple_tree_root;
+ for (int i = 0; i < index.size(); ++i) {
+ tuple_node = &tuple_node->tuples(index[i]);
+ }
+ // Copy the appropriate input allocation to the leaf of the
+ // tuple_shape_tree.
+ int input_index = tuple_node->input_index();
+ *element = input_vector->at(input_index);
+ CHECK(element->release_allocation_after_use ==
+ tuple_node->release_input_handle());
+ // We just need to know the device_ordinal of one of the
+ // allocations. We will validate later that they are all the same.
+ *device_ordinal = (*element).allocation->device_ordinal();
+ }
+ });
+ return Status::OK();
+ }
+};
+
+// Op that allocates memory for a literal and transfers it to the device.
+template <class DeviceAccessor>
+class XRTAllocateOp : public OpKernel {
+ public:
+ explicit XRTAllocateOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ ~XRTAllocateOp() override = default;
+ XRTAllocateOp(const XRTAllocateOp&) = delete;
+ XRTAllocateOp& operator=(const XRTAllocateOp&) = delete;
+
+ void Compute(OpKernelContext* ctx) override {
+ VLOG(1) << "XRTAllocateOp::Compute";
+
+ const Tensor& allocation_info = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(allocation_info.shape()),
+ errors::Internal("allocation input should be a string scalar"));
+ xrt::XLAAllocation allocation_proto;
+ OP_REQUIRES(
+ ctx,
+ allocation_proto.ParseFromString(allocation_info.scalar<string>()()),
+ errors::InvalidArgument(
+ "Unable to parse allocation input to XLAAllocation"));
+
+ std::unique_ptr<xla::Literal> literal;
+ OP_REQUIRES_OK(
+ ctx, XRTStateHelpers::MakeLiteral(allocation_proto.value(), &literal));
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
+
+ // We are guaranteed that the underlying device object won't be deleted out
+ // from under us, while the ScopedRef is live.
+ class DeviceAccessor::ScopedRef device_ref;
+ OP_REQUIRES_OK(ctx,
+ DeviceAccessor::InitScopedRef(
+ ctx, allocation_proto.device_ordinal(), &device_ref));
+
+ XRTTupleAllocation* allocation;
+ OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer(
+ *literal, device_ref.backend(),
+ device_ref.device_ordinal(), &allocation));
+
+ // Intern takes ownership of our reference to allocation.
+ int64 key;
+ OP_REQUIRES_OK(ctx, allocation->Intern(rm, &key));
+
+ Tensor output(DT_INT64, TensorShape({}));
+ output.scalar<int64>()() = key;
+ ctx->set_output(0, output);
+ }
+};
+
+// Op that takes a tuple handle input and returns a handle to a sub-tuple of the
+// input.
+template <bool discard_, class DeviceAccessor>
+class XRTSubTupleOp : public OpKernel {
+ public:
+ explicit XRTSubTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ ~XRTSubTupleOp() override = default;
+ XRTSubTupleOp(const XRTSubTupleOp&) = delete;
+ XRTSubTupleOp& operator=(const XRTSubTupleOp&) = delete;
+
+ void Compute(OpKernelContext* ctx) override {
+ VLOG(1) << "XRTSubTupleOp::Compute";
+
+ const Tensor& handle_tensor = ctx->input(0);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
+ errors::Internal("computation input should be an int64 scalar"));
+ int64 allocation_handle = handle_tensor.scalar<int64>()();
+
+ const Tensor& subtuple_info = ctx->input(1);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsVector(subtuple_info.shape()),
+ errors::Internal("tuple index input should be an int32 vector"));
+ xla::ShapeIndex shape_index;
+ for (int i = 0; i < subtuple_info.dim_size(0); ++i) {
+ shape_index.push_back(subtuple_info.vec<int32>()(i));
+ }
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
+
+ XRTTupleAllocation* allocation;
+ OP_REQUIRES_OK(
+ ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation));
+ core::ScopedUnref allocation_unref(allocation);
+
+ if (discard_) {
+ VLOG(2) << "Releasing handle " << allocation_handle;
+ OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(
+ rm, allocation_handle));
+ }
+
+ XRTTupleAllocation* suballocation;
+ OP_REQUIRES_OK(
+ ctx, XRTTupleAllocation::MakeSubBuffer(allocation, shape_index,
+ &suballocation, !discard_));
+
+ // Intern takes ownership of our reference to suballocation.
+ int64 key;
+ OP_REQUIRES_OK(ctx, suballocation->Intern(rm, &key));
+
+ Tensor output(DT_INT64, TensorShape({}));
+ output.scalar<int64>()() = key;
+ ctx->set_output(0, output);
+ }
+};
+
+// Op that allocates memory for a literal and transfers it to the device.
+template <class DeviceAccessor>
+class XRTMakeTupleOp : public OpKernel {
+ public:
+ explicit XRTMakeTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ ~XRTMakeTupleOp() override = default;
+ XRTMakeTupleOp(const XRTMakeTupleOp&) = delete;
+ XRTMakeTupleOp& operator=(const XRTMakeTupleOp&) = delete;
+
+ void Compute(OpKernelContext* ctx) override {
+ VLOG(1) << "XRTMakeTupleOp::Compute";
+
+ const Tensor& tuple_info = ctx->input(0);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsScalar(tuple_info.shape()),
+ errors::Internal("tuple description input should be a string scalar"));
+ xrt::XLATupleNode tuple_proto;
+ OP_REQUIRES(
+ ctx, tuple_proto.ParseFromString(tuple_info.scalar<string>()()),
+ errors::InvalidArgument("Unable to parse tuple input to XLATupleNode"));
+
+ OpInputList arg_list;
+ OP_REQUIRES_OK(ctx, ctx->input_list("input_handles", &arg_list));
+
+ // For each input, the allocation it corresponds to and a flag indicating
+ // whether or not it should be released, i.e. discarded from the resource
+ // manager. One ref on each allocation is owned by this vector, and freed on
+ // exit.
+ std::vector<XRTTupleAllocation::ExpandedTupleInput> input_vector(
+ arg_list.size());
+ auto cleanup = gtl::MakeCleanup([&input_vector] {
+ for (auto& input : input_vector) {
+ if (input.allocation != nullptr) {
+ input.allocation->Unref();
+ }
+ }
+ });
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
+
+ xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput> tuple_shape_tree;
+ // device_ordinal is filled in by ParseTupleTree with the ordinal of one of
+ // the allocations. It is guaranteed that there is at least on allocation in
+ // any legal tree. We validate below in XRTTupleAllocation::MakeTuple that
+ // all the allocations are on the same device.
+ int device_ordinal;
+ OP_REQUIRES_OK(ctx, XRTStateHelpers::ParseTupleTree(
+ tuple_proto, arg_list, &input_vector,
+ &tuple_shape_tree, &device_ordinal, rm));
+
+ // We are guaranteed that the underlying device object won't be deleted out
+ // from under us, while the ScopedRef is live.
+ class DeviceAccessor::ScopedRef device_ref;
+ OP_REQUIRES_OK(
+ ctx, DeviceAccessor::InitScopedRef(ctx, device_ordinal, &device_ref));
+
+ XRTTupleAllocation* output_allocation;
+ OP_REQUIRES_OK(ctx, XRTTupleAllocation::MakeTuple(
+ device_ref.backend(), device_ref.device_ordinal(),
+ tuple_shape_tree, &output_allocation));
+ // Add a ScopedUnref to simplify the error path while calling
+ // DeleteFromResourceManager.
+ core::ScopedUnref unref(output_allocation);
+ for (int i = 0; i < input_vector.size(); ++i) {
+ if (input_vector[i].release_allocation_after_use) {
+ OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(
+ rm, arg_list[i].scalar<int64>()()));
+ }
+ }
+
+ // Intern takes ownership of a reference to output_allocation, so add
+ // another since the ScopedUnref will release one when this method exits.
+ output_allocation->Ref();
+ int64 key;
+ OP_REQUIRES_OK(ctx, output_allocation->Intern(rm, &key));
+
+ Tensor output(DT_INT64, TensorShape({}));
+ output.scalar<int64>()() = key;
+ ctx->set_output(0, output);
+ }
+};
+
+// Op that reads a device-resident tuple to host memory and returns it as a
+// literal.
+template <bool discard_, class DeviceAccessor>
+class XRTReadLiteralOp : public OpKernel {
+ public:
+ explicit XRTReadLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ ~XRTReadLiteralOp() override = default;
+ XRTReadLiteralOp(const XRTReadLiteralOp&) = delete;
+ XRTReadLiteralOp& operator=(const XRTReadLiteralOp&) = delete;
+
+ void Compute(OpKernelContext* ctx) override {
+ VLOG(1) << "XRTReadLiteralOp::Compute";
+
+ const Tensor& handle_tensor = ctx->input(0);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
+ errors::Internal("computation input should be an int64 scalar"));
+ int64 allocation_handle = handle_tensor.scalar<int64>()();
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
+
+ XRTTupleAllocation* allocation;
+ OP_REQUIRES_OK(
+ ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation));
+ core::ScopedUnref allocation_unref(allocation);
+
+ if (discard_) {
+ VLOG(2) << "Releasing handle " << allocation_handle;
+ OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(
+ rm, allocation_handle));
+ }
+
+ // We are guaranteed that the underlying device object won't be deleted out
+ // from under us, while the ScopedRef is live.
+ class DeviceAccessor::ScopedRef device_ref;
+ OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
+ ctx, allocation->device_ordinal(), &device_ref));
+
+ std::unique_ptr<xla::Literal> literal;
+ OP_REQUIRES_OK(
+ ctx, allocation->ToLiteral(device_ref.backend(),
+ device_ref.device_ordinal(), &literal));
+ xla::LiteralProto literal_proto = literal->ToProto();
+
+ Tensor output(DT_STRING, TensorShape({}));
+ literal_proto.SerializeToString(&output.scalar<string>()());
+ ctx->set_output(0, output);
+ }
+};
+
+// Op that discards a handle to device memory.
+template <class DeviceAccessor>
+class XRTReleaseAllocationOp : public OpKernel {
+ public:
+ explicit XRTReleaseAllocationOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ ~XRTReleaseAllocationOp() override = default;
+ XRTReleaseAllocationOp(const XRTReleaseAllocationOp&) = delete;
+ XRTReleaseAllocationOp& operator=(const XRTReleaseAllocationOp&) = delete;
+
+ void Compute(OpKernelContext* ctx) override {
+ VLOG(1) << "XRTReleaseAllocationOp::Compute";
+
+ const Tensor& allocation_handle = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(allocation_handle.shape()),
+ errors::Internal("handle input should be an int64 scalar"));
+ int64 key = allocation_handle.scalar<int64>()();
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
+
+ OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(rm, key));
+
+ VLOG(2) << "Released allocation handle " << key;
+ }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
diff --git a/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc
new file mode 100644
index 0000000000..5cfc8711f9
--- /dev/null
+++ b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc
@@ -0,0 +1,48 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+REGISTER_OP("XRTCompile")
+ .Input("computation: string")
+ .Output("handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Reads a computation proto, compiles it, and places it in the global compilation
+cache.
+
+'computation' is a serialized xrt::XLAComputation proto.
+'handle' is an identifier that can be used in other ops to refer to the
+computation.
+)");
+
+REGISTER_OP("XRTReleaseCompilationHandle")
+ .Input("handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::NoOutputs)
+ .Doc(
+ R"(
+Discards a computation from the compilation cache. The handle cannot be
+subsequently used.
+
+'handle' is an id returned from a XRTCompile Op.
+)");
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc
new file mode 100644
index 0000000000..fda4c31298
--- /dev/null
+++ b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+REGISTER_OP("XRTExecute")
+ .Attr("Ninputs: int")
+ .Input("computation_handle: int64")
+ .Input("execution_config: string")
+ .Input("input_handles: Ninputs * int64")
+ .Output("output_handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Runs a previously-compiled computation on a core. If
+execution_config.release_input_handles is true, the input handles are invalid
+after this op runs.
+
+'computation_handle' is an id returned by XRTCompile.
+'execution_config' is a serialized xrt::TPUExecutionConfig proto.
+'input_handles' is a list of ids of allocations, one per input to the compiled
+computation.
+'output_handle' is an identifier for the result of the compiled computation.
+'Ninputs' is the number of input handles.
+)");
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc
new file mode 100644
index 0000000000..07d025ce34
--- /dev/null
+++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc
@@ -0,0 +1,122 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+REGISTER_OP("XRTAllocate")
+ .Input("allocation: string")
+ .Output("handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Reads a literal proto and transfers it to TPU device memory.
+
+'allocation' is a serialized xrt::TPUAllocation proto.
+'handle' is an id that can be used in other ops to refer to the allocation.
+)");
+
+REGISTER_OP("XRTSubTuple")
+ .Input("base_handle: int64")
+ .Input("shape_index: int32")
+ .Output("output_handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Returns a handle to a sub-tuple of an allocated tuple.
+
+'base_handle' is the id of the on-device allocation.
+'shape_index' is a vector of integers describing an XLA ShapeIndex.
+'output_handle' is an id that can be used in other ops to refer to the
+sub-tuple.
+)");
+
+REGISTER_OP("XRTSubTupleAndRelease")
+ .Input("base_handle: int64")
+ .Input("shape_index: int32")
+ .Output("output_handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Returns a handle to a sub-tuple of an allocated tuple, and releases the handle
+of the input tuple.
+
+'base_handle' is the id of the on-device allocation.
+'shape_index' is a vector of integers describing an XLA ShapeIndex.
+'output_handle' is an id that can be used by other ops to refer to the
+sub-tuple.
+)");
+
+REGISTER_OP("XRTMakeTuple")
+ .Attr("Ninputs: int")
+ .Input("tuple_description: string")
+ .Input("input_handles: Ninputs * int64")
+ .Output("output_handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Returns a handle to a new allocation constructed by assembling existing
+allocations in a tuple.
+
+'tuple_description' is a serialized xrt::XLATupleNode proto describing the
+shape of the output tuple, and whether each input handle should be aliased or
+released.
+'input_handles' is a list of input handles to assemble into the output tuple.
+'output_handle' is an id that can be used by other ops to refer to the new
+tuple.
+'Ninputs' is the number of input handles.
+)");
+
+REGISTER_OP("XRTReadLiteral")
+ .Input("handle: int64")
+ .Output("literal: string")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Copies an allocated tuple from device memory and returns it as a literal.
+
+'handle' is the id returned from the Op that produced the on-device allocation.
+'literal' is a serialized xla::LiteralProto proto.
+)");
+
+REGISTER_OP("XRTReadLiteralAndRelease")
+ .Input("handle: int64")
+ .Output("literal: string")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Copies an allocated tuple from device memory, and returns it as a literal, and
+releases the handle.
+
+'handle' is the id returned from the Op that produced the on-device allocation.
+'literal' is a serialized xla::LiteralProto proto.
+)");
+
+REGISTER_OP("XRTReleaseAllocationHandle")
+ .Input("handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::NoOutputs)
+ .Doc(
+ R"(
+Discards an allocation from device memory. The handle cannot be subsequently
+used.
+
+'handle' is the id returned from the Op that produced the on-device allocation.
+)");
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD
new file mode 100644
index 0000000000..09ab4ed95f
--- /dev/null
+++ b/tensorflow/compiler/xrt/tests/BUILD
@@ -0,0 +1,65 @@
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = [
+ "//learning/brain:__subpackages__",
+ "//tensorflow/compiler:__subpackages__",
+ ],
+)
+
+load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test", "tf_cc_test")
+
+cc_library(
+ name = "raw_api_test_lib",
+ testonly = 1,
+ srcs = [
+ "raw_api_test.cc",
+ ],
+ deps = [
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:client_session",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:scope",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_computation",
+ "//tensorflow/compiler/xrt:xrt_proto",
+ "//tensorflow/compiler/xrt:xrt_server",
+ "//tensorflow/compiler/xrt/cc:xrt_ops",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:tensorflow_opensource",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+tf_cc_test(
+ name = "raw_api_test_cpu",
+ size = "medium",
+ srcs = [],
+ args = ["--xla_test_device=XLA_CPU"],
+ deps = [
+ ":raw_api_test_lib",
+ "//tensorflow/compiler/jit:xla_cpu_device",
+ ],
+)
+
+tf_cuda_cc_test(
+ name = "raw_api_test_gpu",
+ size = "medium",
+ srcs = [],
+ args = ["--xla_test_device=XLA_GPU"],
+ tags = ["requires-gpu-sm35"],
+ deps = [
+ ":raw_api_test_lib",
+ "//tensorflow/compiler/jit:xla_gpu_device",
+ ],
+)
diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc
new file mode 100644
index 0000000000..5b8516bf1d
--- /dev/null
+++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc
@@ -0,0 +1,421 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/cc/client/client_session.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h"
+#include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h"
+#include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h"
+#include "tensorflow/compiler/xrt/xrt.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace {
+
+string* xla_test_device_ptr; // initial value set in main()
+
+string DeviceFromFlag() {
+ string xla_test_device = *xla_test_device_ptr;
+ return absl::StrCat("/device:", xla_test_device, ":0");
+}
+
+xla::LiteralProto TwoElementTuple() {
+ auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
+ auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
+ auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()});
+ return tuple->ToProto();
+}
+
+xla::LiteralProto ScalarLiteral() {
+ auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
+ return scalar->ToProto();
+}
+
+xla::LiteralProto NestedTuple() {
+ auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
+ auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
+ auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()});
+ auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
+ auto nested = xla::LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+ return nested->ToProto();
+}
+
+xla::LiteralProto MakeTuple0() {
+ auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
+ auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
+ auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
+ auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()});
+ auto nested0 = xla::LiteralUtil::MakeTuple({scalar.get(), tuple.get()});
+ auto nested1 = xla::LiteralUtil::MakeTuple({scalar.get(), nested0.get()});
+ return nested1->ToProto();
+}
+
+xla::LiteralProto FloatVector(gtl::ArraySlice<float> v) {
+ auto array = xla::LiteralUtil::CreateR1<float>(v);
+ return array->ToProto();
+}
+
+bool CompareLiteralProtos(const xla::LiteralProto& a,
+ const xla::LiteralProto& b) {
+ auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie();
+ auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
+ bool equal = *l_a == *l_b;
+ if (!equal) {
+ LOG(INFO) << "LiteralProtos don't match " << a.DebugString()
+ << " != " << b.DebugString();
+ }
+ return equal;
+}
+
+bool CompareLiteralToLiteralProto(const xla::Literal& a,
+ const xla::LiteralProto& b) {
+ auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
+ bool equal = a == *l_b;
+ if (!equal) {
+ LOG(INFO) << "Literal and LiteralProto don't match "
+ << a.ToProto().DebugString() << " != " << b.DebugString();
+ }
+ return equal;
+}
+
+xla::XlaComputation AddAndScale() {
+ xla::XlaBuilder builder("AddAndScale");
+ auto p0 = xla::Parameter(&builder, 0,
+ xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
+ auto p1 = xla::Parameter(&builder, 1,
+ xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
+ auto sum = xla::Add(p0, p1);
+ auto c = xla::ConstantR0<float>(&builder, 3.0f);
+ xla::Mul(sum, c);
+ return builder.Build().ValueOrDie();
+}
+
+xla::XlaComputation AddAndTuple() {
+ xla::XlaBuilder builder("AddAndTuple");
+ auto p0 = xla::Parameter(&builder, 0,
+ xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
+ auto p1 = xla::Parameter(&builder, 1,
+ xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
+ auto sum = xla::Add(p0, p1);
+ xla::Tuple(&builder, {sum});
+ return builder.Build().ValueOrDie();
+}
+
+void StoreComputationSnapshot(const xla::XlaComputation& computation,
+ xla::HloSnapshot* dst) {
+ auto snapshot = computation.Snapshot().ValueOrDie();
+ *dst = *snapshot;
+}
+
+TEST(RawApiTest, ReadAndWriteState) {
+ xrt::XLAAllocation alloc;
+ alloc.set_device_ordinal(0);
+ *alloc.mutable_value() = TwoElementTuple();
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto value =
+ ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
+ auto handle = ops::XRTAllocate(root, value);
+ auto read_back = ops::XRTReadLiteral(root, handle);
+ auto release = ops::XRTReleaseAllocationHandle(
+ root.WithControlDependencies(read_back), handle);
+ TF_ASSERT_OK(root.status());
+
+ tensorflow::ClientSession session(root);
+ std::vector<tensorflow::Tensor> outputs;
+ TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {read_back},
+ {release}, &outputs));
+
+ xla::LiteralProto response;
+ EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+
+ EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
+}
+
+TEST(RawApiTest, ReadAndWriteStateAutoFree) {
+ xrt::XLAAllocation alloc;
+ alloc.set_device_ordinal(0);
+ *alloc.mutable_value() = TwoElementTuple();
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto value =
+ ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
+ auto handle = ops::XRTAllocate(root, value);
+ auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
+ TF_ASSERT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({read_back}, &outputs));
+
+ xla::LiteralProto response;
+ EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+ EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
+}
+
+TEST(RawApiTest, SubBuffer) {
+ xrt::XLAAllocation alloc;
+ alloc.set_device_ordinal(0);
+ *alloc.mutable_value() = NestedTuple();
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto value =
+ ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
+ auto base_handle = ops::XRTAllocate(root, value);
+ auto index_0 = ops::Const(root.WithDevice("/device:CPU:0"), {0});
+ auto index_1 = ops::Const(root.WithDevice("/device:CPU:0"), {1});
+ auto index_00 = ops::Const(root.WithDevice("/device:CPU:0"), {0, 0});
+ auto sub_0 = ops::XRTSubTuple(root, base_handle, index_0);
+ auto sub_1 = ops::XRTSubTuple(root, base_handle, index_1);
+ auto sub_00 = ops::XRTSubTupleAndRelease(
+ root.WithControlDependencies(
+ {sub_0.output_handle.op(), sub_1.output_handle.op()}),
+ base_handle, index_00);
+ auto value_0 = ops::XRTReadLiteralAndRelease(root, sub_0);
+ auto value_1 = ops::XRTReadLiteralAndRelease(root, sub_1);
+ auto value_00 = ops::XRTReadLiteralAndRelease(root, sub_00);
+ TF_ASSERT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({value_0, value_1, value_00}, &outputs));
+
+ auto base_literal = xla::Literal::CreateFromProto(alloc.value()).ValueOrDie();
+ auto base_elements = base_literal->DecomposeTuple();
+ auto nested_0_elements = base_elements[0].Clone().DecomposeTuple();
+ xla::LiteralProto response_0;
+ EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar<string>()()));
+ EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[0], response_0));
+ xla::LiteralProto response_1;
+ EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar<string>()()));
+ EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[1], response_1));
+ xla::LiteralProto response_00;
+ EXPECT_TRUE(response_00.ParseFromString(outputs[2].scalar<string>()()));
+ EXPECT_TRUE(CompareLiteralToLiteralProto(nested_0_elements[0], response_00));
+}
+
+TEST(RawApiTest, MakeTuple) {
+ xrt::XLAAllocation alloc_0;
+ alloc_0.set_device_ordinal(0);
+ *alloc_0.mutable_value() = TwoElementTuple();
+ xrt::XLAAllocation alloc_1;
+ alloc_1.set_device_ordinal(0);
+ *alloc_1.mutable_value() = ScalarLiteral();
+
+ // The trivial tuple that just forwards its input and releases it.
+ xrt::XLATupleNode desc_0;
+ desc_0.set_input_index(0);
+ desc_0.set_release_input_handle(true);
+
+ xrt::XLATupleNode desc_1;
+ auto subdesc_10 = desc_1.add_tuples();
+ auto subdesc_11 = desc_1.add_tuples();
+ subdesc_10->set_input_index(0);
+ auto subdesc_110 = subdesc_11->add_tuples();
+ subdesc_110->set_input_index(0);
+ auto subdesc_111 = subdesc_11->add_tuples();
+ subdesc_111->set_input_index(1);
+
+ xrt::XLATupleNode desc_2;
+ auto subdesc_20 = desc_2.add_tuples();
+ auto subdesc_21 = desc_2.add_tuples();
+ subdesc_20->set_input_index(1);
+ subdesc_20->set_release_input_handle(true);
+ subdesc_21->set_input_index(0);
+ subdesc_21->set_release_input_handle(true);
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto value_0 =
+ ops::Const(root.WithDevice("/device:CPU:0"), alloc_0.SerializeAsString());
+ auto handle_0 = ops::XRTAllocate(root, value_0);
+ auto value_1 =
+ ops::Const(root.WithDevice("/device:CPU:0"), alloc_1.SerializeAsString());
+ auto handle_1 = ops::XRTAllocate(root, value_1);
+ auto tuple_0 =
+ ops::Const(root.WithDevice("/device:CPU:0"), desc_0.SerializeAsString());
+ auto handle_2 =
+ ops::XRTMakeTuple(root, tuple_0, {static_cast<Output>(handle_0)});
+ // handle_0 has now been released.
+ auto tuple_1 =
+ ops::Const(root.WithDevice("/device:CPU:0"), desc_1.SerializeAsString());
+ auto handle_3 = ops::XRTMakeTuple(
+ root, tuple_1,
+ {static_cast<Output>(handle_1), static_cast<Output>(handle_2)});
+ auto tuple_2 =
+ ops::Const(root.WithDevice("/device:CPU:0"), desc_2.SerializeAsString());
+ // Make sure this runs after handle_3 has completed, since it will free
+ // handle_1 and handle_2.
+ auto handle_4 = ops::XRTMakeTuple(
+ root.WithControlDependencies(handle_3), tuple_2,
+ {static_cast<Output>(handle_1), static_cast<Output>(handle_2)});
+ // handle_1 and handle_2 have now been released.
+
+ auto res_0 = ops::XRTReadLiteralAndRelease(root, handle_3);
+ auto res_1 = ops::XRTReadLiteralAndRelease(root, handle_4);
+ TF_ASSERT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({res_0, res_1}, &outputs));
+ xla::LiteralProto response_0;
+ EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar<string>()()));
+ xla::LiteralProto response_1;
+ EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar<string>()()));
+
+ auto expected_0 = MakeTuple0();
+ EXPECT_TRUE(CompareLiteralProtos(response_0, expected_0));
+ auto expected_1 = NestedTuple();
+ EXPECT_TRUE(CompareLiteralProtos(response_1, expected_1));
+}
+
+TEST(RawApiTest, CompileAndExecute) {
+ xrt::XLAAllocation p0;
+ p0.set_device_ordinal(0);
+ *p0.mutable_value() = FloatVector({1.0f, 2.0f});
+ xrt::XLAAllocation p1;
+ p1.set_device_ordinal(0);
+ *p1.mutable_value() = FloatVector({8.0f, 5.0f});
+
+ xrt::XLAComputation c;
+ auto config = c.mutable_config();
+ auto shapes = config->mutable_program_shape();
+ *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
+ *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
+ *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {2});
+ StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
+
+ xrt::XRTExecutionConfig e;
+ e.set_release_input_handles(true);
+ e.set_release_compilation_handle(true);
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto e_config =
+ ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
+ auto computation =
+ ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
+ auto c_handle = ops::XRTCompile(root, computation);
+ auto p0_value =
+ ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
+ auto p0_handle = ops::XRTAllocate(root, p0_value);
+ auto p1_value =
+ ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
+ auto p1_handle = ops::XRTAllocate(root, p1_value);
+ auto result = ops::XRTExecute(root, c_handle, e_config,
+ {Output(p0_handle), Output(p1_handle)});
+ auto read_back = ops::XRTReadLiteralAndRelease(root, result);
+ TF_ASSERT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({read_back}, &outputs));
+
+ xla::LiteralProto response;
+ EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+
+ auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
+ EXPECT_TRUE(CompareLiteralToLiteralProto(*expected, response));
+}
+
+TEST(RawApiTest, CompileAndExecuteReturnTuple) {
+ xrt::XLAAllocation p0;
+ p0.set_device_ordinal(0);
+ *p0.mutable_value() = FloatVector({1.0f, 2.0f});
+ xrt::XLAAllocation p1;
+ p1.set_device_ordinal(0);
+ *p1.mutable_value() = FloatVector({8.0f, 5.0f});
+
+ xrt::XLAComputation c;
+ auto config = c.mutable_config();
+ auto shapes = config->mutable_program_shape();
+ *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
+ *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
+ *shapes->mutable_result() = xla::ShapeUtil::MakeTupleShape(
+ {xla::ShapeUtil::MakeShape(xla::F32, {2})});
+ StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
+
+ xrt::XRTExecutionConfig e;
+ e.set_release_input_handles(true);
+ e.set_release_compilation_handle(true);
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto e_config =
+ ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
+ auto computation =
+ ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
+ auto c_handle = ops::XRTCompile(root, computation);
+ auto p0_value =
+ ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
+ auto p0_handle = ops::XRTAllocate(root, p0_value);
+ auto p1_value =
+ ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
+ auto p1_handle = ops::XRTAllocate(root, p1_value);
+ auto result = ops::XRTExecute(root, c_handle, e_config,
+ {Output(p0_handle), Output(p1_handle)});
+ auto read_back = ops::XRTReadLiteralAndRelease(root, result);
+ TF_ASSERT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({read_back}, &outputs));
+
+ xla::LiteralProto response;
+ EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+
+ auto sum = xla::LiteralUtil::CreateR1<float>({9.0f, 7.0f});
+ auto expected = xla::LiteralUtil::MakeTuple({sum.get()});
+ EXPECT_TRUE(CompareLiteralToLiteralProto(*expected, response));
+}
+
+} // namespace
+
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ tensorflow::xla_test_device_ptr = new tensorflow::string("XLA_CPU");
+ std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("xla_test_device", tensorflow::xla_test_device_ptr,
+ "Tensorflow device type to use for test, e.g., XLA_CPU"),
+ };
+ tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto
new file mode 100644
index 0000000000..5678f0905f
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt.proto
@@ -0,0 +1,78 @@
+syntax = "proto3";
+
+package xrt;
+
+import "tensorflow/compiler/tf2xla/host_compute_metadata.proto";
+import "tensorflow/compiler/xla/xla_data.proto";
+import "tensorflow/compiler/xla/service/hlo.proto";
+
+// Options for an XLA compilation.
+message XLAComputationConfig {
+ // The number of replicas the computation will be run on. If this is
+ // default (0) it is interpreted as 1.
+ int32 num_replicas = 1;
+ // The number of "model-parallel" cores per replica. If this is
+ // default (0) it is interpreted as 1.
+ int32 num_cores_per_replica = 2;
+ // Optional metadata about host sends and recvs.
+ tensorflow.tf2xla.HostComputeMetadata host_compute_metadata = 3;
+
+ // The arg/result shapes for the whole computation.
+ xla.ProgramShape program_shape = 4;
+ // The arg/result shapes for each core of a model-parallel
+ // computation. per_core_args_and_result_shapes is optional for a
+ // single-core computation.
+ repeated xla.ProgramShape per_core_program_shape = 5;
+}
+
+// Options and XLA computation for a compilation.
+message XLAComputation {
+ XLAComputationConfig config = 1;
+ xla.HloSnapshot hlo_snapshot = 2;
+}
+
+// Literal to allocate space for, and transfer to, device memory.
+message XLAAllocation {
+ int32 device_ordinal = 1;
+ xla.LiteralProto value = 2;
+}
+
+// Node in a tree describing a tuple constructed from input handles. A
+// node is an internal node if tuples is non-empty, in which case
+// input_index and release_input_handle are ignored. Otherwise a node
+// is a leaf node. Each leaf XLATupleNode is the index of an input
+// which corresponds to a handle that will be grafted onto the output
+// tuple at that location. If release_input_handle is true that input
+// handle will be released and become invalid. Inputs may be repeated
+// in which case leaves of the output tuple will alias. If an input is
+// repeated, release_input_handle must be false for every leaf where
+// that input appears.
+//
+// For example, if input 0 has shape {} and input 1 has shape {2,3}
+// then the XLATupleNode with structure {1,{0,1}} corresponds to a
+// tuple with shape {{2,3},{{},{2,3}}}.
+message XLATupleNode {
+ int32 input_index = 1;
+ bool release_input_handle = 2;
+ repeated XLATupleNode tuples = 3;
+}
+
+// Options for an XLA execution.
+message XRTExecutionConfig {
+ // Local device to run on. This is present because the execute Op
+ // may be placed on a device such as CPU or TPU_SYSTEM that
+ // logically manages multiple cores.
+ int32 device_ordinal = 1;
+ // Which model-parallel computation to run from the compiled bundle.
+ int32 core_index_in_replica = 2;
+ // Optional key to disambiguate between executions. This is only
+ // needed if multiple host send/recvs may be outstanding
+ // concurrently with executions.
+ string execution_instance_key = 3;
+ // If non-zero, rng_seed to reset the core with.
+ uint32 rng_seed = 4;
+ // If true, release allocation handles on the inputs after running.
+ bool release_input_handles = 5;
+ // If true, release the handle to the computation after running.
+ bool release_compilation_handle = 6;
+}
diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.cc b/tensorflow/compiler/xrt/xrt_compilation_cache.cc
new file mode 100644
index 0000000000..4844c7fb71
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_compilation_cache.cc
@@ -0,0 +1,263 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
+
+#include "absl/synchronization/mutex.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+const char* kXRTCompilationCacheResourceName = "xrt_compilation_cache";
+
+XRTCompilationCache::EntryRefImpl::EntryRefImpl(XRTCompilationCache* parent,
+ CompiledSubgraph* entry)
+ : parent_(parent), entry_(entry) {
+ entry_->Ref();
+}
+
+XRTCompilationCache::EntryRefImpl::~EntryRefImpl() {
+ parent_->DiscardEntryRef(entry_);
+}
+
+XRTCompilationCacheEntry XRTCompilationCache::EntryRefImpl::get() {
+ return XRTCompilationCacheEntry(entry_->program.get());
+}
+
+XRTCompilationCache::XRTCompilationCache(int max_number_of_entries)
+ : max_cache_entries_(max_number_of_entries) {
+ CHECK_GE(max_cache_entries_, 0);
+ VLOG(1) << "Created compilation cache max " << max_cache_entries_
+ << " entries.";
+}
+
+XRTCompilationCache::~XRTCompilationCache() {
+ VLOG(1) << "XRTCompilationCache::~XRTCompilationCache()";
+ while (!entries_by_last_use_.empty()) {
+ MarkOldestEntryForEviction();
+ }
+ // By the time the cache is deleted all reference holders should have already
+ // been deleted, since they were holding references to the cache. So all
+ // entries should be gone at this point.
+ CHECK_EQ(cache_.size(), 0);
+ CHECK_EQ(entries_by_uid_.size(), 0);
+ CHECK_EQ(cache_entries_, 0);
+ CHECK_EQ(marked_for_eviction_entries_, 0);
+}
+
+Status XRTCompilationCache::Release(int64 uid) {
+ absl::MutexLock lock(&mu_);
+ auto iter = entries_by_uid_.find(uid);
+
+ if (iter == entries_by_uid_.end()) {
+ return errors::NotFound("No cache entry found for uid ", uid);
+ }
+
+ DiscardEntryRefLocked(iter->second);
+
+ VLOG(1) << "After releasing entry " << uid << " refs cache is "
+ << cache_.size() << " entries ("
+ << cache_entries_ + marked_for_eviction_entries_
+ << "), marked for eviction "
+ << (cache_.size() - entries_by_last_use_.size()) << " entries ("
+ << marked_for_eviction_entries_ << ").";
+
+ return Status::OK();
+}
+
+void XRTCompilationCache::DiscardEntryRef(CompiledSubgraph* entry) {
+ absl::MutexLock lock(&mu_);
+ DiscardEntryRefLocked(entry);
+}
+
+void XRTCompilationCache::DiscardEntryRefLocked(CompiledSubgraph* entry) {
+ if (entry->RefCountIsOne()) {
+ // The last reference to this entry is going away, so really delete it from
+ // the cache in such a way that it can't be restored by being looked up
+ // again.
+
+ // Sanity-check that it has been marked for eviction.
+ CHECK(entries_by_last_use_.find(entry->last_use) ==
+ entries_by_last_use_.end());
+ // Update the counter tracking how much space is taken up by entries that
+ // are marked for eviction.
+ --marked_for_eviction_entries_;
+
+ // Remove the entry from the cache.
+ auto erased = cache_.erase(entry->key);
+ if (erased == 0) {
+ LOG(FATAL) << "Tried to discard nonexistent cache entry";
+ }
+ erased = entries_by_uid_.erase(entry->uid);
+ CHECK_EQ(erased, 1);
+ }
+ entry->Unref();
+}
+
+void XRTCompilationCache::MarkOldestEntryForEviction() {
+ CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second;
+ VLOG(1) << "Marking " << entry_to_mark->key << " for eviction";
+ entries_by_last_use_.erase(entry_to_mark->last_use);
+ --cache_entries_;
+ ++marked_for_eviction_entries_;
+ // Discard the cache's reference to entry. If steps are holding onto
+ // references to entry it won't be deleted until the last step holding it
+ // completes. It stays in the cache in the meantime and can be resurrected
+ // by a call to CompileIfKeyAbsent if that occurs before the last reference
+ // expires.
+ DiscardEntryRefLocked(entry_to_mark);
+}
+
+void XRTCompilationCache::LookupEntryMarkedForEviction(
+ CompiledSubgraph* entry) {
+ // The entry was previously marked for eviction (or is newly created) so
+ // unmark it. Add a reference (owned by the cache), update the cache size, and
+ // mark something old for eviction if necessary.
+ entry->Ref();
+ --marked_for_eviction_entries_;
+ ++cache_entries_;
+
+ // Mark the least-recently-used non-marked entry for eviction. Never mark the
+ // most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
+ // which means there's only one entry not already marked for eviction), so
+ // that an entry persists in the cache even if it is larger than the allocated
+ // cache size.
+ while (entries_by_last_use_.size() > 1 &&
+ cache_entries_ > max_cache_entries_) {
+ MarkOldestEntryForEviction();
+ }
+}
+
+XRTCompilationCache::CompiledSubgraph* XRTCompilationCache::InitializeEntry(
+ const string& key,
+ const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
+ initialize_program) {
+ CompiledSubgraph* entry = new CompiledSubgraph();
+ entry->parent = this;
+ entry->key = key;
+ entry->uid = next_uid_++;
+ // Add the entry to the cache. Once the computation has been compiled,
+ // UpdateEntryAfterCompilation will be called to potentially mark old entries
+ // that don't fit any more for eviction.
+ //
+ // At this point there is one reference to entry, which is owned by the caller
+ // who created the entry. A second reference, owned by the cache, will be
+ // added below since we leave the entry in the 'marked for eviction' state
+ // here.
+ auto cache_inserted =
+ cache_.insert(std::pair<string, CompiledSubgraph*>(key, entry));
+ CHECK(cache_inserted.second);
+
+ // Initialize the program outside the lock so that other cache operations
+ // can proceed during the (potentially lengthy) initialization.
+ Status s;
+ std::unique_ptr<xla::LocalExecutable> program;
+ {
+ mu_.Unlock();
+ { s = initialize_program(&program); }
+ mu_.Lock();
+ }
+
+ // Add the entry to the uid index.
+ auto uid_inserted = entries_by_uid_.insert(
+ std::pair<int64, CompiledSubgraph*>(entry->uid, entry));
+ CHECK(uid_inserted.second);
+
+ entry->initialized = true;
+ entry->initialization_status = s;
+ if (s.ok()) {
+ entry->program = std::move(program);
+ }
+ // Add the entry to marked_for_eviction_entries_ since it will be adjusted
+ // down again when the newly-created entry gets unmarked.
+ ++marked_for_eviction_entries_;
+ return entry;
+}
+
+Status XRTCompilationCache::CompileIfKeyAbsent(
+ const string& key, int64* uid,
+ const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
+ compile_function) {
+ CompiledSubgraph* entry = nullptr;
+
+ absl::MutexLock lock(&mu_);
+ auto iter = cache_.find(key);
+
+ if (iter == cache_.end()) {
+ // The single ref on the newly-created entry is owned by the caller.
+ VLOG(1) << "Before adding new entry for key " << key << " cache is "
+ << cache_.size() << " entries ("
+ << cache_entries_ + marked_for_eviction_entries_ << "), "
+ << " marked for eviction "
+ << (cache_.size() - entries_by_last_use_.size()) << " entries ("
+ << marked_for_eviction_entries_ << ").";
+ entry = InitializeEntry(key, compile_function);
+ } else {
+ VLOG(1) << "Before refreshing entry for key " << key << " cache is "
+ << cache_.size() << " entries ("
+ << cache_entries_ + marked_for_eviction_entries_ << "), "
+ << " marked for eviction "
+ << (cache_.size() - entries_by_last_use_.size()) << " entries ("
+ << marked_for_eviction_entries_ << ").";
+ entry = iter->second;
+ // Make a new reference that is owned by the caller.
+ entry->Ref();
+ // Block if necessary until the subgraph has been initialized.
+ mu_.Await(absl::Condition(
+ +[](CompiledSubgraph* e) { return e->initialized; }, entry));
+ }
+
+ // Let the caller know the uid of the entry.
+ *uid = entry->uid;
+
+ // Remove the old LRU-table entry if it wasn't already marked for eviction.
+ auto erased = entries_by_last_use_.erase(entry->last_use);
+ // Update the LRU table indicating this entry is the most recently used.
+ entry->last_use = use_counter_++;
+ entries_by_last_use_[entry->last_use] = entry;
+ if (erased == 0) {
+ // The entry had been marked for eviction, or is newly created.
+ LookupEntryMarkedForEviction(entry);
+ }
+
+ VLOG(1) << "After refreshing entry for key " << key << " cache is "
+ << cache_.size() << " entries ("
+ << cache_entries_ + marked_for_eviction_entries_ << "), "
+ << " marked for eviction "
+ << (cache_.size() - entries_by_last_use_.size()) << " entries ("
+ << marked_for_eviction_entries_ << ").";
+
+ return entry->initialization_status;
+}
+
+Status XRTCompilationCache::Lookup(
+ int64 uid, std::unique_ptr<XRTCompilationCacheEntryRef>* entry) {
+ entry->reset();
+
+ absl::MutexLock lock(&mu_);
+ const auto iter = entries_by_uid_.find(uid);
+ if (iter == entries_by_uid_.end()) {
+ return errors::NotFound("No executable found for uid ", uid);
+ }
+ CompiledSubgraph* cache_entry = iter->second;
+ *entry = std::unique_ptr<XRTCompilationCacheEntryRef>(
+ new EntryRefImpl(this, cache_entry));
+ return Status::OK();
+}
+
+string XRTCompilationCache::DebugString() { return "XRTCompilationCache"; }
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.h b/tensorflow/compiler/xrt/xrt_compilation_cache.h
new file mode 100644
index 0000000000..c505299a45
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_compilation_cache.h
@@ -0,0 +1,238 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_
+#define TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_
+
+#include <memory>
+#include <string>
+
+#include "absl/synchronization/mutex.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/refcount.h"
+
+namespace tensorflow {
+
+extern const char* kXRTCompilationCacheResourceName;
+
+struct XRTCompilationCacheEntry {
+ explicit XRTCompilationCacheEntry(xla::LocalExecutable* executable)
+ : executable(executable) {}
+
+ // Returns a non-owned pointer to an immutable executable.
+ xla::LocalExecutable* get_executable() const { return executable; }
+
+ private:
+ xla::LocalExecutable* executable;
+};
+
+// Base class for a reference to a cached executable. A unique_ptr to a
+// XRTCompilationCacheEntryRef is returned by the cache Lookup methods below,
+// and ensures the underlying executable is not garbage-collected until the
+// client discards the ptr.
+class XRTCompilationCacheEntryRef {
+ public:
+ virtual ~XRTCompilationCacheEntryRef() = default;
+
+ // Returns a XRTCompilationCacheEntry that should not be used beyond the
+ // lifetime of the XRTCompilationCacheEntryRef.
+ virtual XRTCompilationCacheEntry get() = 0;
+};
+
+// Cache for compiled XLA executables.
+// TODO(b/112646171) rationalize this with the other compilation caches.
+//
+// Each key identifies a unique XLA computation, and the value is executable
+// generated by compiling the computation.
+//
+// When a computation is considered for compilation, the client calls
+//
+// auto key = <compute key for computation>;
+// auto compile_function = <lambda to compile computation into executable>;
+// int64 uid;
+// CompileIfKeyAbsent(computation_key, &uid, compile_function);
+//
+// where computation_key is the key computed for the computation. On success,
+// uid contains an identifier that can be used to look up the executable. If the
+// compiled executable were not present in the cache, compile_function would be
+// called to generate it.
+//
+// The caller is responsible for calling Release(uid) once for every
+// call to CompileIfKeyAbsent(key, ...) to discard the reference to the
+// compilation results, after the caller is sure it will not look up the
+// compiled executables again.
+//
+// Subsequently the client can call
+//
+// std::unique_ptr<XRTCompilationCacheEntryRef> entry;
+// Lookup(uid, &entry);
+// auto proto = entry->get();
+//
+// to access a cached executable.
+class XRTCompilationCache : public ResourceBase {
+ public:
+ // There is no way in general to discover the size taken by an XLA executable,
+ // so the cache defaults to a specific number of entries to determine when to
+ // start evicting programs. TODO(b/112592410) change this if the XLA API gets
+ // a mechanism to query size.
+ explicit XRTCompilationCache(int max_number_of_entries);
+ ~XRTCompilationCache() override;
+
+ // Ensures there is an entry for key present in the cache. By the time
+ // CompileIfKeyAbsent returns there is guaranteed to be an entry in the cache
+ // for key, and that entry will remain valid at least until Release is called
+ // on the returned uid. The first call to CompileIfKeyAbsent with a key that
+ // is not in the cache will evaluate compile_function to compute the value to
+ // use in the entry. Subsequent calls with the same key will block until
+ // compile_function completes. Other cache reads and inserts may proceed on
+ // other threads while compile_function is executing. The caller is
+ // responsible for calling Release(uid) to manually discard its reference to
+ // the compiled program, once the caller will not look up the compiled program
+ // again.
+ //
+ // compile_function should compile the computation represented by key and fill
+ // the xla::LocalExecutable into its passed argument. It should return OK
+ // if and only if compilation succeeds. The executable will be discarded on
+ // non-OK status.
+ Status CompileIfKeyAbsent(
+ const string& key, int64* uid,
+ const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
+ compile_function);
+
+ Status Release(int64 uid);
+
+ // Looks up an executable corresponding to uid. On success a pointer to an
+ // EntryRef holding the program is returned in entry.
+ Status Lookup(int64 uid, std::unique_ptr<XRTCompilationCacheEntryRef>* entry);
+
+ string DebugString() override;
+
+ private:
+ // An entry in the compilation cache. The entry is deleted once it has been
+ // marked for eviction from the cache _and_ all looked-up entries have been
+ // released. When the entry is first created, it is uninitialized and a
+ // client-supplied compilation function is run outside the cache's lock to
+ // generate the program to be stored in the entry. Any other client that
+ // requests the entry will block until it has been initialized. Each entry has
+ // a last_use value that set from a monotonically-increasing counter in the
+ // cache whenever the entry is referenced. When the cache becomes full,
+ // entries are marked for eviction in LRU order.
+ struct CompiledSubgraph : public core::RefCounted {
+ ~CompiledSubgraph() override = default;
+
+ XRTCompilationCache* parent = nullptr; // Not owned.
+ bool initialized = false;
+ // The Status returned by the compilation function when the entry is
+ // initialized. This status will be returned to any client that requests the
+ // entry.
+ Status initialization_status;
+ // Counter to keep track of LRU entries for the eviction policy.
+ int64 last_use = -1;
+ // The unique key describing this entry.
+ string key;
+ // The uid describing this entry.
+ int64 uid;
+ // The compiled payload corresponding to the key.
+ std::unique_ptr<xla::LocalExecutable> program;
+ };
+
+ // Wrapper for a cache entry that holds a reference to the entry until the
+ // wrapper is deleted. This wrapper is the concrete type of
+ // XRTCompilationCacheEntryRef returned by Lookup.
+ class EntryRefImpl : public XRTCompilationCacheEntryRef {
+ public:
+ EntryRefImpl(XRTCompilationCache* parent, CompiledSubgraph* entry);
+ ~EntryRefImpl() override;
+
+ XRTCompilationCacheEntry get() override;
+
+ private:
+ XRTCompilationCache* parent_; // Not owned.
+ // A reference to entry_ is acquired in the contructor and released via
+ // parent->DiscardEntryRef in the destructor.
+ CompiledSubgraph* entry_;
+ };
+
+ // Releases one reference to entry. This is called by the cache when entry is
+ // marked for eviction; or by an EntryRefImpl when it is destroyed. Before the
+ // last reference to entry is released, entry is removed from cache_.
+ void DiscardEntryRef(CompiledSubgraph* entry);
+ void DiscardEntryRefLocked(CompiledSubgraph* entry)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Marks the oldest unmarked entry for eviction. Requires that there is at
+ // least one such entry.
+ void MarkOldestEntryForEviction() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Updates datastructures to indicate that entry, which had been marked for
+ // eviction, has been looked up. This is called by CompileIfKeyAbsent when an
+ // entry is newly created, or an entry that has been marked for eviction but
+ // not yet evicted is looked up.
+ //
+ // First the entry is unmarked for eviction, i.e. the cache gains a reference
+ // to entry, entry's last_use field is set to be the most recent value of
+ // use_counter_ and entries_by_last_use_ is updated accordingly.
+ //
+ // Next, the size of the cache is examined to see if any other entries need to
+ // be marked for eviction now that entry has been unmarked. While the total
+ // number of unmarked cached entries is greater than max_cache_entries_,
+ // entries are marked for eviction in LRU order. The most recently used entry
+ // is never marked for eviction, so an entry larger than the max cache entries
+ // will remain in the cache until it is replaced by something else.
+ void LookupEntryMarkedForEviction(CompiledSubgraph* entry)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Creates a new entry by running initialize_program and places it in the
+ // cache to be looked up by key. The new entry is in the 'marked for eviction'
+ // state (not present in entries_by_last_use_) and the caller is expected to
+ // call LookupEntryMarkedForEviction after InitializeEntry.
+ //
+ // **InitializeEntry releases mu_ during the call to initialize_program.**
+ CompiledSubgraph* InitializeEntry(
+ const string& key,
+ const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
+ initialize_program) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // The maximum number of entries that are stored in the cache before entries
+ // are marked for eviction.
+ const int max_cache_entries_;
+
+ mutable absl::Mutex mu_;
+ // The uid to assign to the next new entry created.
+ int64 next_uid_ GUARDED_BY(mu_) = 0;
+ // The total number of entries that are stored and not marked for eviction.
+ int cache_entries_ GUARDED_BY(mu_) = 0;
+ // The total number of entries that are marked for eviction.
+ int marked_for_eviction_entries_ GUARDED_BY(mu_) = 0;
+ // The value to assign to the last_use field of the next entry that is looked
+ // up.
+ int64 use_counter_ GUARDED_BY(mu_) = 0;
+ // All the executables that can be looked up in the cache index by key. An
+ // entry is marked for eviction iff it is present in cache_ and not in
+ // entries_by_last_use_.
+ std::unordered_map<string, CompiledSubgraph*> cache_ GUARDED_BY(mu_);
+ // All the executable entries that can be looked up in the cache indexed by
+ // uid.
+ std::unordered_map<int64, CompiledSubgraph*> entries_by_uid_ GUARDED_BY(mu_);
+ // Map from last_use to entry, used to mark entries for eviction in LRU
+ // order. If an entry's last_use counter is not present as a key in
+ // entries_by_last_use_ then the entry has been marked for eviction.
+ std::map<int64, CompiledSubgraph*> entries_by_last_use_ GUARDED_BY(mu_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_
diff --git a/tensorflow/compiler/xrt/xrt_device.cc b/tensorflow/compiler/xrt/xrt_device.cc
new file mode 100644
index 0000000000..ea40e6c895
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_device.cc
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for managing access to XLA resources.
+
+#include "tensorflow/compiler/xrt/xrt_device.h"
+
+#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+/*static*/ Status XRTGenericDeviceAccessor::GetResourceManager(
+ OpKernelContext* ctx, ResourceMgr** rm) {
+ *rm = ctx->resource_manager();
+ return Status::OK();
+}
+
+/*static*/ Status XRTGenericDeviceAccessor::InitScopedRef(
+ OpKernelContext* ctx, int device_ordinal, ScopedRef* scoped_ref) {
+ const XlaDevice::Metadata* metadata;
+ TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata));
+ if (device_ordinal != metadata->device_ordinal()) {
+ return errors::Internal("XRT device ordinal requested ", device_ordinal,
+ " on device with ordinal ",
+ metadata->device_ordinal());
+ }
+ scoped_ref->Acquire(metadata->client());
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/xrt_device.h b/tensorflow/compiler/xrt/xrt_device.h
new file mode 100644
index 0000000000..1e3fddd2a7
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_device.h
@@ -0,0 +1,66 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for keeping track of on-device state.
+
+#ifndef TENSORFLOW_COMPILER_XRT_XRT_DEVICE_H_
+#define TENSORFLOW_COMPILER_XRT_XRT_DEVICE_H_
+
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+
+namespace tensorflow {
+
+// This accessor is used for XLA CPU/GPU. It uses the device resource manager,
+// so e.g., on multi-GPU setups the compilation cache will not be shared across
+// devices.
+class XRTGenericDeviceAccessor {
+ public:
+ static Status GetResourceManager(OpKernelContext* ctx, ResourceMgr** rm);
+
+ // We use a ScopedRef pattern here even though it's not strictly necessary,
+ // just so that templated uses of this and the TPU accessor class will be as
+ // similar as possible.
+ class ScopedRef {
+ public:
+ ScopedRef() {}
+ ~ScopedRef() {}
+
+ ScopedRef(const ScopedRef&) = delete;
+ ScopedRef& operator=(const ScopedRef&) = delete;
+
+ // Returns the XLA device protected by this ScopedRef.
+ xla::LocalClient* client() { return client_; }
+ xla::Backend* backend() { return client_->mutable_backend(); }
+ int device_ordinal() { return 0; }
+
+ private:
+ // XRTGenericDeviceAccessor::InitScopedRef is the only way to initialize
+ // ScopedRef.
+ friend class XRTGenericDeviceAccessor;
+
+ void Acquire(xla::LocalClient* client) { client_ = client; }
+
+ xla::LocalClient* client_ = nullptr;
+ };
+
+ static Status InitScopedRef(OpKernelContext* ctx, int device_ordinal,
+ ScopedRef* scoped_ref);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_XRT_XRT_DEVICE_H_
diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc
new file mode 100644
index 0000000000..911ac9a78b
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_state.cc
@@ -0,0 +1,458 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for allocating XLA literals in device memory and managing handles
+// that refer to them.
+
+#include "tensorflow/compiler/xrt/xrt_state.h"
+
+#include <stdint.h>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+
+namespace tensorflow {
+
+namespace {
+
+const char* kTupleContainer = "tuples";
+
+// Counter used to assign unique handles.
+mutex _uid_mutex(tensorflow::LINKER_INITIALIZED);
+int64 _uid GUARDED_BY(_uid_mutex) = 0;
+int64 get_uid() {
+ mutex_lock l(_uid_mutex);
+ return _uid++;
+}
+
+Status AllocateScopedShapedBuffer(
+ xla::Backend* backend, int device_ordinal, const xla::Shape& shape,
+ std::unique_ptr<xla::ScopedShapedBuffer>* buffer) {
+ auto transfer_manager = backend->transfer_manager();
+ auto allocator = backend->memory_allocator();
+ TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
+
+ // XLA may use a different representation on device than the representation on
+ // the host. XLA does not document any contract for the relationship between
+ // these representations :/ Right now, the device shape is always a superset
+ // of the host shape, meaning that for any valid ShapeIndex in the host shape
+ // that ShapeIndex is also valid in the device shape, but not vice versa. In
+ // particular, some host-side types are rewritten to be tuples. We rely on
+ // this property when making sub-buffers, because we assume that if the client
+ // requests the host-shape sub-buffer at index i, that will correspond to the
+ // right device-shape sub-buffer at the same index.
+ xla::Shape on_device_shape = transfer_manager->HostShapeToDeviceShape(shape);
+
+ // The ScopedShapedBuffer frees the buffers that have so far been allocated if
+ // it goes out of scope. That's useful if we return early as the result of an
+ // error allocating one of the later buffers.
+ *buffer = absl::make_unique<xla::ScopedShapedBuffer>(
+ shape, on_device_shape, allocator, device_ordinal);
+ for (auto& index_to_buffer : (*buffer)->buffers()) {
+ xla::Shape subshape =
+ xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
+ uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
+ TF_ASSIGN_OR_RETURN(
+ xla::OwningDeviceMemory buffer,
+ allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/false));
+ // Move our buffer into shaped_buffer, which takes ownership of it.
+ index_to_buffer.second = buffer.Forget();
+ VLOG(2) << "Allocated buffer at " << index_to_buffer.second.opaque()
+ << " index " << index_to_buffer.first.ToString();
+ }
+
+ TF_RETURN_IF_ERROR(
+ transfer_manager->WriteTupleIndexTables(stream.get(), *(buffer->get())));
+
+ return Status::OK();
+}
+
+} // namespace
+
+XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation,
+ int device_ordinal,
+ xla::DeviceMemoryAllocator* allocator)
+ : allocation_(allocation),
+ device_ordinal_(device_ordinal),
+ allocator_(allocator) {}
+
+XRTBufferAllocation::~XRTBufferAllocation() {
+ // Deallocate explicitly allows allocation_ to be null.
+ Status s = allocator_->Deallocate(device_ordinal_, allocation_);
+ // Nothing to do but check fail here if memory datastructures are corrupted.
+ CHECK(s.ok());
+ VLOG(2) << "Freed buffer at " << allocation_.opaque();
+}
+
+const se::DeviceMemoryBase& XRTBufferAllocation::allocation() {
+ return allocation_;
+}
+
+void XRTBufferAllocation::DiscardAllocation() {
+ // Replace the allocation with a null.
+ allocation_ = se::DeviceMemoryBase();
+}
+
+XRTTupleAllocation::XRTTupleAllocation(int device_ordinal,
+ xla::DeviceMemoryAllocator* allocator,
+ const xla::Shape& on_host_shape,
+ const xla::Shape& on_device_shape)
+ : device_ordinal_(device_ordinal),
+ allocator_(allocator),
+ on_host_shape_(on_host_shape),
+ on_device_shape_(on_device_shape),
+ buffers_(&on_device_shape_) {}
+
+XRTTupleAllocation::~XRTTupleAllocation() {
+ for (auto& buffer : buffers_) {
+ buffer.second->Unref();
+ }
+}
+
+/*static*/ Status XRTTupleAllocation::CreateAndTransfer(
+ const xla::Literal& literal, xla::Backend* backend, int device_ordinal,
+ XRTTupleAllocation** allocation) {
+ auto transfer_manager = backend->transfer_manager();
+ auto allocator = backend->memory_allocator();
+
+ std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
+ TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(
+ backend, device_ordinal, literal.shape(), &scoped_buffer));
+ TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
+ TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
+ stream.get(), literal, *scoped_buffer));
+
+ // By releasing the ScopedShapedBuffer we ensure that the underlying storage
+ // won't be freed when the buffer goes out of scope at the end of this
+ // call. To avoid a leak, there must be no error-case returns from here until
+ // the end of the method.
+ auto shaped_buffer = scoped_buffer->release();
+ *allocation = new XRTTupleAllocation(device_ordinal, allocator,
+ shaped_buffer.on_host_shape(),
+ shaped_buffer.on_device_shape());
+ (*allocation)
+ ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
+ return Status::OK();
+}
+
+/*static*/ Status XRTTupleAllocation::CreateFromBuffer(
+ const xla::ShapedBuffer& shaped_buffer, xla::Backend* backend,
+ int device_ordinal, XRTTupleAllocation** allocation) {
+ auto allocator = backend->memory_allocator();
+
+ *allocation = new XRTTupleAllocation(device_ordinal, allocator,
+ shaped_buffer.on_host_shape(),
+ shaped_buffer.on_device_shape());
+ (*allocation)
+ ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
+ return Status::OK();
+}
+
+Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal,
+ std::unique_ptr<xla::Literal>* literal) {
+ auto transfer_manager = backend->transfer_manager();
+ TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
+ TF_ASSIGN_OR_RETURN(*literal, transfer_manager->TransferLiteralFromDevice(
+ stream.get(), ToShapedBuffer()));
+ return Status::OK();
+}
+
+void XRTTupleAllocation::DiscardAllocation(
+ const xla::ShapeIndex& buffer_index) {
+ buffers_.element(buffer_index)->DiscardAllocation();
+}
+
+const xla::Shape& XRTTupleAllocation::on_host_shape() { return on_host_shape_; }
+
+const xla::Shape& XRTTupleAllocation::on_device_shape() {
+ return on_device_shape_;
+}
+
+int XRTTupleAllocation::device_ordinal() { return device_ordinal_; }
+
+const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() {
+ return buffers_.element({})->allocation();
+}
+
+/*static*/ Status XRTTupleAllocation::Lookup(ResourceMgr* rm, int64 key,
+ XRTTupleAllocation** allocation) {
+ string key_string = strings::StrCat(key);
+ TF_RETURN_IF_ERROR(rm->Lookup(kTupleContainer, key_string, allocation));
+ return Status::OK();
+}
+
+/*static*/ Status XRTTupleAllocation::DeleteFromResourceManager(ResourceMgr* rm,
+ int64 key) {
+ string key_string = strings::StrCat(key);
+ return rm->Delete<XRTTupleAllocation>(kTupleContainer, key_string);
+}
+
+// Helper typedef to make ShapeTree ForEach helper lambda signatures more
+// readable. They need a type of const T& where in this case T is the
+// following pointer.
+typedef XRTBufferAllocation* XRTBufferAllocationPtr;
+
+/*static*/ Status XRTTupleAllocation::MakeSubBuffer(
+ XRTTupleAllocation* parent, const xla::ShapeIndex& subshape,
+ XRTTupleAllocation** allocation, bool alias_parent_allocation) {
+ TF_ASSIGN_OR_RETURN(
+ const xla::Shape* host_sub_shape,
+ xla::ShapeUtil::TryGetSubshape(parent->on_host_shape(), subshape));
+ TF_ASSIGN_OR_RETURN(
+ const xla::Shape* device_sub_shape,
+ xla::ShapeUtil::TryGetSubshape(parent->on_device_shape(), subshape));
+
+ *allocation =
+ new XRTTupleAllocation(parent->device_ordinal(), parent->allocator_,
+ *host_sub_shape, *device_sub_shape);
+ if (alias_parent_allocation) {
+ // Copy the subtree of allocations from the parent allocation.
+ (*allocation)->buffers_.CopySubtreeFrom(parent->buffers_, subshape, {});
+ // Increment the refcount on each aliased buffer.
+ (*allocation)
+ ->buffers_.ForEachElement(
+ [](const xla::ShapeIndex& index,
+ const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
+ } else {
+ // Find the buffers in the parent allocation that match the subtree, and
+ // move the parent allocation's buffer over to the new allocation.
+ (*allocation)
+ ->buffers_.ForEachMutableElement(
+ [&](const xla::ShapeIndex& index, XRTBufferAllocationPtr* buffer) {
+ // Extend the allocation's index to the parent's frame by adding
+ // subshape as a prefix.
+ xla::ShapeIndex parent_index = subshape;
+ for (int i = 0; i < index.size(); ++i) {
+ parent_index.push_back(index[i]);
+ }
+ *buffer = parent->buffers_.element(parent_index);
+ *parent->buffers_.mutable_element(parent_index) =
+ new XRTBufferAllocation(se::DeviceMemoryBase(),
+ parent->device_ordinal(),
+ parent->allocator_);
+ });
+ }
+
+ return Status::OK();
+}
+
+/* static */ Status XRTTupleAllocation::ExpandTreeOfTuples(
+ const xla::ShapeTree<ExpandedTupleInput>& elements, int device_ordinal,
+ xla::DeviceMemoryAllocator* allocator, xla::Shape* host_shape,
+ xla::Shape* device_shape) {
+ // Initialize both host and device shape to be the 'spine' of the new tuple
+ // shape, given by the shape of the tree of tuples.
+ *host_shape = elements.shape();
+ *device_shape = elements.shape();
+ // Now go over the leaves of the tree of tuples, and 'graft' the host/device
+ // shapes of the allocation at that leaf onto the expanded host/device shapes
+ // at the leaf position.
+ TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
+ [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
+ if (elements.IsLeaf(index)) {
+ if (element.allocation == nullptr) {
+ return errors::InvalidArgument(
+ "MakeTuple elements has a null internal node at index ",
+ index.ToString());
+ }
+ if (device_ordinal != element.allocation->device_ordinal() ||
+ allocator != element.allocation->allocator_) {
+ return errors::InvalidArgument(
+ "MakeTuple elements must all be allocated on the same device "
+ "as the destination.");
+ }
+ *xla::ShapeUtil::GetMutableSubshape(host_shape, index) =
+ element.allocation->on_host_shape();
+ *xla::ShapeUtil::GetMutableSubshape(device_shape, index) =
+ element.allocation->on_device_shape();
+ } else {
+ if (element.allocation != nullptr) {
+ return errors::InvalidArgument(
+ "MakeTuple elements has a non-null internal node at index ",
+ index.ToString());
+ }
+ }
+ return Status::OK();
+ }));
+ return Status::OK();
+}
+
+/*static*/ Status XRTTupleAllocation::MakeTuple(
+ xla::Backend* backend, int device_ordinal,
+ const xla::ShapeTree<ExpandedTupleInput>& elements,
+ XRTTupleAllocation** allocation) {
+ auto transfer_manager = backend->transfer_manager();
+ auto allocator = backend->memory_allocator();
+ TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
+
+ xla::Shape host_shape;
+ xla::Shape device_shape;
+ TF_RETURN_IF_ERROR(ExpandTreeOfTuples(elements, device_ordinal, allocator,
+ &host_shape, &device_shape));
+
+ // The aliasing is determined below based on whether or not all the inputs are
+ // released while being transferred. allocation_tmp is a local pointer that is
+ // copied to *allocation at the end only if the method succeeds.
+ auto allocation_tmp = new XRTTupleAllocation(device_ordinal, allocator,
+ host_shape, device_shape);
+ core::ScopedUnref allocation_unref(allocation_tmp);
+ // First allocate device memory for the new tuple index tables, one at each
+ // internal node of the elements tree. Do this in a separate pass into a
+ // ScopedShapedBuffer so that it's easy to free the newly-allocated memory if
+ // an allocation fails. Make sure the shape has layout so that the code that
+ // writes index tables will be happy lower down.
+ xla::Shape spine_shape = elements.shape();
+ xla::LayoutUtil::SetToDefaultLayout(&spine_shape);
+ auto new_tuple_buffers = absl::make_unique<xla::ScopedShapedBuffer>(
+ spine_shape, spine_shape, allocator, device_ordinal);
+ TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
+ [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
+ if (!elements.IsLeaf(index)) {
+ xla::Shape subshape =
+ xla::ShapeUtil::GetSubshape(device_shape, index);
+ uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
+ TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer,
+ allocator->Allocate(device_ordinal, size,
+ /*retry_on_failure=*/false));
+ VLOG(2) << "Allocated buffer at " << buffer.opaque() << " index "
+ << index.ToString();
+ // Move the new buffer into new_tuple_buffers, which takes ownership
+ // of it.
+ new_tuple_buffers->set_buffer(std::move(buffer), index);
+ }
+ return Status::OK();
+ }));
+ // Transfer from the ScopedShapedBuffer to a ShapedBuffer, which does not own
+ // the newly-allocated index tables. Right now there's no owner for the new
+ // index tables, so next we will transfer ownership to the new allocation,
+ // taking care not to return early on any errors in the meantime.
+ xla::ShapedBuffer tuple_buffers = new_tuple_buffers->release();
+ // Now fill in the remaining datastructures. After this ForEachElement
+ // completes:
+ // 1) Every leaf element of tuple_buffers will be the root buffer of
+ // an existing allocation, and every internal element of tuple_buffers
+ // will be a newly-allocated index table. tuple_buffers does not own any
+ // of these.
+ // 2) Every element of allocation_tmp->buffers_ will be a correctly
+ // constructed
+ // XRTBufferAllocation wrapping the necessary allocations. For buffers in
+ // existing allocations there will be a new reference owned by the new
+ // allocation, and for newly-allocated index tables there will be a
+ // single reference owned by the new allocation.
+ elements.ForEachElement([&](const xla::ShapeIndex& index,
+ const ExpandedTupleInput& element) {
+ if (elements.IsLeaf(index)) {
+ allocation_tmp->buffers_.CopySubtreeFrom(element.allocation->buffers_, {},
+ index);
+ tuple_buffers.set_buffer(element.allocation->root_allocation(), index);
+ if (element.release_allocation_after_use) {
+ // Transfer the references from element's buffers to the new allocation
+ // rather than incrementing the refcount. The caller should have
+ // validated that release_allocation_after_use is false if
+ // element.allocation appears in more than one leaf.
+ element.allocation->buffers_.ForEachMutableElement(
+ [&](const xla::ShapeIndex& index, XRTBufferAllocationPtr* buffer) {
+ *buffer = new XRTBufferAllocation(
+ se::DeviceMemoryBase(), element.allocation->device_ordinal(),
+ element.allocation->allocator_);
+ });
+ } else {
+ // Increment the refcount on each newly-aliased buffer.
+ element.allocation->buffers_.ForEachElement(
+ [](const xla::ShapeIndex& index,
+ const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
+ }
+ } else {
+ // This is an internal node of the tuple tree so take ownership of the
+ // newly-created index table.
+ *allocation_tmp->buffers_.mutable_element(index) =
+ new XRTBufferAllocation(tuple_buffers.buffer(index), device_ordinal,
+ allocator);
+ }
+ });
+ // Because the internal nodes of tuple_buffers are exactly the new index
+ // tables, WriteTupleIndexTables will write only the new index tables and not
+ // rewrite the index tables for the existing allocations.
+ TF_RETURN_IF_ERROR(
+ transfer_manager->WriteTupleIndexTables(stream.get(), tuple_buffers));
+
+ *allocation = allocation_tmp;
+ // Get another reference since allocation_tmp will be Unrefed automatically on
+ // exit.
+ (*allocation)->Ref();
+ return Status::OK();
+}
+
+Status XRTTupleAllocation::Intern(ResourceMgr* rm, int64* key) {
+ *key = get_uid();
+ string key_string = strings::StrCat(*key);
+ return rm->Create(kTupleContainer, key_string, this);
+}
+
+bool XRTTupleAllocation::IsExclusiveOwner() {
+ for (const auto& buffer : buffers_) {
+ if (!buffer.second->RefCountIsOne()) return false;
+ }
+ return true;
+}
+
+void XRTTupleAllocation::InitializeFromShapedBuffer(
+ const xla::ShapedBuffer& shaped_buffer,
+ xla::DeviceMemoryAllocator* allocator, int device_ordinal) {
+ for (auto& buffer : buffers_) {
+ // Make a reference-counted version of the allocated buffer.
+ buffer.second = new XRTBufferAllocation(shaped_buffer.buffer(buffer.first),
+ device_ordinal, allocator);
+ }
+}
+
+xla::ShapedBuffer XRTTupleAllocation::ToShapedBuffer() {
+ xla::ShapedBuffer shaped_buffer(on_host_shape(), on_device_shape(),
+ allocator_->platform(), device_ordinal_);
+ for (const auto& buffer : buffers_) {
+ shaped_buffer.set_buffer(buffer.second->allocation(), buffer.first);
+ }
+ return shaped_buffer;
+}
+
+xla::ShapeTree<xla::MaybeOwningDeviceMemory>
+XRTTupleAllocation::ToDeviceMemoryTree(bool release) {
+ xla::ShapeTree<xla::MaybeOwningDeviceMemory> shaped_tree(on_device_shape());
+ for (const auto& buffer : buffers_) {
+ if (!release) {
+ *shaped_tree.mutable_element(buffer.first) = buffer.second->allocation();
+ } else {
+ *shaped_tree.mutable_element(buffer.first) = xla::OwningDeviceMemory(
+ buffer.second->allocation(), device_ordinal_, allocator_);
+ DiscardAllocation(buffer.first);
+ }
+ }
+ return shaped_tree;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h
new file mode 100644
index 0000000000..42705688dd
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_state.h
@@ -0,0 +1,208 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for keeping track of on-device state.
+
+#ifndef TENSORFLOW_COMPILER_XRT_XRT_STATE_H_
+#define TENSORFLOW_COMPILER_XRT_XRT_STATE_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+
+namespace tensorflow {
+
+// TODO(misard) make this a Tensor if and when that makes sense.
+// A reference-counted wrapper around a buffer allocation. This maps an XLA
+// tuple index or a non-tuple XLA shape to a region of device memory. The device
+// memory buffer is freed when the reference count drops to zero.
+class XRTBufferAllocation : public core::RefCounted {
+ public:
+ XRTBufferAllocation(const se::DeviceMemoryBase& allocation,
+ int device_ordinal,
+ xla::DeviceMemoryAllocator* allocator);
+ ~XRTBufferAllocation() override;
+
+ // The region of device memory being wrapped.
+ const se::DeviceMemoryBase& allocation();
+
+ // Sets the DeviceMemoryBase to be null. DiscardAllocation should be called
+ // when ownership of the underlying buffer has been transferred, e.g., to an
+ // output buffer when input and output buffers are aliased during
+ // execution. The call to DiscardAllocation prevents any device buffer being
+ // freed when the reference count drops to zero.
+ void DiscardAllocation();
+
+ private:
+ se::DeviceMemoryBase allocation_;
+ int device_ordinal_;
+ xla::DeviceMemoryAllocator* allocator_;
+};
+
+// Entry in the resource manager corresponding to an allocation handle returned
+// to a client. The handle identifies an immutable tuple of data in device
+// memory. New handles can be created in three ways: by passing a literal in
+// which case device memory is allocated and the literal is transferred to that
+// memory; by aliasing a sub-shape of an existing tuple-shaped handle; or by
+// aliasing a vector of existing handles to create a new tuple. The underlying
+// storage is reference-counted. When a handle is released, the reference count
+// of each storage buffer is decremented, and buffers with no outstanding
+// references are freed.
+class XRTTupleAllocation : public ResourceBase {
+ public:
+ ~XRTTupleAllocation() override;
+
+ // Allocates new device memory buffers sufficient to store literal, transfers
+ // literal to that memory, and returns a XRTTupleAllocation handle to the
+ // allocated buffers.
+ static Status CreateAndTransfer(const xla::Literal& literal,
+ xla::Backend* backend, int device_ordinal,
+ XRTTupleAllocation** allocation);
+
+ // Wraps an existing ShapeBuffer in a new XRTTupleAllocation handle.
+ static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer,
+ xla::Backend* backend, int device_ordinal,
+ XRTTupleAllocation** allocation);
+
+ // Aliases a sub-shape of parent and returns a XRTTupleAllocation handle
+ // to the sub-shape. If alias_base_allocation is true, the buffers in the
+ // sub-shape will be shared between parent and the returned allocation,
+ // otherwise the overlapping buffers in parent will be replaced by
+ // nullptr.
+ static Status MakeSubBuffer(XRTTupleAllocation* parent,
+ const xla::ShapeIndex& subshape,
+ XRTTupleAllocation** allocation,
+ bool alias_parent_allocation);
+
+ // A structure describing a leaf of a tree of tuples to expand. Each leaf
+ // contains an allocation and indicates whether or not the allocation's handle
+ // should be freed after incorporating its buffers into the expanded tree.
+ struct ExpandedTupleInput {
+ XRTTupleAllocation* allocation;
+ bool release_allocation_after_use;
+ };
+
+ // Returns a handle to a new tuple where the subtree of the new tuple at an
+ // index corresponding to a leaf of 'elements' is constructed from the
+ // allocation (i.e., a tuple or array) pointed to by that leaf. If
+ // release_allocation_after_use is false at a leaf, the new tuple will alias
+ // the input allocation at that leaf, otherwise the input allocation will be
+ // released. Input allocations may be repeated (appear in more than one leaf)
+ // in which case the corresponding buffers in the output tuple will alias. If
+ // an input is repeated, release_input_handle must be false for every leaf
+ // where that input appears. The latter property is not validated by MakeTuple
+ // and must be enforced by the caller.
+ static Status MakeTuple(xla::Backend* backend, int device_ordinal,
+ const xla::ShapeTree<ExpandedTupleInput>& elements,
+ XRTTupleAllocation** allocation);
+
+ // Retrieves the allocation interned under key from rm. The caller owns a
+ // reference to allocation after looking it up.
+ static Status Lookup(ResourceMgr* rm, int64 key,
+ XRTTupleAllocation** allocation);
+
+ // Deletes the reference in the rm to an allocation interned under key.
+ static Status DeleteFromResourceManager(ResourceMgr* rm, int64 key);
+
+ // Adds the allocation to a ResourceMgr and returns the key that will be used
+ // to retrieve it. Transfers a reference on *this to rm.
+ Status Intern(ResourceMgr* rm, int64* key);
+
+ // Copies the allocation from device to host and returns it in literal.
+ Status ToLiteral(xla::Backend* backend, int device_ordinal,
+ std::unique_ptr<xla::Literal>* literal);
+
+ // True if none of the buffers in the allocation are aliased by any other live
+ // handle.
+ bool IsExclusiveOwner();
+
+ // The ordinal of the device holding this tuple.
+ int device_ordinal();
+
+ // Returns the shape of the tuple as seen by the host.
+ const xla::Shape& on_host_shape();
+
+ // Returns the shape of the tuple as stored on the device.
+ const xla::Shape& on_device_shape();
+
+ // Returns the buffer pointed to by the root of the tuple.
+ const se::DeviceMemoryBase& root_allocation();
+
+ // Stops managing the storage for the allocation at buffer_index, e.g.,
+ // because it has been aliased to the output buffer of a computation.
+ void DiscardAllocation(const xla::ShapeIndex& buffer_index);
+
+ // Returns the tree of allocations as a ShapedBuffer. This tree may not have
+ // the same shape as on_host_shape.
+ xla::ShapedBuffer ToShapedBuffer();
+
+ // Returns the device memory tree of this allocation. If 'release' is set, the
+ // ownership of the device memory is transferred to the result.
+ xla::ShapeTree<xla::MaybeOwningDeviceMemory> ToDeviceMemoryTree(bool release);
+
+ string DebugString() override { return "XLA allocation handle"; }
+
+ private:
+ // Creates a new handle with (tuple) shape.
+ XRTTupleAllocation(int device_ordinal, xla::DeviceMemoryAllocator* allocator,
+ const xla::Shape& on_host_shape,
+ const xla::Shape& on_device_shape);
+
+ // Inherits the allocations represented in buffer, which must have the same
+ // shape as buffers_.
+ void InitializeFromShapedBuffer(const xla::ShapedBuffer& shaped_buffer,
+ xla::DeviceMemoryAllocator* allocator,
+ int device_ordinal);
+
+ // Takes a tree 'elements' where each leaf is an allocation, validates that
+ // they are all on device_ordinal managed by allocator, and returns in
+ // host_shape and device_shape the host/device shapes of the expanded tree,
+ // where at each leaf of elements the shape of the allocation at elements is
+ // grafted on.
+ static Status ExpandTreeOfTuples(
+ const xla::ShapeTree<ExpandedTupleInput>& elements, int device_ordinal,
+ xla::DeviceMemoryAllocator* allocator, xla::Shape* host_shape,
+ xla::Shape* device_shape);
+
+ // Location of the memory that is being managed.
+ int device_ordinal_;
+ xla::DeviceMemoryAllocator* allocator_;
+
+ // The shape that the caller thinks the tuple has.
+ const xla::Shape on_host_shape_;
+ // The shape that the tuple has on device. Store this explicitly instead of
+ // using a shape stored in ShapeTree because ShapeTree discards the layout.
+ const xla::Shape on_device_shape_;
+ // The tree of reference-counted buffers, which uses on_device_shape_ as its
+ // shape.
+ xla::ShapeTree<XRTBufferAllocation*> buffers_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_XRT_XRT_STATE_H_
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/BUILD b/tensorflow/contrib/autograph/examples/integration_tests/BUILD
index 6c281485b4..3630b41fc8 100644
--- a/tensorflow/contrib/autograph/examples/integration_tests/BUILD
+++ b/tensorflow/contrib/autograph/examples/integration_tests/BUILD
@@ -23,7 +23,6 @@ py_test(
],
srcs_version = "PY2AND3",
tags = ["no_windows"],
- visibility = ["//visibility:public"],
deps = [
"//tensorflow:tensorflow_py",
],
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
index a0938b3e5f..fe630ef852 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
@@ -22,9 +22,11 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/contrib/autograph/pyct",
"@gast_archive//:gast",
"@six_archive//:six",
+ # TODO(aqj) Revisit this dependency direction when pyct is more
+ # modularized
+ "//tensorflow/contrib/autograph/pyct",
],
)
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/__init__.py b/tensorflow/contrib/autograph/pyct/common_transformers/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/__init__.py
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index d0fd39fa30..3b28ed77f3 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -739,6 +739,11 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
context->input("bias_feature_id", &bias_feature_id_t));
int64 bias_feature_id = bias_feature_id_t->scalar<int64>()();
+ const Tensor* weak_learner_type_t;
+ OP_REQUIRES_OK(context,
+ context->input("weak_learner_type", &weak_learner_type_t));
+ const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()();
+
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
std::vector<int32> non_empty_partitions;
@@ -767,20 +772,63 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
tensorflow::TTypes<int32>::Vec output_partition_ids =
output_partition_ids_t->vec<int32>();
+ // For a normal tree, we output a split per partition. For an oblivious
+ // tree, we output one split for all partitions of the layer.
+ int size_output = num_elements;
+ if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE &&
+ num_elements > 0) {
+ size_output = 1;
+ }
+
Tensor* gains_t = nullptr;
- OP_REQUIRES_OK(
- context, context->allocate_output("gains", TensorShape({num_elements}),
- &gains_t));
+ OP_REQUIRES_OK(context, context->allocate_output(
+ "gains", TensorShape({size_output}), &gains_t));
tensorflow::TTypes<float>::Vec gains = gains_t->vec<float>();
Tensor* output_splits_t = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- "split_infos", TensorShape({num_elements}),
- &output_splits_t));
+ OP_REQUIRES_OK(context, context->allocate_output("split_infos",
+ TensorShape({size_output}),
+ &output_splits_t));
tensorflow::TTypes<string>::Vec output_splits =
output_splits_t->vec<string>();
+ if (num_elements == 0) {
+ return;
+ }
SplitBuilderState state(context);
+ switch (weak_learner_type) {
+ case LearnerConfig::NORMAL_DECISION_TREE: {
+ ComputeNormalDecisionTree(
+ context, &state, normalizer_ratio, num_elements,
+ partition_boundaries, non_empty_partitions, bias_feature_id,
+ partition_ids, feature_ids, gradients_t, hessians_t,
+ &output_partition_ids, &gains, &output_splits);
+ break;
+ }
+ case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
+ ComputeObliviousDecisionTree(
+ context, &state, normalizer_ratio, num_elements,
+ partition_boundaries, non_empty_partitions, bias_feature_id,
+ partition_ids, feature_ids, gradients_t, hessians_t,
+ &output_partition_ids, &gains, &output_splits);
+ break;
+ }
+ }
+ }
+
+ private:
+ void ComputeNormalDecisionTree(
+ OpKernelContext* const context, SplitBuilderState* state,
+ const float normalizer_ratio, const int num_elements,
+ const std::vector<int32>& partition_boundaries,
+ const std::vector<int32>& non_empty_partitions,
+ const int64 bias_feature_id,
+ const tensorflow::TTypes<int32>::ConstVec& partition_ids,
+ const tensorflow::TTypes<int64>::ConstMatrix& feature_ids,
+ const Tensor* gradients_t, const Tensor* hessians_t,
+ tensorflow::TTypes<int32>::Vec* output_partition_ids,
+ tensorflow::TTypes<float>::Vec* gains,
+ tensorflow::TTypes<string>::Vec* output_splits) {
for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
float best_gain = std::numeric_limits<float>::lowest();
int start_index = partition_boundaries[non_empty_partitions[root_idx]];
@@ -790,7 +838,7 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
errors::InvalidArgument("Bias feature ID missing."));
GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index);
root_gradient_stats *= normalizer_ratio;
- NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
+ NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats);
int32 best_feature_idx = 0;
NodeStats best_right_node_stats(0);
NodeStats best_left_node_stats(0);
@@ -801,8 +849,8 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
left_gradient_stats *= normalizer_ratio;
GradientStats right_gradient_stats =
root_gradient_stats - left_gradient_stats;
- NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats);
- NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats);
+ NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats);
+ NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats);
if (left_stats.gain + right_stats.gain > best_gain) {
best_gain = left_stats.gain + right_stats.gain;
best_left_node_stats = left_stats;
@@ -813,18 +861,133 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
SplitInfo split_info;
auto* equality_split = split_info.mutable_split_node()
->mutable_categorical_id_binary_split();
- equality_split->set_feature_column(state.feature_column_group_id());
+ equality_split->set_feature_column(state->feature_column_group_id());
equality_split->set_feature_id(feature_ids(best_feature_idx, 0));
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
- state.FillLeaf(best_left_node_stats, left_child);
- state.FillLeaf(best_right_node_stats, right_child);
- split_info.SerializeToString(&output_splits(root_idx));
- gains(root_idx) =
- best_gain - root_stats.gain - state.tree_complexity_regularization();
- output_partition_ids(root_idx) = partition_ids(start_index);
+ state->FillLeaf(best_left_node_stats, left_child);
+ state->FillLeaf(best_right_node_stats, right_child);
+ split_info.SerializeToString(&(*output_splits)(root_idx));
+ (*gains)(root_idx) =
+ best_gain - root_stats.gain - state->tree_complexity_regularization();
+ (*output_partition_ids)(root_idx) = partition_ids(start_index);
}
}
+
+ void ComputeObliviousDecisionTree(
+ OpKernelContext* const context, SplitBuilderState* state,
+ const float normalizer_ratio, const int num_elements,
+ const std::vector<int32>& partition_boundaries,
+ const std::vector<int32>& non_empty_partitions,
+ const int64 bias_feature_id,
+ const tensorflow::TTypes<int32>::ConstVec& partition_ids,
+ const tensorflow::TTypes<int64>::ConstMatrix& feature_ids,
+ const Tensor* gradients_t, const Tensor* hessians_t,
+ tensorflow::TTypes<int32>::Vec* output_partition_ids,
+ tensorflow::TTypes<float>::Vec* gains,
+ tensorflow::TTypes<string>::Vec* output_splits) {
+ // Holds the root stats per each node to be split.
+ std::vector<GradientStats> current_layer_stats;
+ current_layer_stats.reserve(num_elements);
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ const int start_index = partition_boundaries[root_idx];
+ // First feature ID in each partition should be the bias feature.
+ OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id,
+ errors::InvalidArgument("Bias feature ID missing."));
+ GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index);
+ root_gradient_stats *= normalizer_ratio;
+ current_layer_stats.push_back(root_gradient_stats);
+ }
+ float best_gain = std::numeric_limits<float>::lowest();
+ int64 best_feature_id = 0;
+ std::vector<NodeStats> best_right_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> best_left_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> current_left_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> current_right_node_stats(num_elements, NodeStats(0));
+ int64 current_feature_id = std::numeric_limits<int64>::max();
+ int64 last_feature_id = -1;
+ // Find the lowest feature id, this is going to be the first feature id to
+ // try.
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ const int start_index = partition_boundaries[root_idx];
+ if (feature_ids(start_index + 1, 0) < current_feature_id) {
+ current_feature_id = feature_ids(start_index + 1, 0);
+ }
+ }
+ // Indexes offsets for each of the partitions that can be used to access
+ // gradients of a partition for a current feature we consider. Start at one
+ // beacuse the zero index is for the bias.
+ std::vector<int> current_layer_offsets(num_elements, 1);
+ // The idea is to try every feature id in increasing order. In each
+ // iteration we calculate the gain of the layer using the current feature id
+ // as split value, and we also obtain the following feature id to try.
+ while (current_feature_id > last_feature_id) {
+ last_feature_id = current_feature_id;
+ int64 next_feature_id = -1;
+ // Left gradient stats per node.
+ std::vector<GradientStats> left_gradient_stats(num_elements);
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ int idx =
+ current_layer_offsets[root_idx] + partition_boundaries[root_idx];
+ const int end_index = partition_boundaries[root_idx + 1];
+ if (idx < end_index && feature_ids(idx, 0) == current_feature_id) {
+ GradientStats g(*gradients_t, *hessians_t, idx);
+ g *= normalizer_ratio;
+ left_gradient_stats[root_idx] = g;
+ current_layer_offsets[root_idx]++;
+ idx++;
+ }
+ if (idx < end_index &&
+ (feature_ids(idx, 0) < next_feature_id || next_feature_id == -1)) {
+ next_feature_id = feature_ids(idx, 0);
+ }
+ }
+ float gain_of_split = 0.0;
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ GradientStats right_gradient_stats =
+ current_layer_stats[root_idx] - left_gradient_stats[root_idx];
+ NodeStats left_stat =
+ state->ComputeNodeStats(left_gradient_stats[root_idx]);
+ NodeStats right_stat = state->ComputeNodeStats(right_gradient_stats);
+ gain_of_split += left_stat.gain + right_stat.gain;
+ current_left_node_stats[root_idx] = left_stat;
+ current_right_node_stats[root_idx] = right_stat;
+ }
+ if (gain_of_split > best_gain) {
+ best_gain = gain_of_split;
+ best_left_node_stats = current_left_node_stats;
+ best_right_node_stats = current_right_node_stats;
+ best_feature_id = current_feature_id;
+ }
+ current_feature_id = next_feature_id;
+ }
+
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ best_gain -= state->ComputeNodeStats(current_layer_stats[root_idx]).gain;
+ }
+ best_gain -= num_elements * state->tree_complexity_regularization();
+
+ ObliviousSplitInfo oblivious_split_info;
+ auto* equality_split =
+ oblivious_split_info.mutable_split_node()
+ ->mutable_oblivious_categorical_id_binary_split();
+ equality_split->set_feature_column(state->feature_column_group_id());
+ equality_split->set_feature_id(best_feature_id);
+ (*gains)(0) = best_gain;
+
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ auto* left_child = oblivious_split_info.add_children();
+ auto* right_child = oblivious_split_info.add_children();
+
+ state->FillLeaf(best_left_node_stats[root_idx], left_child);
+ state->FillLeaf(best_right_node_stats[root_idx], right_child);
+
+ const int start_index = partition_boundaries[root_idx];
+ (*output_partition_ids)(root_idx) = partition_ids(start_index);
+ oblivious_split_info.add_children_parent_id(partition_ids(start_index));
+ }
+ oblivious_split_info.SerializeToString(&(*output_splits)(0));
+ }
};
REGISTER_KERNEL_BUILDER(
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
index efe29216c2..e6407174b1 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler
+from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops
from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops
from tensorflow.python.framework import constant_op
@@ -46,6 +47,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
multiclass_strategy,
init_stamp_token=0,
loss_uses_sum_reduction=False,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE,
name=None):
"""Initialize the internal state for this split handler.
@@ -66,6 +68,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
stamped objects.
loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
SUM or MEAN reduction was used for the loss.
+ weak_learner_type: Specifies the type of weak learner to use.
name: An optional handler name.
"""
super(EqualitySplitHandler, self).__init__(
@@ -85,6 +88,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
hessian_shape,
name="StatsAccumulator/{}".format(self._name))
self._sparse_int_column = sparse_int_column
+ self._weak_learner_type = weak_learner_type
def update_stats(self, stamp_token, example_partition_ids, gradients,
hessians, empty_gradients, empty_hessians, weights,
@@ -197,7 +201,8 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
tree_complexity_regularization=self._tree_complexity_regularization,
min_node_weight=self._min_node_weight,
bias_feature_id=_BIAS_FEATURE_ID,
- multiclass_strategy=self._multiclass_strategy))
+ multiclass_strategy=self._multiclass_strategy,
+ weak_learner_type=self._weak_learner_type))
# There are no warm-up rounds needed in the equality column handler. So we
# always return ready.
are_splits_ready = constant_op.constant(True)
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
index ef253e7cec..d9f03c3840 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
@@ -169,6 +169,117 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(1, split_node.feature_id)
+ def testObliviousFeatureSplitGeneration(self):
+ with self.test_session() as sess:
+ # The data looks like the following:
+ # Example | Gradients | Partition | Feature ID |
+ # i0 | (0.2, 0.12) | 1 | 1 |
+ # i1 | (-0.5, 0.07) | 1 | 2 |
+ # i2 | (1.2, 0.2) | 1 | 1 |
+ # i3 | (4.0, 0.13) | 2 | 2 |
+ gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
+ hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
+ partition_ids = [1, 1, 1, 2]
+ indices = [[0, 0], [1, 0], [2, 0], [3, 0]]
+ values = array_ops.constant([1, 2, 1, 2], dtype=dtypes.int64)
+
+ gradient_shape = tensor_shape.scalar()
+ hessian_shape = tensor_shape.scalar()
+ class_id = -1
+
+ split_handler = categorical_split_handler.EqualitySplitHandler(
+ l1_regularization=0.1,
+ l2_regularization=1,
+ tree_complexity_regularization=0,
+ min_node_weight=0,
+ sparse_int_column=sparse_tensor.SparseTensor(indices, values, [4, 1]),
+ feature_column_group_id=0,
+ gradient_shape=gradient_shape,
+ hessian_shape=hessian_shape,
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ init_stamp_token=0,
+ weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ empty_gradients, empty_hessians = get_empty_tensors(
+ gradient_shape, hessian_shape)
+ example_weights = array_ops.ones([4, 1], dtypes.float32)
+
+ update_1 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ update_2 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+
+ with ops.control_dependencies([update_1, update_2]):
+ are_splits_ready, partitions, gains, splits = (
+ split_handler.make_splits(0, 1, class_id))
+ are_splits_ready, partitions, gains, splits = (
+ sess.run([are_splits_ready, partitions, gains, splits]))
+ self.assertTrue(are_splits_ready)
+ self.assertAllEqual([1, 2], partitions)
+
+ # For partition 1.
+ # -(0.2 + 1.2 - 0.1) / (0.12 + 0.2 + 1)
+ expected_left_weight1 = -0.9848484848484846
+ # (0.2 + 1.2 - 0.1) ** 2 / (0.12 + 0.2 + 1)
+ expected_left_gain1 = 1.2803030303030298
+
+ # -(-0.5 + 0.1) / (0.07 + 1)
+ expected_right_weight1 = 0.37383177570093457
+
+ # (-0.5 + 0.1) ** 2 / (0.07 + 1)
+ expected_right_gain1 = 0.14953271028037385
+
+ # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
+ expected_bias_gain1 = 0.46043165467625885
+
+ split_info = split_info_pb2.ObliviousSplitInfo()
+ split_info.ParseFromString(splits[0])
+ # Children of partition 1.
+ left_child = split_info.children[0].vector
+ right_child = split_info.children[1].vector
+ split_node = split_info.split_node.oblivious_categorical_id_binary_split
+
+ self.assertEqual(0, split_node.feature_column)
+ self.assertEqual(1, split_node.feature_id)
+ self.assertAllClose([expected_left_weight1], left_child.value, 0.00001)
+ self.assertAllClose([expected_right_weight1], right_child.value, 0.00001)
+
+ # For partition2.
+ expected_left_weight2 = 0
+ expected_left_gain2 = 0
+ # -(4 - 0.1) / (0.13 + 1)
+ expected_right_weight2 = -3.4513274336283186
+ # (4 - 0.1) ** 2 / (0.13 + 1)
+ expected_right_gain2 = 13.460176991150442
+ # (4 - 0.1) ** 2 / (0.13 + 1)
+ expected_bias_gain2 = 13.460176991150442
+
+ # Children of partition 2.
+ left_child = split_info.children[2].vector
+ right_child = split_info.children[3].vector
+ self.assertAllClose([expected_left_weight2], left_child.value, 0.00001)
+ self.assertAllClose([expected_right_weight2], right_child.value, 0.00001)
+
+ self.assertAllClose(
+ expected_left_gain1 + expected_right_gain1 - expected_bias_gain1 +
+ expected_left_gain2 + expected_right_gain2 - expected_bias_gain2,
+ gains[0], 0.00001)
+
def testGenerateFeatureSplitCandidatesSumReduction(self):
with self.test_session() as sess:
# The data looks like the following:
diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc
index 3ed6c5c04d..64921faf81 100644
--- a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc
+++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc
@@ -111,6 +111,18 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config,
node_id++;
break;
}
+ case TreeNode::kObliviousCategoricalIdBinarySplit: {
+ const auto& split =
+ current_node.oblivious_categorical_id_binary_split();
+ oblivious_leaf_idx <<= 1;
+ const auto& features =
+ example.sparse_int_features[split.feature_column()];
+ if (features.find(split.feature_id()) == features.end()) {
+ oblivious_leaf_idx++;
+ }
+ node_id++;
+ break;
+ }
case TreeNode::NODE_NOT_SET: {
LOG(QFATAL) << "Invalid node in tree: " << current_node.DebugString();
break;
@@ -181,6 +193,11 @@ void DecisionTree::LinkChildren(const std::vector<int32>& children,
<< "Not implemented for the ObliviousDenseFloatBinarySplit case.";
break;
}
+ case TreeNode::kObliviousCategoricalIdBinarySplit: {
+ LOG(QFATAL)
+ << "Not implemented for the ObliviousCategoricalIdBinarySplit case.";
+ break;
+ }
case TreeNode::NODE_NOT_SET: {
LOG(QFATAL) << "A non-set node cannot have children.";
break;
@@ -220,6 +237,11 @@ std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) {
<< "Not implemented for the ObliviousDenseFloatBinarySplit case.";
return {};
}
+ case TreeNode::kObliviousCategoricalIdBinarySplit: {
+ LOG(QFATAL)
+ << "Not implemented for the ObliviousCategoricalIdBinarySplit case.";
+ break;
+ }
case TreeNode::NODE_NOT_SET: {
return {};
}
diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
index 9b68a9de96..f1e12a028a 100644
--- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
@@ -179,6 +179,7 @@ REGISTER_OP("BuildCategoricalEqualitySplits")
.Input("tree_complexity_regularization: float")
.Input("min_node_weight: float")
.Input("multiclass_strategy: int32")
+ .Input("weak_learner_type: int32")
.Output("output_partition_ids: int32")
.Output("gains: float32")
.Output("split_infos: string")
@@ -224,6 +225,8 @@ min_node_weight: A scalar, minimum sum of example hessian needed in a child.
be considered.
multiclass_strategy: A scalar, specifying the multiclass handling strategy.
See LearnerConfig.MultiClassStrategy for valid values.
+weak_learner_type: A scalar, specifying the weak learner type to use.
+ See LearnerConfig.WeakLearnerType for valid values.
output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
for.
gains: A rank 1 tensor, for the computed gain for the created splits.
diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto
index 500909bf2a..520b4f8b11 100644
--- a/tensorflow/contrib/boosted_trees/proto/tree_config.proto
+++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto
@@ -16,6 +16,7 @@ message TreeNode {
CategoricalIdSetMembershipBinarySplit
categorical_id_set_membership_binary_split = 6;
ObliviousDenseFloatBinarySplit oblivious_dense_float_binary_split = 7;
+ ObliviousCategoricalIdBinarySplit oblivious_categorical_id_binary_split = 8;
}
TreeNodeMetadata node_metadata = 777;
}
@@ -116,6 +117,17 @@ message ObliviousDenseFloatBinarySplit {
// leaves.
}
+// Split rule for categorical features with a single feature Id in the oblivious
+// case.
+message ObliviousCategoricalIdBinarySplit {
+ // Categorical feature column and Id describing the rule feature == Id.
+ int32 feature_column = 1;
+ int64 feature_id = 2;
+ // We don't store children ids, because either the next node represents the
+ // whole next layer of the tree or starting with the next node we only have
+ // leaves.
+}
+
// DecisionTreeConfig describes a list of connected nodes.
// Node 0 must be the root and can carry any payload including a leaf
// in the case of representing the bias.
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
index 5e62bad672..74917f7cde 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
@@ -541,7 +541,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
feature_column_group_id=0,
bias_feature_id=-1,
class_id=-1,
- multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@@ -637,7 +638,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
feature_column_group_id=0,
bias_feature_id=-1,
class_id=-1,
- multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN))
+ multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@@ -674,7 +676,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
feature_column_group_id=0,
bias_feature_id=-1,
class_id=-1,
- multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = (sess.run([partitions, gains, splits]))
self.assertEqual(0, len(partitions))
self.assertEqual(0, len(gains))
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index 97743ba255..b008c6e534 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -762,7 +762,8 @@ class GradientBoostedDecisionTreeModel(object):
hessian_shape=self._hessian_shape,
multiclass_strategy=strategy_tensor,
init_stamp_token=init_stamp_token,
- loss_uses_sum_reduction=loss_uses_sum_reduction))
+ loss_uses_sum_reduction=loss_uses_sum_reduction,
+ weak_learner_type=weak_learner_type))
fc_name_idx += 1
# Create ensemble stats variables.
@@ -1063,6 +1064,12 @@ class GradientBoostedDecisionTreeModel(object):
# Grow the ensemble given the current candidates.
sizes = array_ops.unstack(split_sizes)
partition_ids_list = list(array_ops.split(partition_ids, sizes, axis=0))
+ # When using the oblivious decision tree as weak learner, it produces
+ # one gain and one split per handler and not number of partitions.
+ if self._learner_config.weak_learner_type == (
+ learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE):
+ sizes = len(training_state.handlers)
+
gains_list = list(array_ops.split(gains, sizes, axis=0))
split_info_list = list(array_ops.split(split_infos, sizes, axis=0))
return training_ops.grow_tree_ensemble(
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
index f7867d882d..73e41bc457 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from google.protobuf import text_format
from tensorflow.contrib import layers
+from tensorflow.contrib import learn
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
from tensorflow.contrib.boosted_trees.python.ops import model_ops
@@ -314,6 +315,162 @@ class GbdtTest(test_util.TensorFlowTestCase):
}"""
self.assertProtoEquals(expected_tree, output.trees[0])
+ def testObliviousDecisionTreeAsWeakLearner(self):
+ with self.test_session():
+ ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.learning_rate_tuner.fixed.learning_rate = 1
+ learner_config.regularization.l1 = 0
+ learner_config.regularization.l2 = 0
+ learner_config.constraints.max_tree_depth = 2
+ learner_config.constraints.min_node_weight = 0
+ learner_config.weak_learner_type = (
+ learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
+ learner_config.pruning_mode = learner_pb2.LearnerConfig.PRE_PRUNE
+ learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER
+ features = {}
+ features["dense_float"] = array_ops.constant([[-2], [-1], [1], [2]],
+ dtypes.float32)
+
+ gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
+ is_chief=True,
+ num_ps_replicas=0,
+ center_bias=False,
+ ensemble_handle=ensemble_handle,
+ examples_per_layer=1,
+ learner_config=learner_config,
+ logits_dimension=1,
+ features=features)
+
+ predictions_dict = gbdt_model.predict(learn.ModeKeys.TRAIN)
+ predictions = predictions_dict["predictions"]
+ labels = array_ops.constant([[-2], [-1], [1], [2]], dtypes.float32)
+ weights = array_ops.ones([4, 1], dtypes.float32)
+
+ train_op = gbdt_model.train(
+ loss=math_ops.reduce_mean(
+ _squared_loss(labels, weights, predictions)),
+ predictions_dict=predictions_dict,
+ labels=labels)
+ variables.global_variables_initializer().run()
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # On first run, expect no splits to be chosen because the quantile
+ # buckets will not be ready.
+ train_op.run()
+ stamp_token, serialized = model_ops.tree_ensemble_serialize(
+ ensemble_handle)
+ output = tree_config_pb2.DecisionTreeEnsembleConfig()
+ output.ParseFromString(serialized.eval())
+ self.assertEquals(len(output.trees), 0)
+ self.assertEquals(len(output.tree_weights), 0)
+ self.assertEquals(stamp_token.eval(), 1)
+
+ # Second run.
+ train_op.run()
+ stamp_token, serialized = model_ops.tree_ensemble_serialize(
+ ensemble_handle)
+ output = tree_config_pb2.DecisionTreeEnsembleConfig()
+ output.ParseFromString(serialized.eval())
+ self.assertEquals(len(output.trees), 1)
+ self.assertAllClose(output.tree_weights, [1])
+ self.assertEquals(stamp_token.eval(), 2)
+ expected_tree = """
+ nodes {
+ oblivious_dense_float_binary_split {
+ threshold: -1.0
+ }
+ node_metadata {
+ gain: 4.5
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -1.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 1.5
+ }
+ }
+ }"""
+ self.assertProtoEquals(expected_tree, output.trees[0])
+ # Third run.
+ train_op.run()
+ stamp_token, serialized = model_ops.tree_ensemble_serialize(
+ ensemble_handle)
+ output = tree_config_pb2.DecisionTreeEnsembleConfig()
+ output.ParseFromString(serialized.eval())
+ self.assertEquals(len(output.trees), 1)
+ self.assertAllClose(output.tree_weights, [1])
+ self.assertEquals(stamp_token.eval(), 3)
+ expected_tree = """
+ nodes {
+ oblivious_dense_float_binary_split {
+ threshold: -1.0
+ }
+ node_metadata {
+ gain: 4.5
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ oblivious_dense_float_binary_split {
+ threshold: -2.0
+ }
+ node_metadata {
+ gain: 0.25
+ original_oblivious_leaves {
+ vector {
+ value: -1.5
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: 1.5
+ }
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -2.0
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -1.0
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 1.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 1.5
+ }
+ }
+ }"""
+ self.assertProtoEquals(expected_tree, output.trees[0])
+
def testTrainFnChiefSparseAndDense(self):
"""Tests the train function with sparse and dense features."""
with self.test_session() as sess:
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index f6c928e2be..ebcabb4223 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -364,7 +364,7 @@ if (tensorflow_ENABLE_MKL_SUPPORT)
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn_copy_shared_to_destination)
include_directories(${mkldnn_INCLUDE_DIRS})
else (tensorflow_ENABLE_MKLDNN_SUPPORT)
- add_definitions(-DINTEL_MKL_ML)
+ add_definitions(-DINTEL_MKL_ML_ONLY)
endif()
endif (tensorflow_ENABLE_MKL_SUPPORT)
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index b86a543fc3..34f594f741 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -293,6 +293,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 42adfd17f0..9d8e955245 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -720,6 +720,42 @@ class RestructuredDatasetTest(test.TestCase):
def test_assert_element_shape(self):
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(5).map(create_dataset)
+ expected_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 4)))
+ self.assertEqual(expected_shapes, dataset.output_shapes)
+
+ result = dataset.apply(batching.assert_element_shape(expected_shapes))
+ self.assertEqual(expected_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(3).map(create_dataset)
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 10)))
+ with self.assertRaises(ValueError):
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+
+ def test_assert_element_shape_on_unknown_shape_dataset(self):
+
def create_unknown_shape_dataset(x):
return script_ops.py_func(
lambda _: ( # pylint: disable=g-long-lambda
@@ -748,7 +784,60 @@ class RestructuredDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def test_assert_wrong_element_shape(self):
+ def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 10)))
+ iterator = (
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+ def test_assert_partial_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(5).map(create_dataset)
+ partial_expected_shape = (tensor_shape.TensorShape(None), # Unknown shape
+ tensor_shape.TensorShape((None, 4))) # Partial shape
+ result = dataset.apply(
+ batching.assert_element_shape(partial_expected_shape))
+ # Partial shapes are merged with actual shapes:
+ actual_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 4)))
+ self.assertEqual(actual_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_partial_element_shape(self):
def create_dataset(_):
return (array_ops.ones(2, dtype=dtypes.float32),
@@ -756,11 +845,41 @@ class RestructuredDatasetTest(test.TestCase):
dataset = dataset_ops.Dataset.range(3).map(create_dataset)
wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 10)))
+ tensor_shape.TensorShape((None, 10)))
with self.assertRaises(ValueError):
dataset.apply(batching.assert_element_shape(wrong_shapes))
- def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):
+ def test_assert_partial_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ expected_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((None, 4)))
+ result = dataset.apply(batching.assert_element_shape(expected_shapes))
+ self.assertEqual(expected_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_partial_element_shape_on_unknown_shape_dataset(self):
def create_unknown_shape_dataset(x):
return script_ops.py_func(
@@ -776,7 +895,7 @@ class RestructuredDatasetTest(test.TestCase):
self.assertEqual(unknown_shapes, dataset.output_shapes)
wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 10)))
+ tensor_shape.TensorShape((None, 10)))
iterator = (
dataset.apply(batching.assert_element_shape(wrong_shapes))
.make_initializable_iterator())
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
index 9b1857de1a..9020a499c4 100644
--- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
@@ -84,7 +84,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
# Use chi-squared test to assert that the observed distribution matches the
# expected distribution. Based on the implementation in
# "tensorflow/python/kernel_tests/multinomial_op_test.py".
- for probs in [[.85, .05, .1], rand_probs]:
+ for probs in [[.85, .05, .1], rand_probs, [1.]]:
probs = np.asarray(probs)
classes = len(probs)
freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples)
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
index 446bf8d749..089717156c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
@@ -18,10 +18,13 @@ from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
+import numpy as np
from tensorflow.contrib.data.python.ops import optimization
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@@ -62,7 +65,7 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
"Asserted next 2 transformations but encountered only 1."):
sess.run(get_next)
- def testDefaultOptimizations(self):
+ def testOptimizationDefault(self):
dataset = dataset_ops.Dataset.range(10).apply(
optimization.assert_next(
["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
@@ -75,7 +78,7 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testEmptyOptimizations(self):
+ def testOptimizationEmpty(self):
dataset = dataset_ops.Dataset.range(10).apply(
optimization.assert_next(
["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
@@ -88,7 +91,7 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testOptimization(self):
+ def testOptimizationFusion(self):
dataset = dataset_ops.Dataset.range(10).apply(
optimization.assert_next(
["MapAndBatch"])).map(lambda x: x * x).batch(10).apply(
@@ -101,11 +104,9 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testStatefulFunctionOptimization(self):
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next([
- "MapAndBatch"
- ])).map(lambda _: random_ops.random_uniform([])).batch(10).apply(
+ def testOptimizationStatefulFunction(self):
+ dataset = dataset_ops.Dataset.range(10).map(
+ lambda _: random_ops.random_uniform([])).batch(10).apply(
optimization.optimize(["map_and_batch_fusion"]))
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
@@ -113,6 +114,30 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
with self.test_session() as sess:
sess.run(get_next)
+ def testOptimizationLargeInputFromTensor(self):
+ input_t = array_ops.placeholder(dtypes.int32, (None, None, None))
+ dataset = dataset_ops.Dataset.from_tensors(input_t).apply(
+ optimization.optimize())
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
+ sess.run(get_next)
+
+ def testOptimizationLargeInputFromTensorSlices(self):
+ input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None))
+ dataset = dataset_ops.Dataset.from_tensor_slices(input_t).apply(
+ optimization.optimize())
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
+ sess.run(get_next)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index 64fe6dae24..fd00cdc5c6 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -47,22 +47,50 @@ class ReadBatchFeaturesTest(
# Basic test: read from file 0.
self.outputs = self.make_batch_feature(
filenames=self.test_filenames[0],
+ label_key="label",
num_epochs=num_epochs,
batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(sess, batch_size, 0, num_epochs=num_epochs)
+ self.verify_records(
+ sess,
+ batch_size,
+ 0,
+ num_epochs=num_epochs,
+ label_key_provided=True)
with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess)
+ self._next_actual_batch(sess, label_key_provided=True)
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
# Basic test: read from file 1.
self.outputs = self.make_batch_feature(
filenames=self.test_filenames[1],
+ label_key="label",
num_epochs=num_epochs,
batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(sess, batch_size, 1, num_epochs=num_epochs)
+ self.verify_records(
+ sess,
+ batch_size,
+ 1,
+ num_epochs=num_epochs,
+ label_key_provided=True)
with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess)
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from both files.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ label_key_provided=True)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
@@ -90,7 +118,7 @@ class ReadBatchFeaturesTest(
with self.test_session() as sess:
sess.run(init_op)
- for file_batch, _, _, _, record_batch in self._next_expected_batch(
+ for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
range(self._num_files), 2, 10):
actual_batch = sess.run(next_element)
self.assertAllEqual(file_batch, actual_batch["file"])
@@ -155,6 +183,25 @@ class ReadBatchFeaturesTest(
with self.session(graph=g) as sess:
self.outputs = self.make_batch_feature(
filenames=self.test_filenames,
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ reader_num_threads=reader_num_threads,
+ parser_num_threads=parser_num_threads).make_one_shot_iterator(
+ ).get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ label_key_provided=True,
+ interleave_cycle_length=reader_num_threads)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
num_epochs=num_epochs,
batch_size=batch_size,
reader_num_threads=reader_num_threads,
@@ -175,16 +222,20 @@ class ReadBatchFeaturesTest(
# Basic test: read from file 0.
outputs = self.make_batch_feature(
filenames=self.test_filenames[0],
+ label_key="label",
num_epochs=num_epochs,
batch_size=batch_size,
drop_final_batch=True).make_one_shot_iterator().get_next()
- for _, tensor in outputs.items():
+ for tensor in nest.flatten(outputs):
if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
self.assertEqual(tensor.shape[0], batch_size)
def testIndefiniteRepeatShapeInference(self):
dataset = self.make_batch_feature(
- filenames=self.test_filenames[0], num_epochs=None, batch_size=32)
+ filenames=self.test_filenames[0],
+ label_key="label",
+ num_epochs=None,
+ batch_size=32)
for shape, clazz in zip(nest.flatten(dataset.output_shapes),
nest.flatten(dataset.output_classes)):
if issubclass(clazz, ops.Tensor):
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
index e63bc4c720..08b9f03816 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
@@ -76,6 +76,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
filenames,
num_epochs,
batch_size,
+ label_key=None,
reader_num_threads=1,
parser_num_threads=1,
shuffle=False,
@@ -91,8 +92,10 @@ class ReadBatchFeaturesTestBase(test.TestCase):
features={
"file": parsing_ops.FixedLenFeature([], dtypes.int64),
"record": parsing_ops.FixedLenFeature([], dtypes.int64),
- "keywords": parsing_ops.VarLenFeature(dtypes.string)
+ "keywords": parsing_ops.VarLenFeature(dtypes.string),
+ "label": parsing_ops.FixedLenFeature([], dtypes.string),
},
+ label_key=label_key,
reader=core_readers.TFRecordDataset,
num_epochs=self.num_epochs,
shuffle=shuffle,
@@ -101,7 +104,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
parser_num_threads=parser_num_threads,
drop_final_batch=drop_final_batch)
- def _record(self, f, r):
+ def _record(self, f, r, l):
example = example_pb2.Example(
features=feature_pb2.Features(
feature={
@@ -114,7 +117,11 @@ class ReadBatchFeaturesTestBase(test.TestCase):
"keywords":
feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(
- value=self._get_keywords(f, r)))
+ value=self._get_keywords(f, r))),
+ "label":
+ feature_pb2.Feature(
+ bytes_list=feature_pb2.BytesList(
+ value=[compat.as_bytes(l)]))
}))
return example.SerializeToString()
@@ -139,23 +146,30 @@ class ReadBatchFeaturesTestBase(test.TestCase):
filenames.append(fn)
writer = python_io.TFRecordWriter(fn)
for j in range(self._num_records):
- writer.write(self._record(i, j))
+ writer.write(self._record(i, j, "fake-label"))
writer.close()
return filenames
- def _run_actual_batch(self, outputs, sess):
- file_op = outputs["file"]
- keywords_indices_op = outputs["keywords"].indices
- keywords_values_op = outputs["keywords"].values
- keywords_dense_shape_op = outputs["keywords"].dense_shape
- record_op = outputs["record"]
+ def _run_actual_batch(self, outputs, sess, label_key_provided=False):
+ if label_key_provided:
+ # outputs would be a tuple of (feature dict, label)
+ label_op = outputs[1]
+ features_op = outputs[0]
+ else:
+ features_op = outputs
+ label_op = features_op["label"]
+ file_op = features_op["file"]
+ keywords_indices_op = features_op["keywords"].indices
+ keywords_values_op = features_op["keywords"].values
+ keywords_dense_shape_op = features_op["keywords"].dense_shape
+ record_op = features_op["record"]
return sess.run([
file_op, keywords_indices_op, keywords_values_op,
- keywords_dense_shape_op, record_op
+ keywords_dense_shape_op, record_op, label_op
])
- def _next_actual_batch(self, sess):
- return self._run_actual_batch(self.outputs, sess)
+ def _next_actual_batch(self, sess, label_key_provided=False):
+ return self._run_actual_batch(self.outputs, sess, label_key_provided)
def _interleave(self, iterators, cycle_length):
pending_iterators = iterators
@@ -188,7 +202,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
def _next_record(file_indices):
for j in file_indices:
for i in range(self._num_records):
- yield j, i
+ yield j, i, compat.as_bytes("fake-label")
def _next_record_interleaved(file_indices, cycle_length):
return self._interleave([_next_record([i]) for i in file_indices],
@@ -200,6 +214,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
keywords_batch_max_len = 0
record_batch = []
batch_index = 0
+ label_batch = []
for _ in range(num_epochs):
if cycle_length == 1:
next_records = _next_record(file_indices)
@@ -208,6 +223,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
for record in next_records:
f = record[0]
r = record[1]
+ label_batch.append(record[2])
file_batch.append(f)
record_batch.append(r)
keywords = self._get_keywords(f, r)
@@ -219,7 +235,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
if len(file_batch) == batch_size:
yield [
file_batch, keywords_batch_indices, keywords_batch_values,
- [batch_size, keywords_batch_max_len], record_batch
+ [batch_size, keywords_batch_max_len], record_batch, label_batch
]
file_batch = []
keywords_batch_indices = []
@@ -227,10 +243,11 @@ class ReadBatchFeaturesTestBase(test.TestCase):
keywords_batch_max_len = 0
record_batch = []
batch_index = 0
+ label_batch = []
if file_batch:
yield [
file_batch, keywords_batch_indices, keywords_batch_values,
- [len(file_batch), keywords_batch_max_len], record_batch
+ [len(file_batch), keywords_batch_max_len], record_batch, label_batch
]
def verify_records(self,
@@ -238,6 +255,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
batch_size,
file_index=None,
num_epochs=1,
+ label_key_provided=False,
interleave_cycle_length=1):
if file_index is not None:
file_indices = [file_index]
@@ -245,8 +263,12 @@ class ReadBatchFeaturesTestBase(test.TestCase):
file_indices = range(self._num_files)
for expected_batch in self._next_expected_batch(
- file_indices, batch_size, num_epochs, interleave_cycle_length):
- actual_batch = self._next_actual_batch(sess)
+ file_indices,
+ batch_size,
+ num_epochs,
+ cycle_length=interleave_cycle_length):
+ actual_batch = self._next_actual_batch(
+ sess, label_key_provided=label_key_provided)
for i in range(len(expected_batch)):
self.assertAllEqual(expected_batch[i], actual_batch[i])
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
index 53c22628c7..43067b4245 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
@@ -19,7 +19,6 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
@@ -175,45 +174,5 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
-class FeatureStatsDatasetTest(
- stats_dataset_test_base.StatsDatasetTestBase,
- reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
-
- def testFeaturesStats(self):
- num_epochs = 5
- total_records = num_epochs * self._num_records
- batch_size = 2
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = self.make_batch_feature(
- filenames=self.test_filenames[0],
- num_epochs=num_epochs,
- batch_size=batch_size,
- shuffle=True,
- shuffle_seed=5,
- drop_final_batch=False).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.test_session() as sess:
- sess.run(iterator.initializer)
- for _ in range(total_records // batch_size + 1 if total_records %
- batch_size else total_records // batch_size):
- sess.run(next_element)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_stats:features", total_records)
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_stats:feature-values", total_records)
- self._assertSummaryHasSum(
- sess.run(summary_t), "record_stats:features", total_records * 3)
- self._assertSummaryHasSum(
- sess.run(summary_t), "record_stats:feature-values",
- self._sum_keywords(1) * num_epochs + 2 * total_records)
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 9f059942a6..9c2001c34f 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -647,15 +647,17 @@ def assert_element_shape(expected_shapes):
"""Assert the shape of this `Dataset`.
```python
- shapes = [tf.TensorShape([16, 256]), tf.TensorShape(None)]
+ shapes = [tf.TensorShape([16, 256]), tf.TensorShape([None, 2])]
result = dataset.apply(tf.contrib.data.assert_element_shape(shapes))
- print(result.output_shapes) # ==> "((16, 256), <unknown>)"
+ print(result.output_shapes) # ==> "((16, 256), (<unknown>, 2))"
```
If dataset shapes and expected_shape, are fully defined, assert they match.
Otherwise, add assert op that will validate the shapes when tensors are
evaluated, and set shapes on tensors, respectively.
+ Note that unknown dimension in `expected_shapes` will be ignored.
+
Args:
expected_shapes: A nested structure of `tf.TensorShape` objects.
@@ -664,20 +666,31 @@ def assert_element_shape(expected_shapes):
`tf.data.Dataset.apply`
"""
+ def _merge_output_shapes(original_shapes, expected_shapes):
+ flat_original_shapes = nest.flatten(original_shapes)
+ flat_new_shapes = nest.flatten_up_to(original_shapes, expected_shapes)
+ flat_merged_output_shapes = [
+ original_shape.merge_with(new_shape)
+ for original_shape, new_shape in zip(flat_original_shapes,
+ flat_new_shapes)]
+ return nest.pack_sequence_as(original_shapes, flat_merged_output_shapes)
+
def _check_shape(*elements):
flatten_tensors = nest.flatten(elements)
flatten_shapes = nest.flatten(expected_shapes)
checked_tensors = [
- with_shape(shape, tensor)
+ with_shape(shape, tensor) if shape else tensor # Ignore unknown shape
for shape, tensor in zip(flatten_shapes, flatten_tensors)
]
return nest.pack_sequence_as(elements, checked_tensors)
def _apply_fn(dataset):
+ output_shapes = _merge_output_shapes(dataset.output_shapes,
+ expected_shapes)
return _RestructuredDataset(
dataset.map(_check_shape),
dataset.output_types,
- output_shapes=expected_shapes,
+ output_shapes=output_shapes,
output_classes=dataset.output_classes)
return _apply_fn
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index 54a92ab185..38c0a09c33 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -235,6 +235,12 @@ def sample_from_datasets(datasets, weights=None, seed=None):
# to weights.
logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0)
+ # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it
+ # is a `Dataset`, it is possible that evaluating it has a side effect the
+ # user depends on.
+ if len(datasets) == 1:
+ return datasets[0]
+
def select_dataset_constant_logits(seed):
return array_ops.squeeze(
stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 29005859d7..7f09ba71dc 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -659,6 +659,7 @@ def make_batched_features_dataset(file_pattern,
batch_size,
features,
reader=core_readers.TFRecordDataset,
+ label_key=None,
reader_args=None,
num_epochs=None,
shuffle=True,
@@ -671,6 +672,9 @@ def make_batched_features_dataset(file_pattern,
drop_final_batch=False):
"""Returns a `Dataset` of feature dictionaries from `Example` protos.
+ If label_key argument is provided, returns a `Dataset` of tuple
+ comprising of feature dictionaries and label.
+
Example:
```
@@ -721,6 +725,9 @@ def make_batched_features_dataset(file_pattern,
reader: A function or class that can be
called with a `filenames` tensor and (optional) `reader_args` and returns
a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`.
+ label_key: (Optional) A string corresponding to the key labels are stored in
+ `tf.Examples`. If provided, it must be one of the `features` key,
+ otherwise results in `ValueError`.
reader_args: Additional arguments to pass to the reader class.
num_epochs: Integer specifying the number of times to read through the
dataset. If None, cycles through the dataset forever. Defaults to `None`.
@@ -746,8 +753,11 @@ def make_batched_features_dataset(file_pattern,
`False`.
Returns:
- A dataset of `dict` elements. Each `dict` maps feature keys to
- `Tensor` or `SparseTensor` objects.
+ A dataset of `dict` elements, (or a tuple of `dict` elements and label).
+ Each `dict` maps feature keys to `Tensor` or `SparseTensor` objects.
+
+ Raises:
+ ValueError: If `label_key` is not one of the `features` keys.
"""
# Create dataset of all matching filenames
filenames = _get_file_names(file_pattern, False)
@@ -786,9 +796,13 @@ def make_batched_features_dataset(file_pattern,
parsing_ops.parse_example_dataset(
features, num_parallel_calls=parser_num_threads))
- # TODO(rachelim): Add an optional label_name argument for extracting the label
- # from the features dictionary, to comply with the type expected by the
- # input_fn to a `tf.Estimator.train` or `tf.Estimator.evaluate` function.
+ if label_key:
+ if label_key not in features:
+ raise ValueError(
+ "The `label_key` provided (%r) must be one of the `features` keys." %
+ label_key)
+ dataset = dataset.map(lambda x: (x, x.pop(label_key)))
+
dataset = dataset.prefetch(prefetch_buffer_size)
return dataset
diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py
index 3b4e981402..8426228992 100644
--- a/tensorflow/contrib/data/python/ops/stats_ops.py
+++ b/tensorflow/contrib/data/python/ops/stats_ops.py
@@ -178,29 +178,6 @@ def latency_stats(tag):
return _apply_fn
-# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
-def feature_stats(tag):
- """Records the features stats from `Example` records of the input dataset.
-
- To consume the statistics, associate a `StatsAggregator` with the output
- dataset.
-
- Args:
- tag: String. All statistics recorded by the returned transformation will be
- associated with the given `tag`.
-
- Returns:
- A `Dataset` transformation function, which can be passed to
- `tf.data.Dataset.apply`.
- """
-
- def _apply_fn(dataset):
- return _StatsDataset(dataset, gen_dataset_ops.feature_stats_dataset, tag)
-
- return _apply_fn
-
-
class _StatsDataset(dataset_ops.Dataset):
"""A `Dataset` that acts as an identity, and also records statistics."""
diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD
index 02feeafb60..a87a5624c8 100644
--- a/tensorflow/contrib/distribute/BUILD
+++ b/tensorflow/contrib/distribute/BUILD
@@ -36,5 +36,6 @@ py_library(
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/distribute:distribute_config",
+ "//tensorflow/python/distribute:distribute_coordinator",
],
)
diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md
index ba92ea0b12..30e1992c01 100644
--- a/tensorflow/contrib/distribute/README.md
+++ b/tensorflow/contrib/distribute/README.md
@@ -12,26 +12,108 @@ models and training code with minimal changes to enable distributed training.
Moreover, we've designed the API in such a way that it works with both eager and
graph execution.
-Currently we support one type of strategy, called
-[`MirroredStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/MirroredStrategy).
-It does in-graph replication with synchronous training
+Currently we support several types of strategies:
+
+* [`MirroredStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/MirroredStrategy):
+This does in-graph replication with synchronous training
on many GPUs on one machine. Essentially, we create copies of all variables in
the model's layers on each device. We then use all-reduce to combine gradients
across the devices before applying them to the variables to keep them in sync.
-In the future, we intend to support other kinds of training configurations such
-as multi-node, synchronous,
-[asynchronous](https://www.tensorflow.org/deploy/distributed#putting_it_all_together_example_trainer_program),
-parameter servers and model parallelism.
+* [`CollectiveAllReduceStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/CollectiveAllReduceStrategy):
+This is a version of `MirroredStrategy` for multi-working training. It uses
+a collective op to do all-reduce. This supports between-graph communication and
+synchronization, and delegates the specifics of the all-reduce implementation to
+the runtime (as opposed to encoding it in the graph). This allows it to perform
+optimizations like batching and switch between plugins that support different
+hardware or algorithms. In the future, this strategy will implement
+fault-tolerance to allow training to continue when there is worker failure.
+
+* [`ParameterServerStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/ParameterServerStrategy):
+This strategy supports using parameter servers either for multi-GPU local
+training or asynchronous multi-machine training. When used to train locally,
+variables are not mirrored, instead they placed on the CPU and operations are
+replicated across all local GPUs. In a multi-machine setting, some are
+designated as workers and some as parameter servers. Each variable is placed on
+one parameter server. Computation operations are replicated across all GPUs of
+the workers.
+
+## Multi-GPU Training
+
+## Example with Keras API
+
+Let's see how to scale to multiple GPUs on one machine using `MirroredStrategy` with [tf.keras] (https://www.tensorflow.org/guide/keras).
+
+Take a very simple model consisting of a single layer:
+
+```python
+inputs = tf.keras.layers.Input(shape=(1,))
+predictions = tf.keras.layers.Dense(1)(inputs)
+model = tf.keras.models.Model(inputs=inputs, outputs=predictions)
+```
-## Example
+Let's also define a simple input dataset for training this model. Note that currently we require using
+[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset)
+with `DistributionStrategy`.
+
+```python
+features = tf.data.Dataset.from_tensors([1.]).repeat(10000).batch(10)
+labels = tf.data.Dataset.from_tensors([1.]).repeat(10000).batch(10)
+train_dataset = tf.data.Dataset.zip((features, labels))
+```
-Let's demonstrate how to use this API with a simple example. We will use the
-[`Estimator`](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator)
-approach, and show you how to scale your model to run on multiple GPUs on one
-machine using `MirroredStrategy`.
-Let's consider a very simple model function which tries to learn a simple
-function.
+To distribute this Keras model on multiple GPUs using `MirroredStrategy` we
+first instantiate a `MirroredStrategy` object.
+
+```python
+distribution = tf.contrib.distribute.MirroredStrategy()
+```
+
+We then compile the Keras model and pass the `MirroredStrategy` object in the
+`distribute` argument (apart from other usual arguments like `loss` and
+`optimizer`).
+
+```python
+model.compile(loss='mean_squared_error',
+ optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.2),
+ distribute=strategy)
+```
+
+To train the model we call Keras `fit` API using the input dataset that we
+created earlier, same as how we would in a non-distributed case.
+
+```python
+model.fit(train_dataset, epochs=5, steps_per_epoch=10)
+```
+
+Similarly, we can also call `evaluate` and `predict` as before using appropriate
+datasets.
+
+```python
+model.evaluate(eval_dataset)
+model.predict(predict_dataset)
+```
+
+That's all you need to train your model with Keras on multiple GPUs with
+`MirroredStrategy`. It will take care of splitting up
+the input dataset, replicating layers and variables on each device, and
+combining and applying gradients.
+
+The model and input code does not have to change because we have changed the
+underlying components of TensorFlow (such as
+optimizer, batch norm and summaries) to become distribution-aware.
+That means those components know how to
+combine their state across devices. Further, saving and checkpointing works
+seamlessly, so you can save with one or no distribution strategy and resume with
+another.
+
+
+## Example with Estimator API
+
+You can also use Distribution Strategy API with [`Estimator`](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator). Let's see a simple example of it's usage with `MirroredStrategy`.
+
+
+Consider a very simple model function which tries to learn a simple function.
```python
def model_fn(features, labels, mode):
@@ -53,17 +135,14 @@ def model_fn(features, labels, mode):
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
```
-Let's also define a simple input function to feed data for training this model.
-Note that we require using
-[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset)
-with `DistributionStrategy`.
+Again, let's define a simple input function to feed data for training this model.
```python
def input_fn():
features = tf.data.Dataset.from_tensors([[1.]]).repeat(100)
labels = tf.data.Dataset.from_tensors(1.).repeat(100)
- return dataset_ops.Dataset.zip((features, labels))
+ return tf.data.Dataset.zip((features, labels))
```
Now that we have a model function and input function defined, we can define the
@@ -80,20 +159,14 @@ distribution = tf.contrib.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(train_distribute=distribution)
classifier = tf.estimator.Estimator(model_fn=model_fn, config=config)
classifier.train(input_fn=input_fn)
+classifier.evaluate(input_fn=input_fn)
```
That's it! This change will now configure estimator to run on all GPUs on your
-machine, with the `MirroredStrategy` approach. It will take care of distributing
-the input dataset, replicating layers and variables on each device, and
-combining and applying gradients.
+machine.
-The model and input functions do not have to change because we have changed the
-underlying components of TensorFlow (such as
-optimizer, batch norm and summaries) to become distribution-aware.
-That means those components know how to
-combine their state across devices. Further, saving and checkpointing works
-seamlessly, so you can save with one or no distribution strategy and resume with
-another.
+
+## Customization and Performance Tips
Above, we showed the easiest way to use [`MirroredStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/MirroredStrategy#__init__).
There are few things you can customize in practice:
@@ -103,8 +176,6 @@ of GPUs (using param `num_gpus`), in case you don't want auto detection.
* You can specify various parameters for all reduce with the `cross_tower_ops`
param, such as the all reduce algorithm to use, and gradient repacking.
-## Performance Tips
-
We've tried to make it such that you get the best performance for your existing
model. We also recommend you follow the tips from
[Input Pipeline Performance Guide](https://www.tensorflow.org/performance/datasets_performance).
@@ -113,15 +184,177 @@ and [`dataset.prefetch`](https://www.tensorflow.org/performance/datasets_perform
in the input function gives a solid boost in performance. When using
`dataset.prefetch`, use `buffer_size=None` to let it detect optimal buffer size.
+## Multi-worker Training
+### Overview
+
+For multi-worker training, no code change is required to the `Estimator` code.
+You can run the same model code for all tasks in your cluster including
+parameter servers and the evaluator. But you need to use
+`tf.estimator.train_and_evaluator`, explicitly specify `num_gpus_per_workers`
+for your strategy object, and set "TF\_CONFIG" environment variables for each
+binary running in your cluster. We'll provide a Kubernetes template in the
+[tensorflow/ecosystem](https://github.com/tensorflow/ecosystem) repo which sets
+"TF\_CONFIG" for your training tasks.
+
+### TF\_CONFIG environment variable
+
+The "TF\_CONFIG" environment variables is a JSON string which specifies what
+tasks constitute a cluster, their addresses and each task's role in the cluster.
+One example of "TF\_CONFIG" is:
+
+```python
+TF_CONFIG='{
+ "cluster": {
+ "worker": ["host1:port", "host2:port", "host3:port"],
+ "ps": ["host4:port", "host5:port"]
+ },
+ "task": {"type": "worker", "index": 1}
+}'
+```
+
+This "TF\_CONFIG" specifies that there are three workers and two ps tasks in the
+cluster along with their hosts and ports. The "task" part specifies that the
+role of the current task in the cluster, worker 1. Valid roles in a cluster is
+"chief", "worker", "ps" and "evaluator". There should be no "ps" job for
+`CollectiveAllReduceStrategy` and `MirroredStrategy`. The "evaluator" job is
+optional and can have at most one task. It does single machine evaluation and if
+you don't want to do evaluation, you can pass in a dummy `input_fn` to the
+`tf.estimator.EvalSpec` of `tf.estimator.train_and_evaluate`.
+
+### Dataset
+
+The `input_fn` you provide to estimator code is for one worker. So remember to
+scale up your batch if you have multiple GPUs on each worker.
+
+The same `input_fn` will be used for all workers if you use
+`CollectiveAllReduceStrategy` and `ParameterServerStrategy`. Therefore it is
+important to shuffle your dataset in your `input_fn`.
+
+`MirroredStrategy` will insert a `tf.dataset.Dataset.shard` call in you
+`input_fn`. As a result, each worker gets a fraction of your input data.
+
+### Performance Tips
+
+We have been actively working on multi-worker performance. Currently, prefer
+`CollectiveAllReduceStrategy` for synchronous multi-worker training.
+
+### Example
+
+Let's use the same example for multi-worker. We'll start a cluster with 3
+workers doing synchronous all-reduce training. In the following code snippet, we
+start multi-worker training using `tf.estimator.train_and_evaluate`:
+
+
+```python
+def model_main():
+ estimator = ...
+ distribution = tf.contrib.distribute.CollectiveAllReduceStrategy(
+ num_gpus_per_worker=2)
+ config = tf.estimator.RunConfig(train_distribute=distribution)
+ train_spec = tf.estimator.TrainSpec(input_fn=input_fn)
+ eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
+ tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
+```
+
+
+**Note**: You don't have to set "TF\_CONFIG" manually if you use our provided
+Kubernetes template.
+
+You'll then need 3 machines, find out their host addresses and one available
+port on each machine. Then set "TF\_CONFIG" in each binary and run the above
+model code.
+
+In your worker 0, run:
+
+```python
+os.environ["TF_CONFIG"] = json.dumps({
+ "cluster": {
+ "worker": ["host1:port", "host2:port", "host3:port"]
+ },
+ "task": {"type": "worker", "index": 0}
+})
+
+# Call the model_main function defined above.
+model_main()
+```
+
+In your worker 1, run:
+
+```python
+os.environ["TF_CONFIG"] = json.dumps({
+ "cluster": {
+ "worker": ["host1:port", "host2:port", "host3:port"]
+ },
+ "task": {"type": "worker", "index": 1}
+})
+
+# Call the model_main function defined above.
+model_main()
+```
+
+In your worker 2, run:
+
+```python
+os.environ["TF_CONFIG"] = json.dumps({
+ "cluster": {
+ "worker": ["host1:port", "host2:port", "host3:port"]
+ },
+ "task": {"type": "worker", "index": 2}
+})
+
+# Call the model_main function defined above.
+model_main()
+```
+
+Then you'll find your cluster has started training! You can inspect the logs of
+workers or start a tensorboard.
+
+### Standalone client mode
+
+We have a new way to run distributed training. You can bring up standard
+tensorflow servers in your cluster and run your model code anywhere such as on
+your laptop.
+
+In the above example, instead of calling `model_main`, you can call
+`tf.contrib.distribute.run_standard_tensorflow_server().join()`. This will bring
+up a cluster running standard tensorflow servers which wait for your request to
+start training.
+
+On your laptop, you can run
+
+```python
+estimator = ...
+distribution = tf.contrib.distribute.CollectiveAllReduceStrategy(
+ num_gpus_per_worker=2)
+config = tf.estimator.RunConfig(
+ experimental_distribute=tf.contrib.distribute.DistributeConfig(
+ train_distribute=distribution,
+ remote_cluster={"worker": ["host1:port", "host2:port", "host3:port"]}))
+train_spec = tf.estimator.TrainSpec(input_fn=input_fn)
+eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
+tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
+```
+
+Then you will see the training logs on your laptop. You can terminate the
+training by terminating your process on your laptop. You can also modify your
+code and run a new model against the same cluster.
+
+We've been optimizing the performance of standalone client mode. If you notice
+high latency between your laptop and your cluster, you can reduce that latency
+by running your model binary in the cluster.
+
## Caveats
+
This feature is in early stages and there are a lot of improvements forthcoming:
* Summaries are only computed in the first tower in `MirroredStrategy`.
-* Evaluation is not yet distributed.
* Eager support is in the works; performance can be more challenging with eager
execution.
-* As mentioned earlier, multi-node and other distributed strategies will be
-introduced in the future.
+* We currently support the following predefined Keras callbacks:
+`ModelCheckpointCallback`, `TensorBoardCallback`. We will soon be adding support for
+some of the other callbacks such as `EarlyStopping`, `ReduceLROnPlateau`, etc. If you
+create your own callback, you will not have access to all model properties and
+validation data.
* If you are [`batching`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch)
your input data, we will place one batch on each GPU in each step. So your
effective batch size will be `num_gpus * batch_size`. Therefore, consider
diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py
index bf763215ba..350f81f60f 100644
--- a/tensorflow/contrib/distribute/__init__.py
+++ b/tensorflow/contrib/distribute/__init__.py
@@ -28,6 +28,7 @@ from tensorflow.contrib.distribute.python.parameter_server_strategy import Param
from tensorflow.contrib.distribute.python.step_fn import *
from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy
from tensorflow.python.distribute.distribute_config import DistributeConfig
+from tensorflow.python.distribute.distribute_coordinator import run_standard_tensorflow_server
from tensorflow.python.training.distribute import *
from tensorflow.python.training.distribution_strategy_context import *
@@ -56,6 +57,7 @@ _allowed_symbols = [
'get_tower_context',
'has_distribution_strategy',
'require_tower_context',
+ 'run_standard_tensorflow_server',
'UpdateContext',
]
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 94deb2a432..c524d8b394 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -279,10 +279,11 @@ cuda_py_test(
":strategy_test_lib",
"//tensorflow/python:distribute",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:layers",
+ "//tensorflow/python:state_ops",
"//tensorflow/python:variable_scope",
- "//tensorflow/python:array_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
index 2331444261..4fa8aa06cc 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -22,17 +22,16 @@ from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_
from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import values
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops
+from tensorflow.python.platform import tf_logging as logging
-# TODO(yuefengz): shard the dataset.
# TODO(yuefengz): support in-graph replication.
-# TODO(yuefengz): it only works with a cluster without a chief node, maybe
-# support chief node?
class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
"""Distribution strategy that uses collective ops for all-reduce.
@@ -51,42 +50,64 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
"""Initializes the object.
Args:
- num_gpus_per_worker: number of local GPUs or GPUs per worker.
+ num_gpus_per_worker: number of local GPUs or GPUs per worker, the default
+ is 0 meaning CPU only.
"""
self._num_gpus_per_worker = num_gpus_per_worker
- self._initialize(None, None, None)
-
- def _initialize(self, cluster_spec, task_type, task_id):
- if cluster_spec:
- if task_type is None or task_id is None:
- raise ValueError("When `cluster_spec` is given, you must also specify "
- "`task_type` and `task_id`")
- if task_type not in ["chief", "worker"]:
- raise ValueError(
- "Unrecognized task_type: %r, valid task types are: \"chief\", "
- "\"worker\"." % task_type)
- self._cluster_spec = multi_worker_util.normalize_cluster_spec(
- cluster_spec)
- worker_device = "/job:%s/task:%d" % (task_type, task_id)
- num_workers = len(self._cluster_spec.as_dict().get("worker", [])) + len(
- self._cluster_spec.as_dict().get("chief", []))
- if not num_workers:
- raise ValueError("No `worker` or `chief` tasks can be found in "
- "`cluster_spec`.")
-
- self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
- task_id)
+ self._initialize_local_worker(num_gpus_per_worker)
+
+ def _initialize_local_worker(self, num_gpus_per_worker):
+ """Initializes the object for local training."""
+ self._is_chief = True
+ self._num_workers = 1
+
+ if num_gpus_per_worker:
+ local_devices = [
+ "/device:GPU:%d" % i for i in range(num_gpus_per_worker)
+ ]
else:
- self._cluster_spec = None
- self._is_chief = True
- worker_device = ""
- num_workers = 1
- self._num_workers = num_workers
+ local_devices = ["/device:CPU:0"]
+
+ self._collective_keys = cross_tower_utils.CollectiveKeys()
+ super(CollectiveAllReduceStrategy, self).__init__(
+ devices=local_devices,
+ cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce(
+ num_workers=1,
+ num_gpus_per_worker=num_gpus_per_worker,
+ collective_keys=self._collective_keys))
+
+ self._cluster_spec = None
+ self._task_type = None
+ self._task_id = None
+
+ logging.info("CollectiveAllReduceStrategy with local_devices = %r",
+ local_devices)
- if self._num_gpus_per_worker:
+ def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec,
+ task_type, task_id):
+ """Initializes the object for multi-worker training."""
+ if task_type is None or task_id is None:
+ raise ValueError("When `cluster_spec` is given, you must also specify "
+ "`task_type` and `task_id`")
+ if task_type not in ["chief", "worker"]:
+ raise ValueError(
+ "Unrecognized task_type: %r, valid task types are: \"chief\", "
+ "\"worker\"." % task_type)
+ cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
+ self._num_workers = len(cluster_spec.as_dict().get("worker", [])) + len(
+ cluster_spec.as_dict().get("chief", []))
+ if not self._num_workers:
+ raise ValueError("No `worker` or `chief` tasks can be found in "
+ "`cluster_spec`.")
+
+ self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
+ task_id)
+
+ worker_device = "/job:%s/task:%d" % (task_type, task_id)
+ if num_gpus_per_worker:
local_devices = [
"%s/device:GPU:%d" % (worker_device, i)
- for i in range(self._num_gpus_per_worker)
+ for i in range(num_gpus_per_worker)
]
else:
local_devices = [worker_device]
@@ -95,14 +116,23 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
super(CollectiveAllReduceStrategy, self).__init__(
devices=local_devices,
cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce(
- num_workers=num_workers,
- num_gpus_per_worker=self._num_gpus_per_worker,
+ num_workers=self._num_workers,
+ num_gpus_per_worker=num_gpus_per_worker,
collective_keys=self._collective_keys))
# Add a default device so that ops without specified devices will not end up
# on other workers.
- if cluster_spec:
- self._default_device = "/job:%s/replica:0/task:%d" % (task_type, task_id)
+ self._default_device = "/job:%s/task:%d" % (task_type, task_id)
+
+ self._cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
+ self._task_type = task_type
+ self._task_id = task_id
+
+ logging.info(
+ "Multi-worker CollectiveAllReduceStrategy with "
+ "cluster_spec = %r, task_type = %r, task_id = %r, "
+ "num_workers = %r, local_devices = %r", cluster_spec.as_dict(),
+ task_type, task_id, self._num_workers, local_devices)
def _create_variable(self, next_creator, *args, **kwargs):
colocate_with = kwargs.pop("colocate_with", None)
@@ -166,6 +196,12 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
return mirrored_strategy._create_mirrored_variable(
devices, _real_mirrored_creator, *args, **kwargs)
+ def distribute_dataset(self, dataset_fn):
+ """Distributes the dataset to each local GPU."""
+ # TODO(yuefengz): shard the dataset.
+ return values.PerDeviceDataset(
+ self._call_dataset_fn(dataset_fn), self._devices, True)
+
def configure(self,
session_config=None,
cluster_spec=None,
@@ -183,11 +219,43 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
Raises:
ValueError: if `task_type` is not in the `cluster_spec`.
"""
- # TODO(yuefengz): we'll need to mutate the session_config to add
- # configurations for collective ops.
- del session_config
if not self._cluster_spec and cluster_spec:
- self._initialize(cluster_spec, task_type, task_id)
+ # If a `cluster_spec` is already passed in, do nothing here.
+ # TODO(yuefengz): check `cluster_spec` is the same if this object has
+ # already been initialized with a `cluster_spec`.
+ self._initialize_multi_worker(self._num_gpus_per_worker, cluster_spec,
+ task_type, task_id)
+
+ if not session_config or not self._cluster_spec:
+ return
+
+ assert self._task_type
+ assert self._task_id is not None
+
+ # Collective group leader is needed for collective ops to coordinate
+ # workers.
+ if "chief" in self._cluster_spec.jobs:
+ session_config.experimental.collective_group_leader = (
+ "/job:chief/replica:0/task:0")
+ else:
+ if "worker" not in self._cluster_spec.jobs:
+ raise ValueError(
+ "You must have `chief` or `worker` jobs in the `cluster_spec`.")
+ session_config.experimental.collective_group_leader = (
+ "/job:worker/replica:0/task:0")
+
+ # The device filters prevent communication between workers.
+ del session_config.device_filters[:]
+ session_config.device_filters.append(
+ "/job:%s/task:%d" % (self._task_type, self._task_id))
+
+ # The scoped_allocator_optimization is to optimize graphs for collective
+ # ops.
+ rewrite_options = session_config.graph_options.rewrite_options
+ rewrite_options.scoped_allocator_optimization = (
+ rewriter_config_pb2.RewriterConfig.ON)
+ del rewrite_options.scoped_allocator_opts.enable_op[:]
+ rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce")
@property
def between_graph(self):
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
index e284969b1a..36e9761073 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
@@ -62,7 +62,10 @@ class CollectiveAllReduceStrategyTestBase(
num_gpus_per_worker=num_gpus)
if task_type and task_id is not None:
distribution.configure(
- cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id)
+ session_config=self._sess_config,
+ cluster_spec=self._cluster_spec,
+ task_type=task_type,
+ task_id=task_id)
collective_keys = cross_tower_utils.CollectiveKeys(
group_key_start=10 * num_gpus +
CollectiveAllReduceStrategyTestBase.collective_key_base,
@@ -187,11 +190,6 @@ class DistributedCollectiveAllReduceStrategyTest(
cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
num_workers=3, num_ps=0)
- def setUp(self):
- super(DistributedCollectiveAllReduceStrategyTest, self).setUp()
- self._sess_config.experimental.collective_group_leader = (
- '/job:worker/replica:0/task:0')
-
@combinations.generate(
combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
def testMinimizeLossGraph(self, num_gpus):
@@ -221,8 +219,6 @@ class DistributedCollectiveAllReduceStrategyTestWithChief(
def setUp(self):
super(DistributedCollectiveAllReduceStrategyTestWithChief, self).setUp()
self._run_options.experimental.collective_graph_key = 7
- self._sess_config.experimental.collective_group_leader = (
- '/job:chief/replica:0/task:0')
@combinations.generate(
combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py
index 2a653b0f10..e08ba9c2a6 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py
@@ -35,13 +35,13 @@ from tensorflow.python.training import device_util
def check_destinations(destinations):
- """Checks whether `destinations` is not None and not empty.
+ """Checks whether `destinations` is not empty.
Args:
destinations: a DistributedValues, Variable, string or a list of strings.
Returns:
- Boolean indicating whether `destinations` is not None and not empty.
+ Boolean which is True if `destinations` is not empty.
"""
# Calling bool() on a ResourceVariable is not allowed.
if isinstance(destinations, resource_variable_ops.ResourceVariable):
@@ -56,7 +56,7 @@ def validate_destinations(destinations):
value_lib.AggregatingVariable, six.string_types, list)):
raise ValueError("destinations must be one of a `DistributedValues` object,"
" a tf.Variable object, a device string, a list of device "
- "strings or None")
+ "strings")
if not check_destinations(destinations):
raise ValueError("destinations can not be empty")
@@ -131,8 +131,7 @@ def _devices_match(left, right):
def _all_devices_match(value_destination_pairs):
- if not all([d is None or _devices_match(v, d)
- for v, d in value_destination_pairs]):
+ if not all([_devices_match(v, d) for v, d in value_destination_pairs]):
return False
if not all([_devices_match(v, value_destination_pairs[0][0])
for v, _ in value_destination_pairs[1:]]):
@@ -189,7 +188,7 @@ class CrossTowerOps(object):
def __init__(self):
pass
- def reduce(self, aggregation, per_device_value, destinations=None):
+ def reduce(self, aggregation, per_device_value, destinations):
"""Reduce `per_device_value` to `destinations`.
It runs the reduction operation defined by `aggregation` and put the
@@ -210,8 +209,7 @@ class CrossTowerOps(object):
if not isinstance(per_device_value, value_lib.PerDevice):
per_device_value = _make_tensor_into_per_device(per_device_value)
- if destinations is not None:
- validate_destinations(destinations)
+ validate_destinations(destinations)
return self._reduce(aggregation, per_device_value, destinations)
def batch_reduce(self, aggregation, value_destination_pairs):
@@ -224,9 +222,7 @@ class CrossTowerOps(object):
aggregation: Indicates how a variable will be aggregated. Accepted values
are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
value_destination_pairs: a list or a tuple of tuples of PerDevice objects
- (or tensors with device set if there is one tower) and destinations. If
- a destination is None, then the destinations are set to match the
- devices of the input PerDevice object.
+ (or tensors with device set if there is one tower) and destinations.
Returns:
a list of Mirrored objects.
@@ -242,8 +238,7 @@ class CrossTowerOps(object):
value_destination_pairs)
for _, d in value_destination_pairs:
- if d is not None:
- validate_destinations(d)
+ validate_destinations(d)
return self._batch_reduce(aggregation, value_destination_pairs)
@@ -573,7 +568,7 @@ class AllReduceCrossTowerOps(CrossTowerOps):
def _reduce(self, aggregation, per_device_value, destinations):
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
per_device_value)
- if ((destinations is None or _devices_match(per_device_value, destinations))
+ if (_devices_match(per_device_value, destinations)
and not context.executing_eagerly()
and not contains_indexed_slices):
return self._batch_all_reduce(aggregation, [per_device_value])[0]
@@ -602,8 +597,10 @@ class AllReduceCrossTowerOps(CrossTowerOps):
[v[0] for v in value_destination_pairs])
else:
if not all_devices_match:
- logging.warning("Efficient batch_reduce is not supported if "
- "destinations are different.")
+ logging.log_first_n(logging.WARN,
+ "Efficient batch_reduce is not supported if "
+ "destinations are different.",
+ 10)
return [
self._reduce(aggregation, t, destinations=v)
@@ -782,7 +779,7 @@ class CollectiveAllReduce(CrossTowerOps):
def __init__(self,
num_workers=1,
num_gpus_per_worker=0,
- all_reduce_merge_scope=1,
+ all_reduce_merge_scope=32,
collective_keys=None):
"""Initializes the object.
@@ -803,8 +800,15 @@ class CollectiveAllReduce(CrossTowerOps):
# TODO(yuefengz, tucker): is indexed slices supported by collective ops?
def _reduce(self, aggregation, per_device_value, destinations):
+ if cross_tower_utils.contains_indexed_slices(per_device_value):
+ raise ValueError(
+ "`IndexSlices` is not supported for Collective All-Reduce.")
+ if context.executing_eagerly():
+ raise ValueError(
+ "Eager execution is not supported for Collective All-Reduce")
+
all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0]
- if destinations is None or _devices_match(per_device_value, destinations):
+ if _devices_match(per_device_value, destinations):
return all_reduced
else:
index = {}
@@ -820,15 +824,33 @@ class CollectiveAllReduce(CrossTowerOps):
return value_lib.Mirrored(index)
def _batch_reduce(self, aggregation, value_destination_pairs):
- return [
- self._reduce(aggregation, t, destinations=v)
- for t, v in value_destination_pairs
- ]
+ if cross_tower_utils.contains_indexed_slices(value_destination_pairs):
+ raise ValueError(
+ "`IndexSlices` is not supported for Collective All-Reduce.")
+ if context.executing_eagerly():
+ raise ValueError(
+ "Eager execution is not supported for Collective All-Reduce")
+
+ all_devices_match = _all_devices_match(value_destination_pairs)
+ if all_devices_match:
+ return self._batch_all_reduce(aggregation,
+ [v[0] for v in value_destination_pairs])
+ else:
+ if not all_devices_match:
+ logging.log_first_n(
+ logging.WARN, "Efficient batch_reduce is not supported if "
+ "destinations are different.", 10)
+
+ return [
+ self._reduce(aggregation, t, destinations=v)
+ for t, v in value_destination_pairs
+ ]
def _batch_all_reduce(self, aggregation, per_device_values):
"""All-reduce across all workers in a batch."""
if context.executing_eagerly():
- raise ValueError("Eager mode with collective ops is not supported yet.")
+ raise ValueError(
+ "Eager execution with collective ops is not supported yet.")
logging.log_first_n(
logging.INFO, "Collective All-reduce invoked with batches size = %d, "
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
index 2ad91d56e9..490371477a 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
@@ -135,7 +135,7 @@ class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase):
destination_list = devices
all_destinations = [
- None, destination_mirrored, destination_different, destination_str,
+ destination_mirrored, destination_different, destination_str,
destination_list
]
@@ -146,24 +146,24 @@ class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase):
vs.VariableAggregation.MEAN,
per_device,
destinations=destinations),
- _fake_mirrored(mean, destinations or per_device))
+ _fake_mirrored(mean, destinations))
self._assert_values_equal(
cross_tower_ops.reduce(
vs.VariableAggregation.MEAN,
per_device_2,
destinations=destinations),
- _fake_mirrored(mean_2, destinations or per_device))
+ _fake_mirrored(mean_2, destinations))
self._assert_values_equal(
cross_tower_ops.reduce(
vs.VariableAggregation.SUM, per_device,
destinations=destinations),
- _fake_mirrored(mean * len(devices), destinations or per_device))
+ _fake_mirrored(mean * len(devices), destinations))
self._assert_values_equal(
cross_tower_ops.reduce(
vs.VariableAggregation.SUM,
per_device_2,
destinations=destinations),
- _fake_mirrored(mean_2 * len(devices), destinations or per_device))
+ _fake_mirrored(mean_2 * len(devices), destinations))
# test batch_reduce()
for d1, d2 in itertools.product(all_destinations, all_destinations):
@@ -171,25 +171,22 @@ class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase):
cross_tower_ops.batch_reduce(vs.VariableAggregation.MEAN,
[(per_device, d1), (per_device_2, d2)]),
[
- _fake_mirrored(mean, d1 or per_device),
- _fake_mirrored(mean_2, d2 or per_device_2)
+ _fake_mirrored(mean, d1),
+ _fake_mirrored(mean_2, d2)
])
self._assert_values_equal(
cross_tower_ops.batch_reduce(vs.VariableAggregation.SUM,
[(per_device, d1), (per_device_2, d2)]),
[
- _fake_mirrored(mean * len(devices), d1 or per_device),
- _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)
+ _fake_mirrored(mean * len(devices), d1),
+ _fake_mirrored(mean_2 * len(devices), d2)
])
# test broadcast()
for destinations in all_destinations:
- if destinations is None:
- continue
- else:
- self._assert_values_equal(
- cross_tower_ops.broadcast(constant_op.constant(1.), destinations),
- _fake_mirrored(1., destinations))
+ self._assert_values_equal(
+ cross_tower_ops.broadcast(constant_op.constant(1.), destinations),
+ _fake_mirrored(1., destinations))
class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase):
@@ -494,7 +491,7 @@ class MultiWorkerCollectiveAllReduceTest(
destination_list = devices
all_destinations = [
- destination_different, None, destination_mirrored, destination_str,
+ destination_different, destination_mirrored, destination_str,
destination_list
]
@@ -505,27 +502,27 @@ class MultiWorkerCollectiveAllReduceTest(
vs.VariableAggregation.MEAN,
per_device,
destinations=destinations),
- _fake_mirrored(mean, destinations or per_device), sess)
+ _fake_mirrored(mean, destinations), sess)
self._assert_values_equal(
collective_all_reduce.reduce(
vs.VariableAggregation.MEAN,
per_device_2,
destinations=destinations),
- _fake_mirrored(mean_2, destinations or per_device), sess)
+ _fake_mirrored(mean_2, destinations), sess)
self._assert_values_equal(
collective_all_reduce.reduce(
vs.VariableAggregation.SUM,
per_device,
destinations=destinations),
- _fake_mirrored(mean * len(devices) * num_workers, destinations or
- per_device), sess)
+ _fake_mirrored(mean * len(devices) * num_workers, destinations),
+ sess)
self._assert_values_equal(
collective_all_reduce.reduce(
vs.VariableAggregation.SUM,
per_device_2,
destinations=destinations),
- _fake_mirrored(mean_2 * len(devices) * num_workers, destinations or
- per_device), sess)
+ _fake_mirrored(mean_2 * len(devices) * num_workers, destinations),
+ sess)
# test batch_reduce()
for d1, d2 in itertools.product(all_destinations, all_destinations):
@@ -534,18 +531,16 @@ class MultiWorkerCollectiveAllReduceTest(
[(per_device, d1),
(per_device_2, d2)]),
[
- _fake_mirrored(mean, d1 or per_device),
- _fake_mirrored(mean_2, d2 or per_device_2)
+ _fake_mirrored(mean, d1),
+ _fake_mirrored(mean_2, d2)
], sess)
self._assert_values_equal(
collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM,
[(per_device, d1),
(per_device_2, d2)]),
[
- _fake_mirrored(mean * len(devices) * num_workers, d1 or
- per_device),
- _fake_mirrored(mean_2 * len(devices) * num_workers, d2 or
- per_device_2)
+ _fake_mirrored(mean * len(devices) * num_workers, d1),
+ _fake_mirrored(mean_2 * len(devices) * num_workers, d2)
], sess)
return True
diff --git a/tensorflow/contrib/distribute/python/input_ops.py b/tensorflow/contrib/distribute/python/input_ops.py
index 1f24f62947..f07ec8234d 100644
--- a/tensorflow/contrib/distribute/python/input_ops.py
+++ b/tensorflow/contrib/distribute/python/input_ops.py
@@ -47,11 +47,8 @@ def auto_shard_dataset(dataset, num_shards, index):
Returns:
A modified `Dataset` obtained by updating the pipeline sharded by the
- files.
-
- Raises:
- NotImplementedError: If we cannot automatically determine a good way to
- shard the input dataset.
+ files. The input dataset will be returned if we cannot automatically
+ determine a good way to shard the input dataset.
"""
# TODO(priyag): Clone datasets instead of updating in place, similar to the
@@ -127,8 +124,10 @@ def auto_shard_dataset(dataset, num_shards, index):
tf_logging.warn(
"Could not find a standard reader in the input pipeline"
"(one of TextLineDataset, TFRecordDataset, FixedLengthRecordDataset)."
- "Falling back to sharding the dataset anyway. Please verify"
- "correctness of auto-sharding for your input.")
+ "So auto-sharding is not done. Please verify correctness of "
+ "auto-sharding for your input.")
+ # TODO(yuefengz): maybe still shard it?
+ return dataset
# TODO(priyag): What do we want to do if the number of filenames is
# uneven in the number of shards? By default, this will just return as
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index e87b48ba41..d1235b7afb 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -65,7 +65,7 @@ class _RequestedStop(Exception):
pass
-# Make _call_for_each_tower and _reduce_non_distributed_value not members of
+# _call_for_each_tower and _reduce_non_distributed_value are not members of
# MirroredStrategy so that they are generally not allowed to use anything
# specific to MirroredStrategy and thus can be shared with other distribution
# strategies.
@@ -197,10 +197,12 @@ def _reduce_non_distributed_value(distribution, aggregation, value,
# and equal to 0.
if value == 0:
return 0
- # If the aggregation type is MEAN, then this essentially means that the same
- # value should be on all destinations.
- if aggregation == variable_scope.VariableAggregation.MEAN:
- return distribution.broadcast(value, destinations)
+ # If the aggregation type is MEAN or ONLY_FIRST_TOWER, then this
+ # essentially means that the same value should be on all destinations.
+ if aggregation in (
+ variable_scope.VariableAggregation.MEAN,
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER):
+ return value
cross_tower_ops_lib.validate_destinations(destinations)
# We do not support an aggregation type of SUM if the value is the same across
@@ -208,8 +210,8 @@ def _reduce_non_distributed_value(distribution, aggregation, value,
# and summing up identical values across towers is not clearly defined.
if (len(distribution.worker_devices) != 1 or
not cross_tower_ops_lib.check_destinations(destinations)):
- raise ValueError("A non-DistributedValues value cannot be reduced with the "
- "given aggregation.")
+ raise ValueError("A non-DistributedValues value %s cannot be reduced with "
+ "the given aggregation %s." % (value, aggregation))
# TODO(anjalisridhar): Moves these methods to a device utility file?
devices = cross_tower_ops_lib.get_devices_from(destinations)
if len(devices) == 1:
@@ -254,11 +256,12 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
# Get aggregation value
aggregation = kwargs.pop("aggregation",
variable_scope.VariableAggregation.NONE)
- if aggregation not in [
+ if aggregation not in (
variable_scope.VariableAggregation.NONE,
variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN
- ]:
+ variable_scope.VariableAggregation.MEAN,
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER
+ ):
raise ValueError("Invalid variable aggregation mode: " + aggregation +
" for variable: " + kwargs["name"])
@@ -337,6 +340,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
num_gpus: number of GPUs. For local training, either specify `devices` or
`num_gpus`. In distributed training, this must be specified as number of
GPUs on each worker.
+ num_gpus_per_worker: number of GPUs per worker. This is the same as
+ `num_gpus` and only one of `num_gpus` and `num_gpus_per_worker` can be
+ specified.
cross_tower_ops: optional, a descedant of `CrossTowerOps`. If this is not
set, the `configure` method will try to find the best one.
prefetch_on_device: optional boolean to specify whether to prefetch input
@@ -346,6 +352,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def __init__(self,
devices=None,
num_gpus=None,
+ num_gpus_per_worker=None,
cross_tower_ops=None,
prefetch_on_device=None):
super(MirroredStrategy, self).__init__()
@@ -353,9 +360,15 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
self._cross_tower_ops = cross_tower_ops
self._prefetch_on_device = prefetch_on_device
# Rememeber num GPUs which might be needed by `configure` method.
- self._num_gpus = num_gpus
+ if num_gpus is not None and num_gpus_per_worker is not None:
+ raise ValueError(
+ "You cannot specify both `num_gpus` and `num_gpus_per_worker`.")
+ if num_gpus is not None:
+ self._num_gpus = num_gpus
+ else:
+ self._num_gpus = num_gpus_per_worker
- self._initialize_local(num_gpus, devices)
+ self._initialize_local(self._num_gpus, devices)
def _initialize_local(self, num_gpus, devices):
"""Initializes the object for local training."""
@@ -564,8 +577,15 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
if self._cross_tower_ops is None:
if self._cluster_spec:
- self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce(
- self._workers, self._num_gpus)
+ # It currently cannot detect the toplogy of remote workers. So we
+ # hard-code the multi-worker all-reduce algorithm for now.
+ if len(self._workers) == 1:
+ # The default is "nccl".
+ self._cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps()
+ else:
+ # The default is hierarchical reduce and broadcast.
+ self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce(
+ self._workers, self._num_gpus)
else:
self._cross_tower_ops = cross_tower_ops_lib.choose_the_best(
self._devices, session_config=session_config)
@@ -584,10 +604,18 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
# which case `value` would be a single value or value could be 0.
return _reduce_non_distributed_value(self, aggregation, value,
destinations)
+ if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_TOWER:
+ value = value.get(self._devices[0])
+ if isinstance(value, (int, float)):
+ return value
+ return self.broadcast(value, destinations)
return self._get_cross_tower_ops().reduce(
aggregation, value, destinations=destinations)
def _batch_reduce(self, aggregation, value_destination_pairs):
+ if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_TOWER:
+ return [self.broadcast(v.get(self._devices[0]), d)
+ for v, d in value_destination_pairs]
return self._get_cross_tower_ops().batch_reduce(aggregation,
value_destination_pairs)
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index a12ff662db..c6894e9013 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -38,6 +38,7 @@ from tensorflow.python.layers import core
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import device_util
@@ -128,6 +129,25 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
expected = sum(range(len(dist.worker_devices)))
self.assertEqual(expected, self.evaluate(unwrapped[0]))
+ @test_util.run_in_graph_and_eager_modes
+ def testReduceOnlyFirstTowerUpdates(self):
+ if not GPU_TEST:
+ self.skipTest("Not GPU test")
+
+ def run_fn(device_id):
+ return constant_op.constant(3 + 5 * device_id)
+
+ dist = self._get_distribution_strategy()
+ with dist.scope():
+ result = dist.call_for_each_tower(run_fn, dist.worker_device_index)
+ reduced = dist.reduce(
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER,
+ result,
+ destinations="/device:CPU:0")
+ unwrapped = dist.unwrap(reduced)
+ self.assertEqual(1, len(unwrapped))
+ self.assertEqual(3, self.evaluate(unwrapped[0]))
+
@test_util.run_in_graph_and_eager_modes()
def testReduceToMultipleDestinations(self):
if not GPU_TEST:
@@ -384,6 +404,84 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
v3.aggregation)
@test_util.run_in_graph_and_eager_modes(config=config)
+ def testOnlyFirstTowerUpdatesVariables(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ def create_fn():
+ aggregation = variable_scope.VariableAggregation.ONLY_FIRST_TOWER
+ v0 = variable_scope.variable(
+ 2.0,
+ name="on_read",
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=aggregation)
+ v1 = variable_scope.variable(
+ 3.0,
+ name="on_write",
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation=aggregation)
+ return v0, v1
+
+ devices = ["/device:GPU:0", "/device:CPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ v0, v1 = dist.call_for_each_tower(create_fn, run_concurrently=False)
+ self.evaluate(v0.initializer)
+ self.assertEqual(2.0, self.evaluate(v0.get(devices[0])))
+ self.assertEqual(2.0, self.evaluate(v0.get(devices[1])))
+ self.assertEqual(2.0, self.evaluate(dist.read_var(v0)))
+ self.evaluate(v1.initializer)
+ self.assertEqual(3.0, self.evaluate(v1.get(devices[0])))
+ self.assertEqual(3.0, self.evaluate(v1.get(devices[1])))
+ self.assertEqual(3.0, self.evaluate(dist.read_var(v1)))
+
+ # Update using the assign_add member function.
+ def update_member_fn(device_id):
+ update0 = v0.assign_add(5.0 * (device_id + 1))
+ update1 = v1.assign_add(7.0 * (device_id + 1))
+ return update0, update1
+
+ update0a, update1a = dist.call_for_each_tower(
+ update_member_fn, dist.worker_device_index, run_concurrently=False)
+
+ # Update "sync on read" variable.
+ self.evaluate(dist.group(update0a))
+ self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0])))
+ # Writes are not synchronized for "sync on read" variables,
+ # so device[1] can end up with a different value.
+ self.assertEqual(2.0 + 2*5.0, self.evaluate(v0.get(devices[1])))
+ # Always reads from device 0.
+ self.assertEqual(2.0 + 5.0, self.evaluate(dist.read_var(v0)))
+
+ # Update "sync on write" variable.
+ self.evaluate(dist.group(update1a))
+ self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0])))
+ # Writes are synchronized for v1, only the argument to assign_add on
+ # device[0] is used.
+ self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1])))
+ self.assertEqual(3.0 + 7.0, self.evaluate(dist.read_var(v1)))
+
+ # Update using state_ops.assign_add global function.
+ def update_state_ops_fn(device_id):
+ update0 = state_ops.assign_add(v0, 11.0 * (device_id + 1))
+ update1 = state_ops.assign_add(v1, 13.0 * (device_id + 1))
+ return update0, update1
+
+ update0b, update1b = dist.call_for_each_tower(
+ update_state_ops_fn, dist.worker_device_index, run_concurrently=False)
+ self.evaluate(dist.group(update0b))
+
+ # Update "sync on read" variable.
+ self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0])))
+ self.assertEqual(2.0 + 2*5.0 + 2*11.0, self.evaluate(v0.get(devices[1])))
+ self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(dist.read_var(v0)))
+
+ # Update "sync on write" variable.
+ self.evaluate(dist.group(update1b))
+ self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0])))
+ self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1])))
+ self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(dist.read_var(v1)))
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
def testNoneSynchronizationWithGetVariable(self):
self._skip_eager_if_gpus_less_than(1)
devices = ["/device:CPU:0", "/device:GPU:0"]
@@ -804,8 +902,8 @@ class MirroredVariableUpdateTest(test.TestCase):
return mirrored_var.assign(5.0)
with self.assertRaisesRegexp(
- ValueError, "A non-DistributedValues value cannot be reduced with "
- "the given aggregation."):
+ ValueError, "A non-DistributedValues value 5.0 cannot be reduced "
+ "with the given aggregation VariableAggregation.SUM."):
self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn)))
@test_util.run_in_graph_and_eager_modes(config=config)
@@ -1287,7 +1385,8 @@ class MultiWorkerMirroredStrategyTestWithChief(
cls._default_target = "grpc://" + cls._cluster_spec["chief"][0]
def testMinimizeLossGraph(self):
- strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus())
+ strategy = mirrored_strategy.MirroredStrategy(
+ num_gpus_per_worker=context.num_gpus())
strategy.configure(cluster_spec=self._cluster_spec)
self._test_minimize_loss_graph(strategy, learning_rate=0.05)
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index 68561b5bbf..23b220f64b 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -67,6 +67,7 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
self._prefetch_on_device)
def _broadcast(self, tensor, destinations):
+ del destinations
return tensor
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
@@ -127,6 +128,7 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
return values.MapOutput([fn(m, *args, **kwargs) for m in map_over])
def _reduce(self, aggregation, value, destinations):
+ del destinations
if not isinstance(value, values.MapOutput):
return value
l = value.get()
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
index 361c8be590..88d7768b14 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import device_setter
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
@@ -82,19 +83,12 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
create conflicts of device assignment.
"""
- def __init__(self,
- num_gpus_per_worker=0,
- cluster_spec=None,
- task_type=None,
- task_id=None):
+ def __init__(self, num_gpus_per_worker=0):
"""Initializes this strategy.
Args:
- num_gpus_per_worker: number of local GPUs or GPUs per worker.
- cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
- cluster configurations.
- task_type: the current task type.
- task_id: the current task id.
+ num_gpus_per_worker: number of local GPUs or GPUs per worker, the default
+ is 0 meaning CPU only.
Raises:
ValueError: if `cluster_spec` is given but `task_type` or `task_id` is
@@ -102,24 +96,16 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
"""
super(ParameterServerStrategy, self).__init__()
self._num_gpus_per_worker = num_gpus_per_worker
- if cluster_spec:
- cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
- if task_type is None or task_id is None:
- raise ValueError("When `cluster_spec` is given, must also specify "
- "`task_type` and `task_id`.")
- self._cluster_spec = cluster_spec
+ self._initialize_local(num_gpus_per_worker)
# We typically don't need to do all-reduce in this strategy.
self._cross_tower_ops = (
cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
reduce_to_device=_LOCAL_CPU))
- self._initialize_devices(num_gpus_per_worker, cluster_spec, task_type,
- task_id)
-
- def _initialize_devices(self, num_gpus_per_worker, cluster_spec, task_type,
- task_id):
- """Initialize internal devices.
+ def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec,
+ task_type, task_id):
+ """Initialize devices for multiple workers.
It creates variable devices and compute devices. Variables and operations
will be assigned to them respectively. We have one compute device per tower.
@@ -137,85 +123,103 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
Raises:
ValueError: if the cluster_spec doesn't have ps jobs.
"""
- self._task_type = task_type or "worker"
- self._task_id = task_id or 0
- self._worker_device = "/job:%s/task:%d" % (self._task_type, self._task_id)
+ assert cluster_spec
+ if not task_type or task_id is None:
+ raise ValueError("When `cluster_spec` is given, you must also specify "
+ "`task_type` and `task_id`")
+ cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
- # TODO(yuefengz): maybe clearer to split it into two classes, one for
- # the distribuetd case and one for the local case, once we have the factory
- # class/method.
+ self._worker_device = "/job:%s/task:%d" % (self._task_type, self._task_id)
# Define compute devices which is a list of device strings and one for each
# tower. When there are GPUs, replicate operations on these GPUs. Otherwise,
# place operations on CPU.
- if cluster_spec is None:
- # Local mode.
- if num_gpus_per_worker > 0:
- self._compute_devices = list(
- map("/device:GPU:{}".format, range(num_gpus_per_worker)))
- else:
- self._compute_devices = [_LOCAL_CPU]
+ if num_gpus_per_worker > 0:
+ self._compute_devices = [
+ "%s/device:GPU:%d" % (self._worker_device, i)
+ for i in range(num_gpus_per_worker)
+ ]
else:
- # Distributed mode.
- if num_gpus_per_worker > 0:
- self._compute_devices = [
- "%s/device:GPU:%d" % (self._worker_device, i)
- for i in range(num_gpus_per_worker)
- ]
- else:
- self._compute_devices = [self._worker_device]
+ self._compute_devices = [self._worker_device]
self._compute_devices = list(
map(device_util.resolve, self._compute_devices))
self._canonical_compute_device_set = set(self._compute_devices)
- # Define variable device which is a device string in the local case and a
- # device function in the distributed case. It is used to open a device scope
- # where varibles are defined.
+ # In distributed mode, place variables on ps jobs in a round-robin fashion.
+ # Note that devices returned from `replica_device_setter` are not
+ # canonical and therefore we don't canonicalize all variable devices to
+ # make them consistent.
+ # TODO(yuefengz): support passing a strategy object to control variable
+ # assignment.
+ # TODO(yuefengz): merge the logic of replica_device_setter into this
+ # class.
+ num_ps_replicas = len(cluster_spec.as_dict().get("ps", []))
+ if num_ps_replicas == 0:
+ raise ValueError("The cluster spec needs to have `ps` jobs.")
+ self._variable_device = device_setter.replica_device_setter(
+ ps_tasks=num_ps_replicas,
+ worker_device=self._worker_device,
+ merge_devices=True,
+ cluster=cluster_spec)
+
# The `_parameter_devices` is needed for the `parameter_devices` property
- # and is a list of all variable devices.
- if cluster_spec is None:
- # Local mode. If there is only one GPU, put everything on that GPU.
- # Otherwise, place variables on CPU.
- if num_gpus_per_worker == 1:
- assert len(list(self._compute_devices)) == 1
- self._variable_device = _LOCAL_GPU_0
- self._parameter_devices = [_LOCAL_GPU_0]
- else:
- self._variable_device = _LOCAL_CPU
- self._parameter_devices = [_LOCAL_CPU]
+ # and is a list of all variable devices. Here parameter devices are all
+ # tasks of the "ps" job.
+ self._parameter_devices = map("/job:ps/task:{}".format,
+ range(num_ps_replicas))
+
+ # Add a default device so that ops without specified devices will not end up
+ # on other workers.
+ self._default_device = self._worker_device
+
+ self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
+ task_id)
+ self._cluster_spec = cluster_spec
+ self._task_type = task_type
+ self._task_id = task_id
+
+ logging.info(
+ "Multi-worker ParameterServerStrategy with "
+ "cluster_spec = %r, task_type = %r, task_id = %r, "
+ "num_ps_replicas = %r, is_chief = %r, compute_devices = %r, "
+ "variable_device = %r", cluster_spec.as_dict(), task_type, task_id,
+ num_ps_replicas, self._is_chief, self._compute_devices,
+ self._variable_device)
+
+ def _initialize_local(self, num_gpus_per_worker):
+ """Initialize internal devices for local training."""
+ # Define compute devices which is a list of device strings and one for each
+ # tower. When there are GPUs, replicate operations on these GPUs. Otherwise,
+ # place operations on CPU.
+ if num_gpus_per_worker > 0:
+ self._compute_devices = list(
+ map("/device:GPU:{}".format, range(num_gpus_per_worker)))
else:
- # Distributed mode. Place variables on ps jobs in a round-robin fashion.
- # Note that devices returned from `replica_device_setter` are not
- # canonical and therefore we don't canonicalize all variable devices to
- # make them consistent.
- # TODO(yuefengz): support passing a strategy object to control variable
- # assignment.
- # TODO(yuefengz): merge the logic of replica_device_setter into this
- # class.
- num_ps_replicas = len(cluster_spec.as_dict().get("ps", []))
- if num_ps_replicas == 0:
- raise ValueError("The cluster spec needs to have `ps` jobs.")
- self._variable_device = device_setter.replica_device_setter(
- ps_tasks=num_ps_replicas,
- worker_device=self._worker_device,
- merge_devices=True,
- cluster=cluster_spec)
-
- # Parameter devices are all tasks of the "ps" job.
- self._parameter_devices = map("/job:ps/task:{}".format,
- range(num_ps_replicas))
-
- # Define the default device in cross-tower mode. In the distributed case, we
- # set the default device to the corresponding worker to prevent these ops
- # from being placed on other workers.
- if cluster_spec is None:
- self._default_device = None
+ self._compute_devices = [_LOCAL_CPU]
+
+ self._compute_devices = list(
+ map(device_util.resolve, self._compute_devices))
+ self._canonical_compute_device_set = set(self._compute_devices)
+
+ # If there is only one GPU, put everything on that GPU. Otherwise, place
+ # variables on CPU.
+ if num_gpus_per_worker == 1:
+ assert len(list(self._compute_devices)) == 1
+ self._variable_device = _LOCAL_GPU_0
+ self._parameter_devices = [_LOCAL_GPU_0]
else:
- self._default_device = self._worker_device
+ self._variable_device = _LOCAL_CPU
+ self._parameter_devices = [_LOCAL_CPU]
- self._is_chief = cluster_spec is None or multi_worker_util.is_chief(
- cluster_spec, task_type, task_id)
+ self._is_chief = True
+ self._cluster_spec = None
+ self._task_type = None
+ self._task_id = None
+
+ logging.info(
+ "ParameterServerStrategy with compute_devices = %r, "
+ "variable_device = %r", self._compute_devices, self._variable_device)
def distribute_dataset(self, dataset_fn):
"""Distributes the dataset to each local GPU."""
@@ -235,7 +239,8 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
if aggregation not in (
vs.VariableAggregation.NONE,
vs.VariableAggregation.SUM,
- vs.VariableAggregation.MEAN
+ vs.VariableAggregation.MEAN,
+ vs.VariableAggregation.ONLY_FIRST_TOWER
):
raise ValueError("Invalid variable aggregation mode: " + aggregation +
" for variable: " + kwargs["name"])
@@ -302,10 +307,15 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
# pylint: disable=protected-access
return mirrored_strategy._reduce_non_distributed_value(
self, aggregation, value, destinations)
+ if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
+ return self.broadcast(value.get(self._compute_devices[0]), destinations)
return self._cross_tower_ops.reduce(
aggregation, value, destinations=destinations)
def _batch_reduce(self, aggregation, value_destination_pairs):
+ if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
+ return [self.broadcast(v.get(self._compute_devices[0]), d)
+ for v, d in value_destination_pairs]
for _, destinations in value_destination_pairs:
self._verify_destinations_not_different_worker(destinations)
return self._cross_tower_ops.batch_reduce(aggregation,
@@ -385,18 +395,33 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
ValueError: if `cluster_spec` is given but `task_type` or `task_id` is
not.
"""
- del session_config
-
- # Set the devices if cluster_spec is defined in TF_CONFIG but not passed in
- # the constructor.
if not self._cluster_spec and cluster_spec:
- self._cluster_spec = multi_worker_util.normalize_cluster_spec(
- cluster_spec)
+ # If a `cluster_spec` is already passed in, do nothing here.
+ # TODO(yuefengz): check `cluster_spec` is the same if this object has
+ # already been initialized with a `cluster_spec`.
if task_type is None or task_id is None:
raise ValueError("When `cluster_spec` is given, must also specify "
"`task_type` and `task_id`.")
- self._initialize_devices(self._num_gpus_per_worker, self._cluster_spec,
- task_type, task_id)
+ self._cluster_spec = multi_worker_util.normalize_cluster_spec(
+ cluster_spec)
+ self._task_type = task_type
+ self._task_id = task_id
+ self._initialize_multi_worker(self._num_gpus_per_worker,
+ self._cluster_spec, task_type, task_id)
+
+ if not session_config or not self._cluster_spec:
+ return
+
+ assert self._cluster_spec
+ assert self._task_type
+ assert self._task_id is not None
+
+ # The device filters prevent communication between workers.
+ if self._task_type not in ["chief", "worker"]:
+ return
+ del session_config.device_filters[:]
+ session_config.device_filters.extend(
+ ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"])
@property
def num_towers(self):
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index 0e2bfcec5f..12789e0bc9 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import copy
import threading
from absl.testing import parameterized
@@ -25,6 +26,7 @@ from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import parameter_server_strategy
from tensorflow.contrib.distribute.python import values
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.eager import context
from tensorflow.python.estimator import run_config
@@ -56,23 +58,30 @@ class ParameterServerStrategyTestBase(
self._init_reached = 0
self._finish_condition = threading.Condition()
self._finish_reached = 0
+ self._sess_config = config_pb2.ConfigProto(allow_soft_placement=True)
super(ParameterServerStrategyTestBase, self).setUp()
def _get_test_objects(self, task_type, task_id, num_gpus):
distribution = parameter_server_strategy.ParameterServerStrategy(
num_gpus_per_worker=num_gpus)
if not task_type:
- return distribution, ''
+ return distribution, '', self._sess_config
+ sess_config = copy.deepcopy(self._sess_config)
distribution.configure(
- cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id)
- return distribution, 'grpc://' + self._cluster_spec[WORKER][task_id]
+ session_config=sess_config,
+ cluster_spec=self._cluster_spec,
+ task_type=task_type,
+ task_id=task_id)
+ return (distribution, 'grpc://' + self._cluster_spec[WORKER][task_id],
+ sess_config)
def _test_device_assignment_distributed(self, task_type, task_id, num_gpus):
worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id)
- d, _ = self._get_test_objects(task_type, task_id, num_gpus)
+ d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus)
with ops.Graph().as_default(), \
- self.test_session(target=self._default_target) as sess, \
+ self.test_session(target=self._default_target,
+ config=sess_config) as sess, \
d.scope():
# Define a variable outside the call_for_each_tower scope. This is not
@@ -174,7 +183,8 @@ class ParameterServerStrategyTestBase(
variable_device='CPU',
num_gpus=0):
with ops.Graph().as_default(), \
- self.test_session(target=self._default_target) as sess, \
+ self.test_session(target=self._default_target,
+ config=self._sess_config) as sess, \
d.scope():
def model_fn():
@@ -268,7 +278,8 @@ class ParameterServerStrategyTestBase(
self.assertEqual(f_val, 46.0)
def _test_simple_increment(self, task_type, task_id, num_gpus):
- d, master_target = self._get_test_objects(task_type, task_id, num_gpus)
+ d, master_target, sess_config = self._get_test_objects(
+ task_type, task_id, num_gpus)
if hasattr(d, '_cluster_spec') and d._cluster_spec:
num_workers = len(d._cluster_spec.as_dict().get(WORKER))
if 'chief' in d._cluster_spec.as_dict():
@@ -276,7 +287,8 @@ class ParameterServerStrategyTestBase(
else:
num_workers = 1
with ops.Graph().as_default(), \
- self.test_session(target=master_target) as sess, \
+ self.test_session(target=master_target,
+ config=sess_config) as sess, \
d.scope():
def model_fn():
@@ -286,18 +298,22 @@ class ParameterServerStrategyTestBase(
y = variable_scope.get_variable(
'y', initializer=20.0,
aggregation=variable_scope.VariableAggregation.SUM)
+ z = variable_scope.get_variable(
+ 'z', initializer=30.0,
+ aggregation=variable_scope.VariableAggregation.ONLY_FIRST_TOWER)
# We explicitly make a constant tensor here to avoid complaints about
# summing non-distributed values.
one = constant_op.constant(1.0)
x_add = x.assign_add(one, use_locking=True)
y_add = y.assign_add(one, use_locking=True)
+ z_add = z.assign_add(one, use_locking=True)
- train_op = control_flow_ops.group([x_add, y_add])
- return x, y, train_op
+ train_op = control_flow_ops.group(x_add, y_add, z_add)
+ return x, y, z, train_op
- x, y, train_op = d.call_for_each_tower(model_fn)
- train_op = d.group(d.unwrap(train_op))
+ x, y, z, train_op = d.call_for_each_tower(model_fn)
+ train_op = d.group(train_op)
if context.num_gpus() < d._num_gpus_per_worker:
return True
@@ -323,21 +339,25 @@ class ParameterServerStrategyTestBase(
self._finish_condition.notify_all()
self._finish_condition.release()
- x_val, y_val = sess.run([x, y])
+ x_val, y_val, z_val = sess.run([x, y, z])
self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_towers)
self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_towers)
+ self.assertEqual(z_val, 30.0 + 1.0 * num_workers)
return (x_val == 10.0 + 1.0 * num_workers * d.num_towers and
- y_val == 20.0 + 1.0 * num_workers * d.num_towers)
+ y_val == 20.0 + 1.0 * num_workers * d.num_towers and
+ z_val == 30.0 + 1.0 * num_workers)
def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):
- d, master_target = self._get_test_objects(task_type, task_id, num_gpus)
+ d, master_target, sess_config = self._get_test_objects(
+ task_type, task_id, num_gpus)
assert hasattr(d, '_cluster_spec') and d._cluster_spec
num_workers = len(d._cluster_spec.as_dict().get(WORKER))
if CHIEF in d._cluster_spec.as_dict():
num_workers += 1
with ops.Graph().as_default(), \
- self.test_session(target=master_target) as sess, \
+ self.test_session(target=master_target,
+ config=sess_config) as sess, \
d.scope():
l = core.Dense(1, use_bias=False)
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py
index 6ee26e19ac..5d498fb629 100644
--- a/tensorflow/contrib/distribute/python/strategy_test_lib.py
+++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py
@@ -190,7 +190,8 @@ class DistributionTestBase(test.TestCase):
with d.scope():
map_in = [constant_op.constant(i) for i in range(10)]
map_out = d.map(map_in, lambda x, y: x * y, 2)
- observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out)
+ observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out,
+ "/device:CPU:0")
expected = 90 # 2 * (0 + 1 + ... + 9)
self.assertEqual(expected, observed.numpy())
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 6202a0750a..32d7444e42 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -73,70 +73,98 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
num_cores: Number of cores to use on the TPU. If None specified, then
auto-detect the cores and topology of the TPU system.
"""
- # TODO(isaprykin): Generalize the defaults. They are currently tailored for
- # the unit test.
+ # TODO(sourabhbajaj): OneDeviceStrategy should be initialized with the
+ # master node fetched from the cluster resolver.
super(TPUStrategy, self).__init__('/device:CPU:0')
self._tpu_cluster_resolver = tpu_cluster_resolver
self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver)
+ # TODO(sourabhbajaj): Change this from num_cores to metadata_override
self._num_cores_override = num_cores
# TODO(sourabhbajaj): Remove this once performance of running one step
# at a time is comparable to multiple steps.
self.steps_per_run = steps_per_run
- # TODO(frankchn): This should not be hardcoded here for pod purposes.
- self._host = self.tpu_host_cpu_device(0)
+ def _get_enqueue_op_per_host(self, host_id, iterator, input_shapes,
+ iterations):
+ """Create an enqueue op for a single host identified using host_id.
- def distribute_dataset(self, dataset_fn):
- # TODO(priyag): Perhaps distribute across cores here.
- return self._call_dataset_fn(dataset_fn)
-
- # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
- # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
- # a mechanism to infer the outputs of `fn`. Pending b/110550782.
- def _run_steps_on_dataset(self, fn, iterator, iterations,
- initial_loop_values=None):
+ The while_loop op returned will run `iterations` times and in each run
+ enqueue batches for each shard.
- shapes = nest.flatten(iterator.output_shapes)
- if any([not s.is_fully_defined() for s in shapes]):
- raise ValueError(
- 'TPU currently requires fully defined shapes. Either use '
- 'set_shape() on the input tensors or use '
- 'dataset.apply(map_and_batch(..., drop_remainder=True)).')
- types = nest.flatten(iterator.output_types)
+ Args:
+ host_id: integer, id of the host to run the enqueue ops on.
+ iterator: `tf.data` iterator to read the input data.
+ input_shapes: shape of inputs to be enqueue on the queue. This is same as
+ the value of `nest.flatten(iterator.output_shapes)`.
+ iterations: integer, number of iterations to be run; determines the
+ number of batches to be enqueued.
+
+ Returns:
+ while_loop_op running `iterations` times; in each run we enqueue a batch
+ on the infeed queue from the host with id `host_id` for each device shard.
+ """
+ host = self.get_host_cpu_device(host_id)
- def enqueue_ops_fn():
+ def _infeed_enqueue_ops_fn():
"""Enqueue ops for one iteration."""
control_deps = []
sharded_inputs = []
- # TODO(sourabhbajaj): Add support for TPU pods
- with ops.device(self._host):
- for _ in range(self.num_towers):
+ enqueue_ops = []
+
+ with ops.device(host):
+ for _ in range(self.num_towers_per_host):
# Use control dependencies to ensure a deterministic ordering.
with ops.control_dependencies(control_deps):
inputs = nest.flatten(iterator.get_next())
control_deps.extend(inputs)
sharded_inputs.append(inputs)
- enqueue_ops = []
for core_id, shard_input in enumerate(sharded_inputs):
enqueue_ops.append(
tpu_ops.infeed_enqueue_tuple(
- inputs=shard_input, shapes=shapes, device_ordinal=core_id))
+ inputs=shard_input,
+ shapes=input_shapes,
+ device_ordinal=core_id))
return enqueue_ops
def enqueue_ops_loop_body(i):
- with ops.control_dependencies(enqueue_ops_fn()):
+ """Callable for the loop body of the while_loop instantiated below."""
+ with ops.control_dependencies(_infeed_enqueue_ops_fn()):
return i + 1
- with ops.device(self._host):
- enqueue_ops = control_flow_ops.while_loop(
+ with ops.device(host):
+ enqueue_op_per_host = control_flow_ops.while_loop(
lambda i: i < iterations,
enqueue_ops_loop_body,
[constant_op.constant(0)],
parallel_iterations=1)
+ return enqueue_op_per_host
+
+ def distribute_dataset(self, dataset_fn):
+ # TODO(priyag): Perhaps distribute across cores here.
+ return self._call_dataset_fn(dataset_fn)
+
+ # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
+ # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
+ # a mechanism to infer the outputs of `fn`. Pending b/110550782.
+ def _run_steps_on_dataset(self, fn, iterator, iterations,
+ initial_loop_values=None):
+
+ shapes = nest.flatten(iterator.output_shapes)
+ if any([not s.is_fully_defined() for s in shapes]):
+ raise ValueError(
+ 'TPU currently requires fully defined shapes. Either use '
+ 'set_shape() on the input tensors or use '
+ 'dataset.apply(map_and_batch(..., drop_remainder=True)).')
+ types = nest.flatten(iterator.output_types)
+
+ enqueue_ops = [
+ self._get_enqueue_op_per_host(host_id, iterator, shapes, iterations)
+ for host_id in range(self.num_hosts)]
+
def dequeue_fn():
dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
return nest.pack_sequence_as(iterator.output_shapes, dequeued)
@@ -147,6 +175,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
initial_loop_values = nest.flatten(initial_loop_values)
ctx = values.MultiStepContext()
def run_fn(*args, **kwargs):
+ """Single step on the TPU device."""
del args, kwargs
fn_inputs = dequeue_fn()
if not isinstance(fn_inputs, tuple):
@@ -238,6 +267,9 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
if aggregation == vs.VariableAggregation.MEAN:
# TODO(jhseu): Revisit once we support model-parallelism.
value *= (1. / self.num_towers)
+ elif aggregation != vs.VariableAggregation.SUM:
+ raise NotImplementedError(
+ 'Currently only support sum & mean in TPUStrategy.')
return tpu_ops.cross_replica_sum(value)
cf_context = cf_context.outer_context
@@ -247,10 +279,12 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
devices = cross_tower_ops_lib.get_devices_from(destinations)
if len(devices) == 1:
assert device_util.canonicalize(devices[0]) == device_util.canonicalize(
- self._host)
+ self.get_host_cpu_device(0))
else:
raise ValueError('Multiple devices are not supported for TPUStrategy')
+ if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
+ return value[0]
output = math_ops.add_n(value)
if aggregation == vs.VariableAggregation.MEAN:
return output * (1. / len(value))
@@ -265,8 +299,15 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
def num_towers(self):
return self._num_cores_override or self._tpu_metadata.num_cores
- def tpu_host_cpu_device(self, host_id):
+ @property
+ def num_hosts(self):
+ return self._tpu_metadata.num_hosts
+
+ @property
+ def num_towers_per_host(self):
+ return self._tpu_metadata.num_of_cores_per_host
+
+ def get_host_cpu_device(self, host_id):
if self._tpu_cluster_resolver.get_master() in ('', 'local'):
return '/replica:0/task:0/device:CPU:0'
- return '/job:%s/task:%d/device:CPU:0' % ('tpu_worker', host_id)
-
+ return '/job:tpu_worker/task:%d/device:CPU:0' % (host_id,)
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 3ccaa2690e..fafa6384a1 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -340,10 +340,6 @@ class MirroredVariable(DistributedVariable, Mirrored,
"""Holds a map from device to variables whose values are kept in sync."""
def __init__(self, index, primary_var, aggregation):
- # Use a weakref to make it easy to map from the contained values
- # to the container without introducing a reference cycle.
- for v in six.itervalues(index):
- v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
self._primary_var = primary_var
self._aggregation = aggregation
super(MirroredVariable, self).__init__(index)
@@ -523,6 +519,8 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
return self._aggregation
def _get_cross_tower(self):
+ if self._aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
+ return self._primary_var
all_components = tuple(self._index.values())
# TODO(josh11b): Use a strategy-specific method.
total = math_ops.add_n(all_components)
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index 3602f4d128..15a85a28f5 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -521,6 +521,7 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
return worker_device_map, devices
def testDataDistributionOneDevicePerWorker(self):
+ self.skipTest("Temporarily disabled.")
worker_device_map, devices = self._cpu_devices()
with context.graph_mode():
dataset_fn = lambda: dataset_ops.Dataset.range(8)
@@ -528,6 +529,7 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
[[0, 1], [2, 3], [4, 5], [6, 7]])
def testDataDistributionTwoDevicePerWorker(self):
+ self.skipTest("Temporarily disabled.")
if context.num_gpus() < 1:
self.skipTest("A GPU is not available for this test.")
worker_device_map, devices = self._cpu_and_one_gpu_devices()
@@ -537,6 +539,7 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
[[0, 2, 1, 3], [4, 6, 5, 7]])
def testTupleDataset(self):
+ self.skipTest("Temporarily disabled.")
worker_device_map, devices = self._cpu_devices()
with context.graph_mode():
@@ -553,6 +556,7 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
expected_values)
def testInitializableIterator(self):
+ self.skipTest("Temporarily disabled.")
worker_device_map, devices = self._cpu_devices()
with context.graph_mode():
dataset_fn = lambda: dataset_ops.Dataset.range(8)
@@ -570,6 +574,7 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
[[0, 1], [2, 3], [4, 5], [6, 7]])
def testValueErrorForIterator(self):
+ self.skipTest("Temporarily disabled.")
# Incompatiable arguments.
with self.assertRaises(ValueError):
values.MultiWorkerDataIterator({"w1": None}, {"w1": "d1", "w2": "d2"})
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index a8d0d493ab..97c53ae2b9 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -445,7 +445,7 @@ cuda_py_test(
cuda_py_test(
name = "sinh_arcsinh_test",
- size = "small",
+ size = "medium",
srcs = ["python/kernel_tests/sinh_arcsinh_test.py"],
additional_deps = [
":distributions_py",
diff --git a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
index ee25d25b52..d60ee18586 100644
--- a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
+++ b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
@@ -147,11 +147,12 @@
" # random jittering\n",
" \n",
" # resizing to 286 x 286 x 3\n",
- " # method = 2 indicates using \"ResizeMethod.NEAREST_NEIGHBOR\"\n",
" input_image = tf.image.resize_images(input_image, [286, 286], \n",
- " align_corners=True, method=2)\n",
+ " align_corners=True, \n",
+ " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
" real_image = tf.image.resize_images(real_image, [286, 286], \n",
- " align_corners=True, method=2)\n",
+ " align_corners=True, \n",
+ " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
" \n",
" # randomly cropping to 256 x 256 x 3\n",
" stacked_image = tf.stack([input_image, real_image], axis=0)\n",
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
index 3f70f573b1..9d090e8429 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
@@ -195,12 +195,12 @@ class ResNet50(tf.keras.Model):
def __init__(self,
data_format,
- name=None,
+ name='',
trainable=True,
include_top=True,
pooling=None,
classes=1000):
- super(ResNet50, self).__init__(name='')
+ super(ResNet50, self).__init__(name=name)
valid_channel_values = ('channels_first', 'channels_last')
if data_format not in valid_channel_values:
diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py
index 505c94e971..513feb03b6 100644
--- a/tensorflow/contrib/estimator/python/estimator/baseline_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py
@@ -37,13 +37,13 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import checkpoint_utils
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import optimizer
from tensorflow.python.training import saver
@@ -339,7 +339,7 @@ class BaselineEstimatorTrainingTest(test.TestCase):
self.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
@@ -347,7 +347,7 @@ class BaselineEstimatorTrainingTest(test.TestCase):
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
mock_optimizer = test.mock.NonCallableMock(
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py
index 4e6eea8884..bdf8aeb2b8 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util.py
@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@@ -129,10 +130,25 @@ def remove_squeezable_dimensions(predictions, labels, name=None):
return predictions, labels
-def _all_equal(tensor0, tensor1):
- with ops.name_scope('all_equal', values=[tensor0, tensor1]) as scope:
+def _shape_tensor_compatible(expected_shape, actual_shape):
+ """Returns whether actual_shape is compatible with expected_shape.
+
+ Note that -1 in `expected_shape` is recognized as unknown dimension.
+
+ Args:
+ expected_shape: Integer list defining the expected shape, or tensor of same.
+ actual_shape: Shape of the tensor to test.
+ Returns:
+ New tensor.
+ """
+ with ops.name_scope('shape_tensor_equal',
+ values=[expected_shape, actual_shape]) as scope:
return math_ops.reduce_all(
- math_ops.equal(tensor0, tensor1, name='equal'), name=scope)
+ math_ops.logical_or(
+ math_ops.equal(expected_shape, -1),
+ math_ops.equal(expected_shape, actual_shape, 'equal'),
+ name='exclude_partial_shape'),
+ name=scope)
def _is_rank(expected_rank, actual_tensor):
@@ -153,6 +169,8 @@ def _is_rank(expected_rank, actual_tensor):
def _is_shape(expected_shape, actual_tensor, actual_shape=None):
"""Returns whether actual_tensor's shape is expected_shape.
+ Note that -1 in `expected_shape` is recognized as unknown dimension.
+
Args:
expected_shape: Integer list defining the expected shape, or tensor of same.
actual_tensor: Tensor to test.
@@ -164,15 +182,15 @@ def _is_shape(expected_shape, actual_tensor, actual_shape=None):
is_rank = _is_rank(array_ops.size(expected_shape), actual_tensor)
if actual_shape is None:
actual_shape = array_ops.shape(actual_tensor, name='actual')
- shape_equal = _all_equal(
- ops.convert_to_tensor(expected_shape, name='expected'),
- actual_shape)
+ shape_equal = _shape_tensor_compatible(expected_shape, actual_shape)
return math_ops.logical_and(is_rank, shape_equal, name=scope)
def _assert_shape_op(expected_shape, actual_tensor):
"""Asserts actual_tensor's shape is expected_shape.
+ Note that unknown dimension in `expected_shape` will be ignored.
+
Args:
expected_shape: List of integers defining the expected shape, or tensor of
same.
@@ -182,6 +200,9 @@ def _assert_shape_op(expected_shape, actual_tensor):
"""
with ops.name_scope('assert_shape', values=[actual_tensor]) as scope:
actual_shape = array_ops.shape(actual_tensor, name='actual')
+ if (isinstance(expected_shape, tensor_shape.TensorShape)
+ and not expected_shape.is_fully_defined()):
+ expected_shape = [d if d else -1 for d in expected_shape.as_list()]
is_shape = _is_shape(expected_shape, actual_tensor, actual_shape)
return control_flow_ops.Assert(
is_shape, [
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
index 9db2670304..2479fe5b8d 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables as variables_lib
@@ -185,6 +185,16 @@ class WithShapeTest(test.TestCase):
shape,
unexpected_shapes)
+ def test_with_shape_2x2_with_partial_expected_shape(self):
+ with self.test_session():
+ value = [[42, 43], [44, 45]]
+ actual_shape = [2, 2]
+ tensor = constant_op.constant(value, shape=actual_shape)
+ partial_expected_shape = tensor_shape.TensorShape([None, 2])
+ # Won't raise any exception here:
+ tensor_with_shape = tensor_util.with_shape(partial_expected_shape, tensor)
+ np.testing.assert_array_equal(value, tensor_with_shape.eval())
+
def test_with_shape_none(self):
with self.test_session():
tensor_no_shape = array_ops.placeholder(dtypes.float32)
diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc
index 693724b457..370a8caf6a 100644
--- a/tensorflow/contrib/image/kernels/image_ops.cc
+++ b/tensorflow/contrib/image/kernels/image_ops.cc
@@ -71,7 +71,6 @@ class ImageProjectiveTransform : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& images_t = ctx->input(0);
const Tensor& transform_t = ctx->input(1);
- const Tensor& shape_t = ctx->input(2);
OP_REQUIRES(ctx, images_t.shape().dims() == 4,
errors::InvalidArgument("Input images must have rank 4"));
OP_REQUIRES(ctx,
@@ -82,17 +81,28 @@ class ImageProjectiveTransform : public OpKernel {
ProjectiveGenerator<Device, T>::kNumParameters),
errors::InvalidArgument(
"Input transform should be num_images x 8 or 1 x 8"));
- OP_REQUIRES(ctx, shape_t.dims() == 1,
- errors::InvalidArgument("output shape must be 1-dimensional",
- shape_t.shape().DebugString()));
- OP_REQUIRES(ctx, shape_t.NumElements() == 2,
- errors::InvalidArgument("output shape must have two elements",
- shape_t.shape().DebugString()));
- auto shape_vec = shape_t.vec<int32>();
- int32 out_height = shape_vec(0);
- int32 out_width = shape_vec(1);
- OP_REQUIRES(ctx, out_height > 0 && out_width > 0,
- errors::InvalidArgument("output dimensions must be positive"));
+
+ int32 out_height, out_width;
+ // Kernel is shared by legacy "ImageProjectiveTransform" op with 2 args.
+ if (ctx->num_inputs() >= 3) {
+ const Tensor& shape_t = ctx->input(2);
+ OP_REQUIRES(ctx, shape_t.dims() == 1,
+ errors::InvalidArgument("output shape must be 1-dimensional",
+ shape_t.shape().DebugString()));
+ OP_REQUIRES(ctx, shape_t.NumElements() == 2,
+ errors::InvalidArgument("output shape must have two elements",
+ shape_t.shape().DebugString()));
+ auto shape_vec = shape_t.vec<int32>();
+ out_height = shape_vec(0);
+ out_width = shape_vec(1);
+ OP_REQUIRES(
+ ctx, out_height > 0 && out_width > 0,
+ errors::InvalidArgument("output dimensions must be positive"));
+ } else {
+ // Shape is N (batch size), H (height), W (width), C (channels).
+ out_height = images_t.shape().dim_size(1);
+ out_width = images_t.shape().dim_size(2);
+ }
Tensor* output_t;
OP_REQUIRES_OK(ctx, ctx->allocate_output(
@@ -109,10 +119,14 @@ class ImageProjectiveTransform : public OpKernel {
}
};
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<TYPE>("dtype"), \
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("dtype"), \
+ ImageProjectiveTransform<CPUDevice, TYPE>); \
+ REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("dtype"), \
ImageProjectiveTransform<CPUDevice, TYPE>)
TF_CALL_uint8(REGISTER);
@@ -147,11 +161,15 @@ TF_CALL_double(DECLARE_FUNCTOR);
} // end namespace functor
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<TYPE>("dtype") \
- .HostMemory("output_shape"), \
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<TYPE>("dtype"), \
+ ImageProjectiveTransform<GPUDevice, TYPE>); \
+ REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<TYPE>("dtype") \
+ .HostMemory("output_shape"), \
ImageProjectiveTransform<GPUDevice, TYPE>)
TF_CALL_uint8(REGISTER);
diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc
index 4969ac58f9..6f7c9bb520 100644
--- a/tensorflow/contrib/image/ops/image_ops.cc
+++ b/tensorflow/contrib/image/ops/image_ops.cc
@@ -67,19 +67,7 @@ Status ResizeShapeFn(InferenceContext* c) {
c->Dim(input, 3));
}
-} // namespace
-
-// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
-// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
-REGISTER_OP("ImageProjectiveTransform")
- .Input("images: dtype")
- .Input("transforms: float32")
- .Input("output_shape: int32")
- .Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
- .Attr("interpolation: string")
- .Output("transformed_images: dtype")
- .SetShapeFn(ResizeShapeFn)
- .Doc(R"doc(
+static const char kImageProjectiveTransformDoc[] = R"doc(
Applies the given transform to each of the images.
Input `image` is a `Tensor` in NHWC format (where the axes are image in batch,
@@ -99,7 +87,35 @@ transforms: 2D `Tensor`, projective transform(s) to apply to the image(s).
transformed_images: 4D `Tensor`, image(s) in NHWC format, generated by applying
the `transforms` to the `images`. Satisfies the description above.
-)doc");
+)doc";
+
+} // namespace
+
+// TODO(ringwalt): Add a "fill_mode" attr with "constant", "mirror", etc.
+// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
+REGISTER_OP("ImageProjectiveTransform")
+ .Input("images: dtype")
+ .Input("transforms: float32")
+ .Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
+ .Attr("interpolation: string")
+ .Output("transformed_images: dtype")
+ // Output shape is identical to input images.
+ .SetShapeFn([](InferenceContext* c) {
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ })
+ .Doc(kImageProjectiveTransformDoc);
+
+// V2 op supports output_shape.
+REGISTER_OP("ImageProjectiveTransformV2")
+ .Input("images: dtype")
+ .Input("transforms: float32")
+ .Input("output_shape: int32")
+ .Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
+ .Attr("interpolation: string")
+ .Output("transformed_images: dtype")
+ .SetShapeFn(ResizeShapeFn)
+ .Doc(kImageProjectiveTransformDoc);
REGISTER_OP("BipartiteMatch")
.Input("distance_mat: float")
diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
index 70339d7612..376c0751ee 100644
--- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.contrib.image.ops import gen_image_ops
from tensorflow.contrib.image.python.ops import image_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -262,6 +263,15 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
self._test_grad_different_shape([4, 12, 3], [8, 24, 3])
self._test_grad_different_shape([3, 4, 12, 3], [3, 8, 24, 3])
+ def test_projective_transform_v1(self):
+ """The original ImageProjectiveTransform op should take 2 arguments."""
+ image = constant_op.constant([[[[1], [0]], [[0], [1]]]])
+ transform = constant_op.constant([[1., 0., 0., 0., 1., 0., 0., 0.]])
+ result = gen_image_ops.image_projective_transform(
+ image, transform, interpolation="NEAREST")
+ with self.cached_session():
+ self.assertAllEqual([[[[1], [0]], [[0], [1]]]], result.eval())
+
class BipartiteMatchTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py
index e7a09041ad..d4fb99a017 100644
--- a/tensorflow/contrib/image/python/ops/image_ops.py
+++ b/tensorflow/contrib/image/python/ops/image_ops.py
@@ -39,6 +39,7 @@ _IMAGE_DTYPES = set(
ops.RegisterShape("ImageConnectedComponents")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ImageProjectiveTransformV2")(common_shapes.call_cpp_shape_fn)
# TODO(ringwalt): Support a "reshape" (name used by SciPy) or "expand" (name
@@ -290,7 +291,7 @@ def transform(images,
else:
raise TypeError("Transforms should have rank 1 or 2.")
- output = gen_image_ops.image_projective_transform(
+ output = gen_image_ops.image_projective_transform_v2(
images,
output_shape=output_shape,
transforms=transforms,
@@ -391,7 +392,7 @@ def matrices_to_flat_transforms(transform_matrices):
return transforms[:, :8]
-@ops.RegisterGradient("ImageProjectiveTransform")
+@ops.RegisterGradient("ImageProjectiveTransformV2")
def _image_projective_transform_grad(op, grad):
"""Computes the gradient for ImageProjectiveTransform."""
images = op.inputs[0]
@@ -415,7 +416,7 @@ def _image_projective_transform_grad(op, grad):
transforms = flat_transforms_to_matrices(transforms=transforms)
inverse = linalg_ops.matrix_inverse(transforms)
transforms = matrices_to_flat_transforms(inverse)
- output = gen_image_ops.image_projective_transform(
+ output = gen_image_ops.image_projective_transform_v2(
images=grad,
transforms=transforms,
output_shape=array_ops.shape(image_or_images)[1:3],
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py
index 28d19a0445..53c8ae5d08 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column.py
@@ -1100,9 +1100,9 @@ class _EmbeddingColumn(
raise ValueError("Must specify both `ckpt_to_load_from` and "
"`tensor_name_in_ckpt` or none of them.")
if initializer is None:
- logging.warn("The default stddev value of initializer will change from "
- "\"1/sqrt(vocab_size)\" to \"1/sqrt(dimension)\" after "
- "2017/02/25.")
+ logging.warn("The default stddev value of initializer was changed from "
+ "\"1/sqrt(vocab_size)\" to \"1/sqrt(dimension)\" in core "
+ "implementation (tf.feature_column.embedding_column).")
stddev = 1 / math.sqrt(sparse_id_column.length)
initializer = init_ops.truncated_normal_initializer(
mean=0.0, stddev=stddev)
@@ -1501,8 +1501,6 @@ class _ScatteredEmbeddingColumn(
raise ValueError("initializer must be callable if specified. "
"column_name: {}".format(column_name))
if initializer is None:
- logging.warn("The default stddev value of initializer will change from "
- "\"0.1\" to \"1/sqrt(dimension)\" after 2017/02/25.")
stddev = 0.1
initializer = init_ops.truncated_normal_initializer(
mean=0.0, stddev=stddev)
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
index b25f11b5a6..06da32072f 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
@@ -30,6 +30,7 @@ import functools
import re
import numpy as np
+import six
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.framework.python import ops as contrib_framework_ops
@@ -44,6 +45,7 @@ from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
@@ -471,7 +473,8 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
Args:
fn: a function that takes Tensors (all as positional arguments) and returns
- a tuple of Tensors.
+ a tuple of Tensors. Note that `fn` should not close over any other
+ Tensors or Variables.
use_data_dep: `bool`, if `True` will use a dummy data dependency to force
the recompute to happen. If `False` will use a control dependency. By
default will be `True` if in an XLA context and `False` otherwise. XLA
@@ -485,7 +488,22 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
A wrapped fn that is identical to fn when called, but its activations will
be discarded and recomputed on the backwards pass (i.e. on a call to
tf.gradients).
+
+ Raises:
+ ValueError: if `fn` closes over any Tensors or Variables.
"""
+ # Check for closed-over Tensors/Variables
+ if fn.__code__.co_freevars:
+ closed_over_vars = dict(zip(fn.__code__.co_freevars,
+ [c.cell_contents for c in fn.__closure__]))
+ for var_name, value in six.iteritems(closed_over_vars):
+ if isinstance(value, (framework_ops.Tensor, variables_lib.Variable)):
+ raise ValueError(
+ "fn decorated with @recompute_grad closes over Tensor %s "
+ "(local variable name: %s). The decorated fn must not close over "
+ "Tensors or Variables because gradients will NOT be computed for "
+ "them through fn. To ensure correct gradients, make the "
+ "Tensor an input to fn." % (value.name, var_name))
@_safe_wraps(fn)
def wrapped(*args):
@@ -500,6 +518,62 @@ def _is_on_tpu():
return control_flow_util.GetContainingXLAContext(ctxt) is not None
+def _recomputing_grad_fn(compute_fn,
+ original_args,
+ original_vars,
+ output_grads,
+ grad_fn_variables,
+ use_data_dep,
+ tupleize_grads,
+ arg_scope,
+ var_scope,
+ has_is_recompute_kwarg):
+ """Grad fn for recompute_grad."""
+ variables = grad_fn_variables or []
+
+ # Identity ops around the inputs ensures correct gradient graph-walking.
+ inputs = [array_ops.identity(x) for x in list(original_args)]
+
+ # Recompute outputs
+ # Use a control dependency to ensure that the recompute is not eliminated by
+ # CSE and that it happens on the backwards pass.
+ ctrl_dep_grads = [g for g in output_grads if g is not None]
+ with framework_ops.control_dependencies(ctrl_dep_grads):
+ if use_data_dep:
+ inputs = _force_data_dependency(output_grads, inputs)
+ # Re-enter scopes
+ with contrib_framework_ops.arg_scope(arg_scope):
+ with variable_scope.variable_scope(var_scope, reuse=True):
+ # Re-call the function and ensure that the touched variables are the
+ # same as in the first call.
+ with backprop.GradientTape() as tape:
+ fn_kwargs = {}
+ if has_is_recompute_kwarg:
+ fn_kwargs["is_recomputing"] = True
+ outputs = compute_fn(*inputs, **fn_kwargs)
+ recompute_vars = set(tape.watched_variables())
+ if original_vars != recompute_vars:
+ raise ValueError(_WRONG_VARS_ERR)
+
+ if not isinstance(outputs, (list, tuple)):
+ outputs = [outputs]
+ outputs = list(outputs)
+
+ # Compute gradients
+ grads = gradients_impl.gradients(outputs, inputs + variables,
+ output_grads)
+
+ if tupleize_grads:
+ if use_data_dep:
+ grads = _tuple_with_data_dep(grads)
+ else:
+ grads = control_flow_ops.tuple(grads)
+
+ grad_inputs = grads[:len(inputs)]
+ grad_vars = grads[len(inputs):]
+ return grad_inputs, grad_vars
+
+
def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
"""See recompute_grad."""
has_is_recompute_kwarg = "is_recomputing" in tf_inspect.getargspec(fn).args
@@ -510,12 +584,16 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
if use_data_dep_ == _USE_DEFAULT:
use_data_dep_ = _is_on_tpu()
+ # Use custom_gradient and return a grad_fn that recomputes on the backwards
+ # pass.
@custom_gradient.custom_gradient
def fn_with_recompute(*args):
"""Wrapper for fn."""
- # Forward pass
+ # Capture the variable and arg scopes so we can re-enter them when
+ # recomputing.
vs = variable_scope.get_variable_scope()
arg_scope = contrib_framework_ops.current_arg_scope()
+ # Track all variables touched in the function.
with backprop.GradientTape() as tape:
fn_kwargs = {}
if has_is_recompute_kwarg:
@@ -523,46 +601,25 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
outputs = fn(*args, **fn_kwargs)
original_vars = set(tape.watched_variables())
- # Backward pass
def _grad_fn(output_grads, variables=None):
- """Recompute outputs for gradient computation."""
- variables = variables or []
+ # Validate that custom_gradient passes the right variables into grad_fn.
if original_vars:
assert variables, ("Fn created variables but the variables were not "
"passed to the gradient fn.")
if set(variables) != original_vars:
raise ValueError(_WRONG_VARS_ERR)
- inputs = [array_ops.identity(x) for x in list(args)]
- # Recompute outputs
- with framework_ops.control_dependencies(output_grads):
- if use_data_dep_:
- inputs = _force_data_dependency(output_grads, inputs)
- with contrib_framework_ops.arg_scope(arg_scope):
- with variable_scope.variable_scope(vs, reuse=True):
- with backprop.GradientTape() as tape:
- fn_kwargs = {}
- if has_is_recompute_kwarg:
- fn_kwargs["is_recomputing"] = True
- outputs = fn(*inputs, **fn_kwargs)
- recompute_vars = set(tape.watched_variables())
- if original_vars != recompute_vars:
- raise ValueError(_WRONG_VARS_ERR)
-
- if not isinstance(outputs, (list, tuple)):
- outputs = [outputs]
- outputs = list(outputs)
- grads = gradients_impl.gradients(outputs, inputs + variables,
- output_grads)
-
- if tupleize_grads:
- if use_data_dep_:
- grads = _tuple_with_data_dep(grads)
- else:
- grads = control_flow_ops.tuple(grads)
- grad_inputs = grads[:len(inputs)]
- grad_vars = grads[len(inputs):]
- return grad_inputs, grad_vars
+ return _recomputing_grad_fn(
+ compute_fn=fn,
+ original_args=args,
+ original_vars=original_vars,
+ output_grads=output_grads,
+ grad_fn_variables=variables,
+ use_data_dep=use_data_dep_,
+ tupleize_grads=tupleize_grads,
+ arg_scope=arg_scope,
+ var_scope=vs,
+ has_is_recompute_kwarg=has_is_recompute_kwarg)
# custom_gradient inspects the signature of the function to determine
# whether the user expects variables passed in the grad_fn. If the function
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
index d5971fb9d8..c34b5a8017 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
@@ -392,6 +392,16 @@ class RecomputeTest(test.TestCase):
with self.test_session() as sess:
sess.run(grads)
+ def testErrorOnClosedOverTensor(self):
+ x = random_ops.random_uniform((4, 8))
+ y = random_ops.random_uniform((4, 8))
+ z = x * y
+
+ with self.assertRaisesWithPredicateMatch(ValueError, "closes over"):
+ @rev_block_lib.recompute_grad
+ def fn_with_capture(a): # pylint: disable=unused-variable
+ return a * z
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md b/tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md
index a4f5086dde..5fe883d647 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md
+++ b/tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md
@@ -199,6 +199,46 @@ does.
However, in practice, convergence with $$x_0 = 0$$ always happens (tested for a
sample of generic values for the parameters).
+### Poisson log loss
+
+Poisson log loss is defined as $$ \l(u) = e^u - uy $$ for label $$y \geq 0.$$
+Its dual is
+
+$$ \l^\star(v) = (y+v) (\log(y+v) - 1) $$
+
+and is only defined for $$ y+v > 0 $$. We then have the constraint
+
+$$ y > \a+\d. $$
+
+The dual is
+
+$$ D(\d) = -(y-\a-\d) (\log(y-\a-\d) - 1) - \bar{y} \d - \frac{A}{2} \d^2 $$
+
+and its derivative is,
+
+$$ D'(\d) = \log(y-\a-\d) - \bar{y} - A\d $$
+
+Similar to the logistic loss, we perform a change of variable to handle the
+constraint on $$ \d $$
+
+$$ y - (\a+\d) = e^x $$
+
+After this change of variable, the goal is to find the zero of this function
+
+$$ H(x) = x - \bar{y} -A(y-\a-e^x) $$
+
+whose first derivative is
+
+$$ H'(x) = 1+Ae^x $$
+
+Since this function is always positive, $$H$$ is increasing and has a unique
+zero.
+
+We can start Newton algorithm at $$\d=0$$ which corresponds to $$ x =
+\log(y-\a)$$. As before the Newton step is given by
+
+$$x_{k+1} = x_k - \frac{H(x_k)}{H'(x_k)}. $$
+
### References
[1] C. Ma et al., Adding vs. Averaging in Distributed Primal-Dual Optimization,
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index ef0e08a777..1d2db1cec8 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -1192,6 +1192,57 @@ class SdcaWithSmoothHingeLossTest(SdcaModelTest):
self.assertAllClose(0.33, unregularized_loss.eval(), atol=0.02)
self.assertAllClose(0.44, regularized_loss.eval(), atol=0.02)
+class SdcaWithPoissonLossTest(SdcaModelTest):
+ """SDCA optimizer test class for poisson loss."""
+
+ def testSimple(self):
+ # Setup test data
+ example_protos = [
+ make_example_proto({
+ 'age': [0],
+ 'gender': [0]
+ }, 0),
+ make_example_proto({
+ 'age': [1],
+ 'gender': [1]
+ }, 2),
+ ]
+ example_weights = [100.0, 100.0]
+ with self._single_threaded_test_session():
+ examples = make_example_dict(example_protos, example_weights)
+ variables = make_variable_dict(1, 1)
+ options = dict(
+ symmetric_l2_regularization=1.0,
+ symmetric_l1_regularization=0,
+ loss_type='poisson_loss')
+ model = SdcaModel(examples, variables, options)
+ variables_lib.global_variables_initializer().run()
+
+ # Before minimization, the weights default to zero. There is no loss due
+ # to regularization, only unregularized loss which is 1 for each example.
+ predictions = model.predictions(examples)
+ self.assertAllClose([1.0, 1.0], predictions.eval())
+ unregularized_loss = model.unregularized_loss(examples)
+ regularized_loss = model.regularized_loss(examples)
+ approximate_duality_gap = model.approximate_duality_gap()
+ self.assertAllClose(1.0, unregularized_loss.eval())
+ self.assertAllClose(1.0, regularized_loss.eval())
+
+ # There are 4 sparse weights: 2 for age (say w1, w2) and 2 for gender
+ # (say w3 and w4). The minimization leads to:
+ # w1=w3=-1.96487, argmin of 100*(exp(2*w)-2*w*0)+w**2.
+ # w2=w4=0.345708, argmin of 100*(exp(2*w)-2*w*2)+w**2.
+ # This gives an unregularized loss of .3167 and .3366 with regularization.
+ train_op = model.minimize()
+ for _ in range(_MAX_ITERATIONS):
+ train_op.run()
+ model.update_weights(train_op).run()
+
+ self.assertAllClose([0.0196, 1.9965], predictions.eval(), atol=1e-4)
+ self.assertAllClose(0.3167, unregularized_loss.eval(), atol=1e-4)
+ self.assertAllClose(0.3366, regularized_loss.eval(), atol=1e-4)
+ self.assertAllClose(0., approximate_duality_gap.eval(), atol=1e-6)
+
class SdcaFprintTest(SdcaModelTest):
"""Tests for the SdcaFprint op.
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index 0047d5753a..14f59a3f64 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables as var_ops
+from tensorflow.python.ops.nn import log_poisson_loss
from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits
from tensorflow.python.summary import summary
@@ -51,6 +52,7 @@ class SdcaModel(object):
* Squared loss
* Hinge loss
* Smooth hinge loss
+ * Poisson log loss
This class defines an optimizer API to train a linear model.
@@ -112,7 +114,7 @@ class SdcaModel(object):
raise ValueError('examples, variables and options must all be specified.')
supported_losses = ('logistic_loss', 'squared_loss', 'hinge_loss',
- 'smooth_hinge_loss')
+ 'smooth_hinge_loss', 'poisson_loss')
if options['loss_type'] not in supported_losses:
raise ValueError('Unsupported loss_type: ', options['loss_type'])
@@ -315,6 +317,7 @@ class SdcaModel(object):
"""Add operations to compute predictions by the model.
If logistic_loss is being used, predicted probabilities are returned.
+ If poisson_loss is being used, predictions are exponentiated.
Otherwise, (raw) linear predictions (w*x) are returned.
Args:
@@ -335,6 +338,10 @@ class SdcaModel(object):
# Convert logits to probability for logistic loss predictions.
with name_scope('sdca/logistic_prediction'):
result = math_ops.sigmoid(result)
+ elif self._options['loss_type'] == 'poisson_loss':
+ # Exponeniate the prediction for poisson loss predictions.
+ with name_scope('sdca/poisson_prediction'):
+ result = math_ops.exp(result)
return result
def _get_partitioned_update_ops(self,
@@ -624,6 +631,11 @@ class SdcaModel(object):
logits=predictions),
weights)) / math_ops.reduce_sum(weights)
+ if self._options['loss_type'] == 'poisson_loss':
+ return math_ops.reduce_sum(math_ops.multiply(
+ log_poisson_loss(targets=labels, log_input=predictions),
+ weights)) / math_ops.reduce_sum(weights)
+
if self._options['loss_type'] in ['hinge_loss', 'smooth_hinge_loss']:
# hinge_loss = max{0, 1 - y_i w*x} where y_i \in {-1, 1}. So, we need to
# first convert 0/1 labels into -1/1 labels.
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD
index 88c70fbb8a..b6b2357873 100644
--- a/tensorflow/contrib/lite/delegates/eager/BUILD
+++ b/tensorflow/contrib/lite/delegates/eager/BUILD
@@ -133,6 +133,7 @@ cc_library(
"//conditions:default": [
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:framework",
+ "//tensorflow/core:tensorflow",
],
}),
)
diff --git a/tensorflow/contrib/lite/examples/android/app/build.gradle b/tensorflow/contrib/lite/examples/android/app/build.gradle
index eb7fd705e1..35e7887852 100644
--- a/tensorflow/contrib/lite/examples/android/app/build.gradle
+++ b/tensorflow/contrib/lite/examples/android/app/build.gradle
@@ -9,7 +9,6 @@ android {
targetSdkVersion 26
versionCode 1
versionName "1.0"
- testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
// Remove this block.
jackOptions {
@@ -51,10 +50,5 @@ apply from: "download-models.gradle"
dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
- androidTestCompile('androidx.test.espresso:espresso-core:3.1.0-alpha3', {
- exclude group: 'com.android.support', module: 'support-annotations'
- })
compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
-
- testCompile 'junit:junit:4.12'
}
diff --git a/tensorflow/contrib/lite/examples/ios/camera/Podfile b/tensorflow/contrib/lite/examples/ios/camera/Podfile
index 8084307ac7..f460693122 100644
--- a/tensorflow/contrib/lite/examples/ios/camera/Podfile
+++ b/tensorflow/contrib/lite/examples/ios/camera/Podfile
@@ -2,4 +2,4 @@ platform :ios, '8.0'
inhibit_all_warnings!
target 'tflite_camera_example'
- pod 'TensorFlowLite', '1.10.0'
+ pod 'TensorFlowLite', '1.10.1'
diff --git a/tensorflow/contrib/lite/examples/ios/simple/Podfile b/tensorflow/contrib/lite/examples/ios/simple/Podfile
index eea7ecb759..ddb77088d9 100644
--- a/tensorflow/contrib/lite/examples/ios/simple/Podfile
+++ b/tensorflow/contrib/lite/examples/ios/simple/Podfile
@@ -2,4 +2,4 @@ platform :ios, '8.0'
inhibit_all_warnings!
target 'tflite_simple_example'
- pod 'TensorFlowLite', '1.10.0'
+ pod 'TensorFlowLite', '1.10.1'
diff --git a/tensorflow/contrib/lite/g3doc/ios.md b/tensorflow/contrib/lite/g3doc/ios.md
index 5ff0412209..a83d2c8fec 100644
--- a/tensorflow/contrib/lite/g3doc/ios.md
+++ b/tensorflow/contrib/lite/g3doc/ios.md
@@ -36,7 +36,7 @@ brew link libtool
Then you need to run a shell script to download the dependencies you need:
```bash
-tensorflow/contrib/lite/download_dependencies.sh
+tensorflow/contrib/lite/tools/make/download_dependencies.sh
```
This will fetch copies of libraries and data from the web and install them in
@@ -46,14 +46,14 @@ With all of the dependencies set up, you can now build the library for all five
supported architectures on iOS:
```bash
-tensorflow/contrib/lite/build_ios_universal_lib.sh
+tensorflow/contrib/lite/tools/make/build_ios_universal_lib.sh
```
Under the hood this uses a makefile in `tensorflow/contrib/lite` to build the
different versions of the library, followed by a call to `lipo` to bundle them
into a universal file containing armv7, armv7s, arm64, i386, and x86_64
architectures. The resulting library is in
-`tensorflow/contrib/lite/gen/lib/libtensorflow-lite.a`.
+`tensorflow/contrib/lite/tools/make/gen/lib/libtensorflow-lite.a`.
If you get an error such as `no such file or directory: 'x86_64'` when running
`build_ios_universal_lib.sh`: open Xcode > Preferences > Locations, and ensure
diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md
index b984671e89..0f9d016e6d 100644
--- a/tensorflow/contrib/lite/g3doc/models.md
+++ b/tensorflow/contrib/lite/g3doc/models.md
@@ -7,55 +7,58 @@ Model Name | Paper_Model_Files^
------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | --------------------: | ---------------------:
DenseNet | [paper](https://arxiv.org/abs/1608.06993), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz) | 43.6 Mb | 64.2% | 85.6% | 894 ms | 1262 ms
SqueezeNet | [paper](https://arxiv.org/abs/1602.07360), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz) | 5.0 Mb | 49.0% | 72.9% | 224 ms | 255 ms
-NASNet mobile | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz) | 21.4 Mb | 72.2% | 90.6% | 261 ms | 389 ms
-NASNet large | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz) | 355.3 Mb | 82.1% | 95.8% | 6697 ms | 7940 ms
+NASNet mobile | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz) | 21.4 Mb | 74.2% | 91.7% | 261 ms | 389 ms
+NASNet large | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz) | 355.3 Mb | 82.8% | 96.2% | 6697 ms | 7940 ms
ResNet_V2_50 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/resnet_v2_50_2018_04_27.tgz) | 102.3 Mb | 68.1% | 88.4% | 942 ms | 1008 ms
ResNet_V2_101 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/resnet_v2_101_2018_04_27.tgz) | 178.3 Mb | 70.4% | 89.6% | 1880 ms | 1970 ms
-Inception_V3 | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz) | 95.3 Mb | 76.9% | 93.5% | 1433 ms | 1522 ms
-Inception_V4 | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz) | 170.7 Mb | 79.6% | 94.6% | 2986 ms | 3139 ms
-Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb | 76.8% | 93.5% | 2731 ms | 2926 ms
-Mobilenet_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz) | 1.9 Mb | 41.5% | 66.3% | 6.2 ms | 13.0 ms
-Mobilenet_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz) | 1.9 Mb | 45.5% | 70.3% | 8.6 ms | 19.5 ms
-Mobilenet_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz) | 1.9 Mb | 47.7% | 72.3% | 12.1 ms | 27.8 ms
-Mobilenet_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz) | 1.9 Mb | 49.8% | 74.2% | 16.2 ms | 37.3 ms
-Mobilenet_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz) | 5.3 Mb | 56.3% | 79.4% | 18.1 ms | 29.9 ms
-Mobilenet_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz) | 5.3 Mb | 59.1% | 81.9% | 26.8 ms | 45.9 ms
-Mobilenet_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz) | 5.3 Mb | 61.7% | 83.6% | 35.6 ms | 65.3 ms
-Mobilenet_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz) | 5.3 Mb | 63.3% | 84.9% | 47.6 ms | 164.2 ms
-Mobilenet_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz) | 10.3 Mb | 62.1% | 83.9% | 34.6 ms | 48.7 ms
-Mobilenet_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz) | 10.3 Mb | 65.3% | 86.0% | 51.3 ms | 75.2 ms
-Mobilenet_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz) | 10.3 Mb | 67.2% | 87.3% | 71.7 ms | 107.0 ms
-Mobilenet_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz) | 10.3 Mb | 68.4% | 88.2% | 95.7 ms | 143.4 ms
-Mobilenet_1.0_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz) | 16.9 Mb | 65.2% | 85.8% | 57.4 ms | 76.8 ms
-Mobilenet_1.0_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz) | 16.9 Mb | 68.0% | 87.7% | 86.0 ms | 117.7 ms
-Mobilenet_1.0_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz) | 16.9 Mb | 70.0% | 89.2% | 118.6 ms | 167.3 ms
-Mobilenet_1.0_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) | 16.9 Mb | 70.9% | 89.9% | 160.1 ms | 224.3 ms
+Inception_V3 | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz) | 95.3 Mb | 78.2% | 94.0% | 1433 ms | 1522 ms
+Inception_V4 | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz) | 170.7 Mb | 80.4% | 95.2% | 2986 ms | 3139 ms
+Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb | 77.8% | 94.1% | 2731 ms | 2926 ms
+Mobilenet_V1_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz) | 1.9 Mb | 41.6% | 66.6% | 6.2 ms | 13.0 ms
+Mobilenet_V1_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz) | 1.9 Mb | 45.7% | 70.6% | 8.6 ms | 19.5 ms
+Mobilenet_V1_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz) | 1.9 Mb | 47.5% | 72.4% | 12.1 ms | 27.8 ms
+Mobilenet_V1_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz) | 1.9 Mb | 50.0% | 74.4% | 16.2 ms | 37.3 ms
+Mobilenet_V1_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz) | 5.3 Mb | 56.5% | 79.5% | 18.1 ms | 29.9 ms
+Mobilenet_V1_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz) | 5.3 Mb | 59.3% | 82.1% | 26.8 ms | 45.9 ms
+Mobilenet_V1_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz) | 5.3 Mb | 62.0% | 83.7% | 35.6 ms | 65.3 ms
+Mobilenet_V1_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz) | 5.3 Mb | 63.5% | 85.0% | 47.6 ms | 164.2 ms
+Mobilenet_V1_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz) | 10.3 Mb | 62.3% | 84.1% | 34.6 ms | 48.7 ms
+Mobilenet_V1_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz) | 10.3 Mb | 65.5% | 86.1% | 51.3 ms | 75.2 ms
+Mobilenet_V1_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz) | 10.3 Mb | 67.4% | 87.4% | 71.7 ms | 107.0 ms
+Mobilenet_V1_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz) | 10.3 Mb | 68.6% | 88.3% | 95.7 ms | 143.4 ms
+Mobilenet_V1_1.0_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz) | 16.9 Mb | 65.5% | 85.9% | 57.4 ms | 76.8 ms
+Mobilenet_V1_1.0_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz) | 16.9 Mb | 68.3% | 87.8% | 86.0 ms | 117.7 ms
+Mobilenet_V1_1.0_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz) | 16.9 Mb | 70.2% | 89.3% | 118.6 ms | 167.3 ms
+Mobilenet_V1_1.0_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) | 16.9 Mb | 71.3% | 90.1% | 160.1 ms | 224.3 ms
^ The model files include both TF Lite FlatBuffer and Tensorflow frozen Graph.
^^ The performance numbers are generated in the benchmark on Pixel-2 using
single thread large core.
+^^ Accuracy numbers were computed using the [TFLite accuracy tool](../tools/accuracy/ilsvrc)
+after excluding blacklisted images.
+
## Image classification (Quantized Models)
Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance
------------------------ | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------:
-Mobilenet_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.5% | 64.4% | 3.7 ms
-Mobilenet_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 43.4% | 68.5% | 5.5 ms
-Mobilenet_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 46.0% | 71.2% | 7.9 ms
-Mobilenet_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.0% | 72.8% | 10.4 ms
-Mobilenet_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.5% | 77.7% | 8.8 ms
-Mobilenet_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.7% | 80.4% | 13.0 ms
-Mobilenet_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 60.0% | 82.2% | 18.3 ms
-Mobilenet_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 60.7% | 83.2% | 24.7 ms
-Mobilenet_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 55.8% | 78.8% | 16.2 ms
-Mobilenet_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 62.3% | 83.8% | 24.3 ms
-Mobilenet_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.1% | 86.4% | 33.8 ms
-Mobilenet_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 66.8% | 87.0% | 45.4 ms
-Mobilenet_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 63.4% | 84.2% | 24.9 ms
-Mobilenet_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 67.2% | 86.7% | 37.4 ms
-Mobilenet_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.2% | 88.3% | 51.9 ms
-Mobilenet_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 70.1% | 88.9% | 70.2 ms
+Mobilenet_V1_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.8% | 64.8% | 3.7 ms
+Mobilenet_V1_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 43.0% | 68.4% | 5.5 ms
+Mobilenet_V1_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 46.0% | 71.2% | 7.9 ms
+Mobilenet_V1_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.5% | 73.1% | 10.4 ms
+Mobilenet_V1_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 55.2% | 78.4% | 8.8 ms
+Mobilenet_V1_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.5% | 80.7% | 13.0 ms
+Mobilenet_V1_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 60.2% | 82.3% | 18.3 ms
+Mobilenet_V1_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 61.5% | 83.5% | 24.7 ms
+Mobilenet_V1_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 56.2% | 79.4% | 16.2 ms
+Mobilenet_V1_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 62.7% | 83.9% | 24.3 ms
+Mobilenet_V1_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.4% | 86.4% | 33.8 ms
+Mobilenet_V1_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 67.2% | 87.0% | 45.4 ms
+Mobilenet_V1_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 63.6% | 84.3% | 24.9 ms
+Mobilenet_V1_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 67.2% | 86.9% | 37.4 ms
+Mobilenet_V1_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.4% | 88.3% | 51.9 ms
+Mobilenet_V1_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 70.2% | 89.1% | 70.2 ms
## Other models
diff --git a/tensorflow/contrib/lite/g3doc/rpi.md b/tensorflow/contrib/lite/g3doc/rpi.md
index 8ed8640582..41a1892b6f 100644
--- a/tensorflow/contrib/lite/g3doc/rpi.md
+++ b/tensorflow/contrib/lite/g3doc/rpi.md
@@ -1,28 +1,36 @@
-
# TensorFlow Lite for Raspberry Pi
## Cross compiling
-### Installing toolchian
-This has been tested on Ubuntu 16.04.3 64bit and Tensorflow devel docker image [tensorflow/tensorflow:nightly-devel](https://hub.docker.com/r/tensorflow/tensorflow/tags/).
-To cross compiling TensorFlow Lite. First you should install the toolchain and libs.
+### Installing the toolchain
+
+This has been tested on Ubuntu 16.04.3 64bit and Tensorflow devel docker image
+[tensorflow/tensorflow:nightly-devel](https://hub.docker.com/r/tensorflow/tensorflow/tags/).
+
+To cross compile TensorFlow Lite, first install the toolchain and libs.
+
```bash
sudo apt-get update
sudo apt-get install crossbuild-essential-armhf
```
-> If you are using docker, you may not use `sudo`
+
+> If you are using Docker, you may not use `sudo`.
### Building
+
Clone this Tensorflow repository, Run this script at the root of the repository to download all the dependencies:
+
> The Tensorflow repository is in `/tensorflow` if you are using `tensorflow/tensorflow:nightly-devel` docker image, just try it.
+
```bash
-./tensorflow/contrib/lite/download_dependencies.sh
+./tensorflow/contrib/lite/tools/make/download_dependencies.sh
```
Note that you only need to do this once.
You should then be able to compile:
+
```bash
-./tensorflow/contrib/lite/build_rpi_lib.sh
+./tensorflow/contrib/lite/tools/make/build_rpi_lib.sh
```
This should compile a static library in:
@@ -31,21 +39,23 @@ This should compile a static library in:
## Native compiling
This has been tested on Raspberry Pi 3b, Raspbian GNU/Linux 9.1 (stretch), gcc version 6.3.0 20170516 (Raspbian 6.3.0-18+rpi1).
-Log in to you RPI, install the toolchain.
+Log in to you Raspberry Pi, install the toolchain.
+
```bash
sudo apt-get install build-essential
```
-First, clone this TensorFlow repository. Run this at the root of the repository:
+First, clone the TensorFlow repository. Run this at the root of the repository:
+
```bash
-./tensorflow/contrib/lite/download_dependencies.sh
+./tensorflow/contrib/lite/tools/make/download_dependencies.sh
```
Note that you only need to do this once.
You should then be able to compile:
```bash
-./tensorflow/contrib/lite/build_rpi_lib.sh
+./tensorflow/contrib/lite/tools/make/build_rpi_lib.sh
```
This should compile a static library in:
-`tensorflow/contrib/lite/gen/lib/rpi_armv7/libtensorflow-lite.a`.
+`tensorflow/contrib/lite/tools/make/gen/lib/rpi_armv7/libtensorflow-lite.a`.
diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle
index 92f04c651c..05301ebf88 100644
--- a/tensorflow/contrib/lite/java/demo/app/build.gradle
+++ b/tensorflow/contrib/lite/java/demo/app/build.gradle
@@ -10,7 +10,6 @@ android {
targetSdkVersion 26
versionCode 1
versionName "1.0"
- testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
// Remove this block.
jackOptions {
@@ -44,9 +43,6 @@ repositories {
dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
- androidTestCompile('androidx.test.espresso:espresso-core:3.1.0-alpha3', {
- exclude group: 'com.android.support', module: 'support-annotations'
- })
compile 'com.android.support:appcompat-v7:25.2.0'
compile 'com.android.support.constraint:constraint-layout:1.0.2'
compile 'com.android.support:design:25.2.0'
@@ -54,8 +50,6 @@ dependencies {
compile 'com.android.support:support-v13:25.2.0'
compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
-
- testCompile 'junit:junit:4.12'
}
def modelDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip"
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle
index 2a08608bbb..4f3a6cdb2f 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle
@@ -9,7 +9,6 @@ android {
targetSdkVersion 26
versionCode 1
versionName "1.0"
- testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
// Remove this block.
jackOptions {
@@ -43,9 +42,6 @@ repositories {
dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
- androidTestCompile('androidx.test.espresso:espresso-core:3.1.0-alpha3', {
- exclude group: 'com.android.support', module: 'support-annotations'
- })
compile 'com.android.support:appcompat-v7:25.2.0'
compile 'com.android.support.constraint:constraint-layout:1.0.2'
compile 'com.android.support:design:25.2.0'
@@ -53,6 +49,4 @@ dependencies {
compile 'com.android.support:support-v13:25.2.0'
compile 'org.tensorflow:tensorflow-lite:+'
-
- testCompile 'junit:junit:4.12'
}
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index d6d62580e2..9c891fe904 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -590,10 +590,10 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
input->type);
return kTfLiteError;
}
- reference_ops::BroadcastBinaryFunction<float, float, float>(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(alpha), GetTensorDims(alpha),
- GetTensorData<float>(output), GetTensorDims(output), ApplyPrelu<float>);
+ reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
+ GetTensorShape(input), GetTensorData<float>(input), GetTensorShape(alpha),
+ GetTensorData<float>(alpha), GetTensorShape(output),
+ GetTensorData<float>(output), ApplyPrelu<float>);
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/kernels/arg_min_max.cc b/tensorflow/contrib/lite/kernels/arg_min_max.cc
index 4f30d09030..6e05f5a9b2 100644
--- a/tensorflow/contrib/lite/kernels/arg_min_max.cc
+++ b/tensorflow/contrib/lite/kernels/arg_min_max.cc
@@ -96,11 +96,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
const TfLiteTensor* axis = GetInput(context, node, kAxis);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-#define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \
- optimized_ops::ArgMinMax( \
- GetTensorData<axis_type>(axis), GetTensorData<data_type>(input), \
- GetTensorDims(input), GetTensorData<output_type>(output), \
- GetTensorDims(output), GetComparefunction<data_type>(is_arg_max))
+#define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \
+ optimized_ops::ArgMinMax( \
+ GetTensorShape(input), GetTensorData<data_type>(input), \
+ GetTensorData<axis_type>(axis), GetTensorShape(output), \
+ GetTensorData<output_type>(output), \
+ GetComparefunction<data_type>(is_arg_max))
if (axis->type == kTfLiteInt32) {
switch (output->type) {
case kTfLiteInt32: {
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
index c8cee88edf..4efa9d596d 100644
--- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
+++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
@@ -125,14 +125,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \
- type::BatchToSpaceND(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), \
+ type::BatchToSpaceND(GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ GetTensorShape(op_context.block_shape), \
GetTensorData<int32_t>(op_context.block_shape), \
- GetTensorDims(op_context.block_shape), \
+ GetTensorShape(op_context.crops), \
GetTensorData<int32_t>(op_context.crops), \
- GetTensorDims(op_context.crops), \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output))
+ GetTensorShape(op_context.output), \
+ GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index a11a59aa05..af47b33922 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -94,18 +94,23 @@ constexpr int kBwProjectionWeightsTensor = 33; // Optional
// Projection bias tensor of size {n_output}
constexpr int kBwProjectionBiasTensor = 34; // Optional
-// Output tensors.
-constexpr int kFwOutputStateTensor = 0;
-constexpr int kFwCellStateTensor = 1;
-constexpr int kFwOutputTensor = 2;
+// Stateful input tensors that are variables and will be modified by the Op.
+// Activation state tensors of size {n_batch, n_output}
+constexpr int kFwInputActivationStateTensor = 35;
+// Cell state tensors of size {n_batch, n_cell}
+constexpr int kFwInputCellStateTensor = 36;
+// Activation state tensors of size {n_batch, n_output}
+constexpr int kBwInputActivationStateTensor = 37;
+// Cell state tensors of size {n_batch, n_cell}
+constexpr int kBwInputCellStateTensor = 38;
-constexpr int kBwOutputStateTensor = 3;
-constexpr int kBwCellStateTensor = 4;
-constexpr int kBwOutputTensor = 5;
+// Output tensors.
+constexpr int kFwOutputTensor = 0;
+constexpr int kBwOutputTensor = 1;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
- context->AddTensors(context, 2, scratch_tensor_index);
+ context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index);
return scratch_tensor_index;
}
@@ -307,14 +312,14 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
return kTfLiteOk;
}
-// Resize the output, state and scratch tensors based on the sizes of the input
+// Resize the output and scratch tensors based on the sizes of the input
// tensors. Also check that the size of the input tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 35);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 6);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 39);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
@@ -343,13 +348,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context, CheckInputTensorDimensions(context, node, n_input, n_fw_output,
n_fw_cell));
- // Get the pointer to output, state and scratch buffer tensors.
+ // Get the pointer to output, activation_state and cell_state buffer tensors.
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
- TfLiteTensor* fw_output_state =
- GetOutput(context, node, kFwOutputStateTensor);
- TfLiteTensor* fw_cell_state = GetOutput(context, node, kFwCellStateTensor);
-
- // Resize the output, output_state and cell_state tensors.
+ TfLiteTensor* fw_activation_state =
+ GetVariableInput(context, node, kFwInputActivationStateTensor);
+ TfLiteTensor* fw_cell_state =
+ GetVariableInput(context, node, kFwInputCellStateTensor);
+
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(fw_activation_state),
+ n_batch * n_fw_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(fw_cell_state), n_batch * n_fw_cell);
+
+ // Resize the output tensors.
TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3);
fw_output_size->data[0] = max_time;
fw_output_size->data[1] = n_batch;
@@ -357,18 +370,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, fw_output, fw_output_size));
- TfLiteIntArray* fw_output_state_size = TfLiteIntArrayCreate(2);
- fw_output_state_size->data[0] = n_batch;
- fw_output_state_size->data[1] = n_fw_output;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_output_state,
- fw_output_state_size));
-
- TfLiteIntArray* fw_cell_size = TfLiteIntArrayCreate(2);
- fw_cell_size->data[0] = n_batch;
- fw_cell_size->data[1] = n_fw_cell;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, fw_cell_state, fw_cell_size));
-
// Create a scratch buffer tensor.
TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(2);
@@ -377,10 +378,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
fw_scratch_buffer->type = input->type;
fw_scratch_buffer->allocation_type = kTfLiteArenaRw;
- // Mark state tensors as persistent tensors.
- fw_output_state->allocation_type = kTfLiteArenaRwPersistent;
- fw_cell_state->allocation_type = kTfLiteArenaRwPersistent;
-
const TfLiteTensor* fw_input_to_input_weights =
GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor);
const bool fw_use_cifg = (fw_input_to_input_weights == nullptr);
@@ -415,13 +412,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context, CheckInputTensorDimensions(context, node, n_input, n_bw_output,
n_bw_cell));
- // Get the pointer to output, output_state and cell_state buffer tensors.
+ // Get the pointer to output, activation_state and cell_state buffer tensors.
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
- TfLiteTensor* bw_output_state =
- GetOutput(context, node, kBwOutputStateTensor);
- TfLiteTensor* bw_cell_state = GetOutput(context, node, kBwCellStateTensor);
+ TfLiteTensor* bw_activation_state =
+ GetVariableInput(context, node, kBwInputActivationStateTensor);
+ TfLiteTensor* bw_cell_state =
+ GetVariableInput(context, node, kBwInputCellStateTensor);
- // Resize the output, output_state and cell_state tensors.
+ // Resize the output tensors.
TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
bw_output_size->data[0] = max_time;
bw_output_size->data[1] = n_batch;
@@ -429,17 +427,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, bw_output, bw_output_size));
- TfLiteIntArray* bw_output_state_size = TfLiteIntArrayCreate(2);
- bw_output_state_size->data[0] = n_batch;
- bw_output_state_size->data[1] = n_bw_output;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output_state,
- bw_output_state_size));
-
- TfLiteIntArray* bw_cell_size = TfLiteIntArrayCreate(2);
- bw_cell_size->data[0] = n_batch;
- bw_cell_size->data[1] = n_bw_cell;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, bw_cell_state, bw_cell_size));
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(bw_activation_state),
+ n_batch * n_bw_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(bw_cell_state), n_batch * n_bw_cell);
// Create a scratch buffer tensor.
node->temporaries->data[1] = *(scratch_tensor_index) + 1;
@@ -447,10 +440,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
bw_scratch_buffer->type = input->type;
bw_scratch_buffer->allocation_type = kTfLiteArenaRw;
- // Mark state tensors as persistent tensors.
- bw_output_state->allocation_type = kTfLiteArenaRwPersistent;
- bw_cell_state->allocation_type = kTfLiteArenaRwPersistent;
-
const TfLiteTensor* bw_input_to_input_weights =
GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor);
const bool bw_use_cifg = (bw_input_to_input_weights == nullptr);
@@ -518,9 +507,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* fw_projection_bias =
GetOptionalInputTensor(context, node, kFwProjectionBiasTensor);
- TfLiteTensor* fw_output_state =
- GetOutput(context, node, kFwOutputStateTensor);
- TfLiteTensor* fw_cell_state = GetOutput(context, node, kFwCellStateTensor);
+ TfLiteTensor* fw_activation_state =
+ GetVariableInput(context, node, kFwInputActivationStateTensor);
+ TfLiteTensor* fw_cell_state =
+ GetVariableInput(context, node, kFwInputCellStateTensor);
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
// Tensors for the backward cell.
@@ -563,9 +553,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* bw_projection_bias =
GetOptionalInputTensor(context, node, kBwProjectionBiasTensor);
- TfLiteTensor* bw_output_state =
- GetOutput(context, node, kBwOutputStateTensor);
- TfLiteTensor* bw_cell_state = GetOutput(context, node, kBwCellStateTensor);
+ TfLiteTensor* bw_activation_state =
+ GetVariableInput(context, node, kBwInputActivationStateTensor);
+ TfLiteTensor* bw_cell_state =
+ GetVariableInput(context, node, kBwInputCellStateTensor);
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
// n_cell and n_output will be the same size when there is no projection.
@@ -634,7 +625,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_input_gate_bias_ptr, fw_forget_gate_bias->data.f,
fw_cell_bias->data.f, fw_output_gate_bias->data.f,
fw_projection_weights_ptr, fw_projection_bias_ptr, params, n_batch,
- n_fw_cell, n_input, n_fw_output, fw_output_state->data.f,
+ n_fw_cell, n_input, n_fw_output, fw_activation_state->data.f,
fw_cell_state->data.f, fw_input_gate_scratch, fw_forget_gate_scratch,
fw_cell_scratch, fw_output_gate_scratch, output_ptr_time);
}
@@ -705,7 +696,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bw_input_gate_bias_ptr, bw_forget_gate_bias->data.f,
bw_cell_bias->data.f, bw_output_gate_bias->data.f,
bw_projection_weights_ptr, bw_projection_bias_ptr, params, n_batch,
- n_bw_cell, n_input, n_bw_output, bw_output_state->data.f,
+ n_bw_cell, n_input, n_bw_output, bw_activation_state->data.f,
bw_cell_state->data.f, bw_input_gate_scratch, bw_forget_gate_scratch,
bw_cell_scratch, bw_output_gate_scratch, output_ptr_time);
}
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
index a18e1bce34..d058fab529 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
@@ -102,10 +102,6 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
fw_projection_bias_ = AddNullInput();
}
- fw_output_state_ = AddOutput(TensorType_FLOAT32);
- fw_cell_state_ = AddOutput(TensorType_FLOAT32);
- fw_output_ = AddOutput(TensorType_FLOAT32);
-
if (use_cifg) {
bw_input_to_input_weights_ = AddNullInput();
} else {
@@ -161,8 +157,24 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
bw_projection_bias_ = AddNullInput();
}
- bw_output_state_ = AddOutput(TensorType_FLOAT32);
- bw_cell_state_ = AddOutput(TensorType_FLOAT32);
+ // Adding the 2 input state tensors.
+ fw_input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_fw_output_ * n_batch_}},
+ /*is_variable=*/true);
+ fw_input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_fw_cell_ * n_batch_}},
+ /*is_variable=*/true);
+
+ // Adding the 2 input state tensors.
+ bw_input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_bw_output_ * n_batch_}},
+ /*is_variable=*/true);
+ bw_input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_bw_cell_ * n_batch_}},
+ /*is_variable=*/true);
+
+ fw_output_ = AddOutput(TensorType_FLOAT32);
+
bw_output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
@@ -259,26 +271,6 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
PopulateTensor(bw_projection_bias_, f);
}
- void ResetFwOutputAndCellStates() {
- const int zero_buffer_size = n_fw_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(fw_output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- PopulateTensor(fw_cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetBwOutputAndCellStates() {
- const int zero_buffer_size = n_bw_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(bw_output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- PopulateTensor(bw_cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
void SetInput(int offset, float* begin, float* end) {
PopulateTensor(input_, offset, begin, end);
}
@@ -340,13 +332,13 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
int bw_projection_weights_;
int bw_projection_bias_;
- int fw_output_;
- int fw_output_state_;
- int fw_cell_state_;
+ int fw_input_activation_state_;
+ int fw_input_cell_state_;
+ int bw_input_activation_state_;
+ int bw_input_cell_state_;
+ int fw_output_;
int bw_output_;
- int bw_output_state_;
- int bw_cell_state_;
int n_batch_;
int n_input_;
@@ -417,6 +409,12 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
@@ -474,10 +472,6 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
-0.0332076, 0.123838, 0.309777, -0.17621,
-0.0490733, 0.0739237, 0.067706, -0.0208124};
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
-
float* batch0_start = lstm_input;
float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
@@ -500,34 +494,151 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end);
EXPECT_THAT(lstm.GetBwOutput(),
ElementsAreArray(ArrayFloatNear(bw_expected)));
+}
+
+TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
+
+ BidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
+ /*use_peephole=*/false, /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
+
+ // Forward cell
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ // Backward cell
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+ });
+
+ lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
+ -0.34550029, 0.04266912, -0.15680569,
+ -0.34856534, 0.43890524});
+
+ lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
+ -0.20583314, 0.44344562, 0.22077113,
+ -0.29909778});
+
+ lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
+ -0.31343272, -0.40032279, 0.44781327,
+ 0.01387155, -0.35593212});
+
+ lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
+ 0.40525138, 0.44272184, 0.03897077, -0.1556896,
+ 0.19487578});
+
+ lstm.SetInputGateBias({0., 0., 0., 0.});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToInputWeights(
+ {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
+ -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
+ -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
+
+ lstm.SetRecurrentToCellWeights(
+ {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
+ -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
+ -0.46367589, 0.26016325, -0.03894562, -0.16368064});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
+ -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
+ 0.28053468, 0.01560611, -0.20127171, -0.01140004});
+
+ lstm.SetRecurrentToOutputWeights(
+ {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
+ 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
+ -0.51818722, -0.15390486, 0.0468148, 0.39922136});
+
+ // Input should have n_input * sequence_length many values.
// Check reversed inputs.
static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.};
+ static float lstm_fw_golden_output[] = {
+ -0.02973187, 0.1229473, 0.20885126, -0.15358765,
+ -0.03716109, 0.12507336, 0.41193449, -0.20860538,
+ -0.15053082, 0.09120187, 0.24278517, -0.12222792};
+ static float lstm_bw_golden_output[] = {
+ -0.0806187, 0.139077, 0.400476, -0.197842, -0.0332076, 0.123838,
+ 0.309777, -0.17621, -0.0490733, 0.0739237, 0.067706, -0.0208124};
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
-
- batch0_start = lstm_input_reversed;
- batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
+ float* batch0_start = lstm_input_reversed;
+ float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
lstm.SetInput(0, batch0_start, batch0_end);
lstm.Invoke();
- fw_expected.clear();
+ std::vector<float> fw_expected;
for (int s = 0; s < lstm.sequence_length(); s++) {
- fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
- fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
+ float* fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
+ float* fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
fw_expected.insert(fw_expected.begin(), fw_golden_start, fw_golden_end);
}
EXPECT_THAT(lstm.GetBwOutput(),
ElementsAreArray(ArrayFloatNear(fw_expected)));
- bw_expected.clear();
+ std::vector<float> bw_expected;
for (int s = 0; s < lstm.sequence_length(); s++) {
- bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
- bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
+ float* bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
+ float* bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
bw_expected.insert(bw_expected.begin(), bw_golden_start, bw_golden_end);
}
EXPECT_THAT(lstm.GetFwOutput(),
@@ -592,6 +703,12 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
@@ -642,10 +759,6 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
-0.401685, -0.0232794, 0.288642, -0.123074, -0.42915, -0.00871577,
0.20912, -0.103567, -0.166398, -0.00486649, 0.0697471, -0.0537578};
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
-
float* batch0_start = lstm_input;
float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
@@ -668,34 +781,143 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end);
EXPECT_THAT(lstm.GetBwOutput(),
ElementsAreArray(ArrayFloatNear(bw_expected)));
+}
- // Check reversed inputs.
- static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.};
+TEST(LSTMOpTest,
+ BlackBoxTestWithCifgWithPeepholeNoProjectionNoClippingReversed) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
+ BidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
+ /*use_peephole=*/true, /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
- batch0_start = lstm_input_reversed;
- batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
+ {0, 0}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {0, 0}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {0}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ {0, 0}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {0, 0}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {0}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+ });
+
+ lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
+ 0.04717243, 0.48944736, -0.38535351,
+ -0.17212132});
+
+ lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
+ -0.3633365, -0.22755712, 0.28253698, 0.24407166,
+ 0.33826375});
+
+ lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
+ -0.09426838, -0.44257352, 0.54939759,
+ 0.01533556, 0.42751634});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToCellWeights(
+ {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
+ 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
+ 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
+ 0.21193194});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
+ 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
+ -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
+
+ lstm.SetRecurrentToOutputWeights(
+ {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
+ -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
+ 0.50248802, 0.26114327, -0.43736315, 0.33149987});
+
+ lstm.SetCellToForgetWeights(
+ {0.47485286, -0.51955009, -0.24458408, 0.31544167});
+ lstm.SetCellToOutputWeights(
+ {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
+
+ static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.};
+ static float lstm_fw_golden_output[] = {
+ -0.36444446, -0.00352185, 0.12886585, -0.05163646,
+ -0.42312205, -0.01218222, 0.24201041, -0.08124574,
+ -0.358325, -0.04621704, 0.21641694, -0.06471302};
+ static float lstm_bw_golden_output[] = {
+ -0.401685, -0.0232794, 0.288642, -0.123074, -0.42915, -0.00871577,
+ 0.20912, -0.103567, -0.166398, -0.00486649, 0.0697471, -0.0537578};
+
+ float* batch0_start = lstm_input_reversed;
+ float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
lstm.SetInput(0, batch0_start, batch0_end);
lstm.Invoke();
- fw_expected.clear();
+ std::vector<float> fw_expected;
for (int s = 0; s < lstm.sequence_length(); s++) {
- fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
- fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
+ float* fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
+ float* fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
fw_expected.insert(fw_expected.begin(), fw_golden_start, fw_golden_end);
}
EXPECT_THAT(lstm.GetBwOutput(),
ElementsAreArray(ArrayFloatNear(fw_expected)));
- bw_expected.clear();
+ std::vector<float> bw_expected;
for (int s = 0; s < lstm.sequence_length(); s++) {
- bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
- bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
+ float* bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
+ float* bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
bw_expected.insert(bw_expected.begin(), bw_golden_start, bw_golden_end);
}
EXPECT_THAT(lstm.GetFwOutput(),
@@ -759,6 +981,12 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToInputWeights(
@@ -1343,10 +1571,6 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
0.065133, 0.024321, 0.038473, 0.062438
}};
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
-
for (int i = 0; i < lstm.sequence_length(); i++) {
float* batch0_start = lstm_input[0] + i * lstm.num_inputs();
float* batch0_end = batch0_start + lstm.num_inputs();
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
index 4162d9bb88..d988ef8b33 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
@@ -36,18 +36,32 @@ constexpr int kInputTensor = 0;
constexpr int kFwWeightsTensor = 1;
constexpr int kFwRecurrentWeightsTensor = 2;
constexpr int kFwBiasTensor = 3;
-constexpr int kBwWeightsTensor = 4;
-constexpr int kBwRecurrentWeightsTensor = 5;
-constexpr int kBwBiasTensor = 6;
-// State and output tensors.
-constexpr int kFwHiddenStateTensor = 0;
-constexpr int kFwOutputTensor = 1;
-constexpr int kBwHiddenStateTensor = 2;
-constexpr int kBwOutputTensor = 3;
+constexpr int kFwHiddenStateTensor = 4;
+constexpr int kBwWeightsTensor = 5;
+constexpr int kBwRecurrentWeightsTensor = 6;
+constexpr int kBwBiasTensor = 7;
+constexpr int kBwHiddenStateTensor = 8;
+// Auxiliary inputs.
+constexpr int kAuxInputTensor = 9; // Optional.
+constexpr int kFwAuxWeightsTensor = 10; // Optional.
+constexpr int kBwAuxWeightsTensor = 11; // Optional.
+// Output tensors.
+constexpr int kFwOutputTensor = 0;
+constexpr int kBwOutputTensor = 1;
+
+// Temporary tensors.
+enum TemporaryTensor {
+ kInputQuantized = 0,
+ kFwHiddenStateQuantized = 1,
+ kBwHiddenStateQuantized = 2,
+ kScalingFactors = 3,
+ kAuxInputQuantized = 4,
+ kNumTemporaryTensors = 5
+};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
- context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index);
+ context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
return scratch_tensor_index;
}
@@ -57,8 +71,8 @@ void Free(TfLiteContext* context, void* buffer) {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 7);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 12);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* fw_input_weights =
@@ -66,11 +80,30 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* fw_recurrent_weights =
GetInput(context, node, kFwRecurrentWeightsTensor);
const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor);
+ const TfLiteTensor* fw_hidden_state =
+ GetInput(context, node, kFwHiddenStateTensor);
const TfLiteTensor* bw_input_weights =
GetInput(context, node, kBwWeightsTensor);
const TfLiteTensor* bw_recurrent_weights =
GetInput(context, node, kBwRecurrentWeightsTensor);
const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor);
+ const TfLiteTensor* bw_hidden_state =
+ GetInput(context, node, kBwHiddenStateTensor);
+
+ const TfLiteTensor* aux_input =
+ GetOptionalInputTensor(context, node, kAuxInputTensor);
+ const TfLiteTensor* fw_aux_input_weights =
+ GetOptionalInputTensor(context, node, kFwAuxWeightsTensor);
+ const TfLiteTensor* bw_aux_input_weights =
+ GetOptionalInputTensor(context, node, kBwAuxWeightsTensor);
+
+ const bool aux_inputs_all_or_none =
+ ((aux_input != nullptr) && (fw_aux_input_weights != nullptr) &&
+ (bw_aux_input_weights != nullptr)) ||
+ ((aux_input == nullptr) && (fw_aux_input_weights == nullptr) &&
+ (bw_aux_input_weights == nullptr));
+ TF_LITE_ENSURE(context, aux_inputs_all_or_none);
+ const bool has_aux_input = (aux_input != nullptr);
// Check all the parameters of tensor match within themselves and match the
// input configuration.
@@ -88,40 +121,48 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
fw_bias->dims->data[0]);
TF_LITE_ASSERT_EQ(bw_recurrent_weights->dims->data[1],
bw_bias->dims->data[0]);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(fw_hidden_state), 2);
+ TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[0], batch_size);
+ TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[1], fw_num_units);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(bw_hidden_state), 2);
+ TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[0], batch_size);
+ TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[1], bw_num_units);
+
+ if (has_aux_input) {
+ // Check that aux_input has the same dimensions (except last) as the input.
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]);
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]);
+ // Check that aux_input_weights has the same dimensions (except last) as
+ // the input_weights.
+ TF_LITE_ASSERT_EQ(fw_aux_input_weights->dims->data[0], fw_num_units);
+ TF_LITE_ASSERT_EQ(bw_aux_input_weights->dims->data[0], bw_num_units);
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[2],
+ fw_aux_input_weights->dims->data[1]);
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[2],
+ bw_aux_input_weights->dims->data[1]);
+ }
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
- // Resize hidden states.
- TfLiteIntArray* fw_hidden_state_size_array = TfLiteIntArrayCreate(2);
- fw_hidden_state_size_array->data[0] = batch_size;
- fw_hidden_state_size_array->data[1] = fw_num_units;
- TfLiteTensor* fw_hidden_state =
- GetOutput(context, node, kFwHiddenStateTensor);
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_hidden_state,
- fw_hidden_state_size_array));
-
- TfLiteIntArray* bw_hidden_state_size_array = TfLiteIntArrayCreate(2);
- bw_hidden_state_size_array->data[0] = batch_size;
- bw_hidden_state_size_array->data[1] = fw_num_units;
- TfLiteTensor* bw_hidden_state =
- GetOutput(context, node, kBwHiddenStateTensor);
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_hidden_state,
- bw_hidden_state_size_array));
-
- // Mark hidden states as a persistent tensor.
- fw_hidden_state->allocation_type = kTfLiteArenaRwPersistent;
- bw_hidden_state->allocation_type = kTfLiteArenaRwPersistent;
-
const bool is_hybrid_op =
(fw_input_weights->type == kTfLiteUInt8 && input->type == kTfLiteFloat32);
if (is_hybrid_op) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+
TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(2);
- node->temporaries->data[0] = *scratch_tensor_index;
- TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
+ if (has_aux_input) {
+ node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
+ } else {
+ // No need to create a temporary tensor for the non-existent aux_input.
+ node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors - 1);
+ }
+
+ node->temporaries->data[kInputQuantized] =
+ *scratch_tensor_index + kInputQuantized;
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, kInputQuantized);
input_quantized->type = kTfLiteUInt8;
input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
@@ -129,9 +170,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
input_quantized_size));
}
- node->temporaries->data[1] = *scratch_tensor_index + 1;
+
+ node->temporaries->data[kFwHiddenStateQuantized] =
+ *scratch_tensor_index + kFwHiddenStateQuantized;
TfLiteTensor* fw_hidden_state_quantized =
- GetTemporary(context, node, /*index=*/1);
+ GetTemporary(context, node, kFwHiddenStateQuantized);
fw_hidden_state_quantized->type = kTfLiteUInt8;
fw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(fw_hidden_state_quantized->dims,
@@ -142,9 +185,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context, context->ResizeTensor(context, fw_hidden_state_quantized,
fw_hidden_state_quantized_size));
}
- node->temporaries->data[2] = *scratch_tensor_index + 2;
+
+ node->temporaries->data[kBwHiddenStateQuantized] =
+ *scratch_tensor_index + kBwHiddenStateQuantized;
TfLiteTensor* bw_hidden_state_quantized =
- GetTemporary(context, node, /*index=*/2);
+ GetTemporary(context, node, kBwHiddenStateQuantized);
bw_hidden_state_quantized->type = kTfLiteUInt8;
bw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(bw_hidden_state_quantized->dims,
@@ -155,6 +200,36 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context, context->ResizeTensor(context, bw_hidden_state_quantized,
bw_hidden_state_quantized_size));
}
+
+ // Allocate temporary tensors to store scaling factors of quantization.
+ node->temporaries->data[kScalingFactors] =
+ *scratch_tensor_index + kScalingFactors;
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, kScalingFactors);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = batch_size;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+
+ if (has_aux_input) {
+ node->temporaries->data[kAuxInputQuantized] =
+ *scratch_tensor_index + kAuxInputQuantized;
+ TfLiteTensor* aux_input_quantized =
+ GetTemporary(context, node, kAuxInputQuantized);
+ aux_input_quantized->type = kTfLiteUInt8;
+ aux_input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
+ TfLiteIntArray* aux_input_quantized_size =
+ TfLiteIntArrayCopy(aux_input->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, aux_input_quantized,
+ aux_input_quantized_size));
+ }
+ }
}
// Resize outputs.
@@ -174,19 +249,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-TfLiteStatus EvalFloat(const TfLiteTensor* input,
- const TfLiteTensor* fw_input_weights,
- const TfLiteTensor* fw_recurrent_weights,
- const TfLiteTensor* fw_bias,
- const TfLiteTensor* bw_input_weights,
- const TfLiteTensor* bw_recurrent_weights,
- const TfLiteTensor* bw_bias,
- const TfLiteSequenceRNNParams* params,
- TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
- TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* fw_input_weights,
+ const TfLiteTensor* fw_recurrent_weights, const TfLiteTensor* fw_bias,
+ const TfLiteTensor* bw_input_weights,
+ const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
+ const TfLiteTensor* aux_input, const TfLiteTensor* fw_aux_input_weights,
+ const TfLiteTensor* bw_aux_input_weights,
+ const TfLiteSequenceRNNParams* params, TfLiteTensor* fw_hidden_state,
+ TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state,
+ TfLiteTensor* bw_output) {
const int batch_size = input->dims->data[0];
const int max_time = input->dims->data[1];
const int input_size = input->dims->data[2];
+ const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
const int fw_num_units = fw_input_weights->dims->data[0];
const float* fw_bias_ptr = fw_bias->data.f;
@@ -198,6 +274,13 @@ TfLiteStatus EvalFloat(const TfLiteTensor* input,
const float* bw_input_weights_ptr = bw_input_weights->data.f;
const float* bw_recurrent_weights_ptr = bw_recurrent_weights->data.f;
+ const float* fw_aux_input_weights_ptr = (fw_aux_input_weights != nullptr)
+ ? fw_aux_input_weights->data.f
+ : nullptr;
+ const float* bw_aux_input_weights_ptr = (bw_aux_input_weights != nullptr)
+ ? bw_aux_input_weights->data.f
+ : nullptr;
+
for (int b = 0; b < batch_size; b++) {
// Forward cell.
float* fw_hidden_state_ptr_batch =
@@ -205,12 +288,17 @@ TfLiteStatus EvalFloat(const TfLiteTensor* input,
for (int s = 0; s < max_time; s++) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
float* output_ptr_batch =
fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units;
kernel_utils::RnnBatchStep(
- input_ptr_batch, fw_input_weights_ptr, fw_recurrent_weights_ptr,
- fw_bias_ptr, input_size, fw_num_units, /*batch_size=*/1,
+ input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
+ fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr,
+ input_size, aux_input_size, fw_num_units, /*batch_size=*/1,
params->activation, fw_hidden_state_ptr_batch, output_ptr_batch);
}
// Backward cell.
@@ -219,12 +307,17 @@ TfLiteStatus EvalFloat(const TfLiteTensor* input,
for (int s = max_time - 1; s >= 0; s--) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
float* output_ptr_batch =
bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units;
kernel_utils::RnnBatchStep(
- input_ptr_batch, bw_input_weights_ptr, bw_recurrent_weights_ptr,
- bw_bias_ptr, input_size, bw_num_units, /*batch_size=*/1,
+ input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
+ bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr,
+ input_size, aux_input_size, bw_num_units, /*batch_size=*/1,
params->activation, bw_hidden_state_ptr_batch, output_ptr_batch);
}
}
@@ -236,14 +329,17 @@ TfLiteStatus EvalHybrid(
const TfLiteTensor* fw_recurrent_weights, const TfLiteTensor* fw_bias,
const TfLiteTensor* bw_input_weights,
const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
- const TfLiteSequenceRNNParams* params, TfLiteTensor* input_quantized,
- TfLiteTensor* fw_hidden_state_quantized, TfLiteTensor* fw_scaling_factors,
- TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
- TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_scaling_factors,
+ const TfLiteTensor* aux_input, const TfLiteTensor* aux_fw_input_weights,
+ const TfLiteTensor* aux_bw_input_weights,
+ const TfLiteSequenceRNNParams* params, TfLiteTensor* scaling_factors,
+ TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
+ TfLiteTensor* fw_hidden_state_quantized, TfLiteTensor* fw_hidden_state,
+ TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state_quantized,
TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
const int batch_size = input->dims->data[0];
const int max_time = input->dims->data[1];
const int input_size = input->dims->data[2];
+ const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
const int fw_num_units = fw_input_weights->dims->data[0];
const float* fw_bias_ptr = fw_bias->data.f;
@@ -263,6 +359,22 @@ TfLiteStatus EvalHybrid(
reinterpret_cast<const int8_t*>(bw_recurrent_weights->data.uint8);
float bw_recurrent_weights_scale = bw_recurrent_weights->params.scale;
+ // Set the auxiliary pointers and scales if needed.
+ int8_t* aux_fw_input_weights_ptr = nullptr;
+ float aux_fw_input_weights_scale = 0.0f;
+ int8_t* aux_bw_input_weights_ptr = nullptr;
+ float aux_bw_input_weights_scale = 0.0f;
+ int8_t* aux_quantized_input_ptr = nullptr;
+ if (aux_input_size > 0) {
+ aux_fw_input_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_fw_input_weights->data.uint8);
+ aux_fw_input_weights_scale = aux_fw_input_weights->params.scale;
+ aux_bw_input_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_bw_input_weights->data.uint8);
+ aux_bw_input_weights_scale = aux_bw_input_weights->params.scale;
+ aux_quantized_input_ptr = reinterpret_cast<int8_t*>(aux_input_quantized);
+ }
+
// Initialize temporary storage for quantized values.
int8_t* quantized_input_ptr =
reinterpret_cast<int8_t*>(input_quantized->data.uint8);
@@ -270,8 +382,7 @@ TfLiteStatus EvalHybrid(
reinterpret_cast<int8_t*>(fw_hidden_state_quantized->data.uint8);
int8_t* bw_quantized_hidden_state_ptr =
reinterpret_cast<int8_t*>(bw_hidden_state_quantized->data.uint8);
- float* fw_scaling_factors_ptr = fw_scaling_factors->data.f;
- float* bw_scaling_factors_ptr = bw_scaling_factors->data.f;
+ float* scaling_factors_ptr = scaling_factors->data.f;
for (int b = 0; b < batch_size; b++) {
// Forward cell.
@@ -280,15 +391,22 @@ TfLiteStatus EvalHybrid(
for (int s = 0; s < max_time; s++) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
float* output_ptr_batch =
fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units;
kernel_utils::RnnBatchStep(
input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
- fw_recurrent_weights_ptr, fw_recurrent_weights_scale, fw_bias_ptr,
- input_size, fw_num_units, /*batch_size=*/1, params->activation,
- quantized_input_ptr, fw_quantized_hidden_state_ptr,
- fw_scaling_factors_ptr, fw_hidden_state_ptr_batch, output_ptr_batch);
+ aux_input_ptr_batch, aux_fw_input_weights_ptr,
+ aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
+ fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
+ fw_num_units, /*batch_size=*/1, params->activation,
+ quantized_input_ptr, aux_quantized_input_ptr,
+ fw_quantized_hidden_state_ptr, scaling_factors_ptr,
+ fw_hidden_state_ptr_batch, output_ptr_batch);
}
// Backward cell.
float* bw_hidden_state_ptr_batch =
@@ -296,15 +414,22 @@ TfLiteStatus EvalHybrid(
for (int s = max_time - 1; s >= 0; s--) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
float* output_ptr_batch =
bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units;
kernel_utils::RnnBatchStep(
input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
- bw_recurrent_weights_ptr, bw_recurrent_weights_scale, bw_bias_ptr,
- input_size, bw_num_units, /*batch_size=*/1, params->activation,
- quantized_input_ptr, bw_quantized_hidden_state_ptr,
- bw_scaling_factors_ptr, bw_hidden_state_ptr_batch, output_ptr_batch);
+ aux_input_ptr_batch, aux_bw_input_weights_ptr,
+ aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
+ bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
+ bw_num_units, /*batch_size=*/1, params->activation,
+ quantized_input_ptr, aux_quantized_input_ptr,
+ bw_quantized_hidden_state_ptr, scaling_factors_ptr,
+ bw_hidden_state_ptr_batch, output_ptr_batch);
}
}
return kTfLiteOk;
@@ -326,29 +451,49 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetInput(context, node, kBwRecurrentWeightsTensor);
const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor);
- TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
+ // Get auxiliary inputs.
+ const TfLiteTensor* aux_input =
+ GetOptionalInputTensor(context, node, kAuxInputTensor);
+ const TfLiteTensor* fw_aux_input_weights =
+ GetOptionalInputTensor(context, node, kFwAuxWeightsTensor);
+ const TfLiteTensor* bw_aux_input_weights =
+ GetOptionalInputTensor(context, node, kBwAuxWeightsTensor);
+
TfLiteTensor* fw_hidden_state =
- GetOutput(context, node, kFwHiddenStateTensor);
- TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ GetVariableInput(context, node, kFwHiddenStateTensor);
TfLiteTensor* bw_hidden_state =
- GetOutput(context, node, kBwHiddenStateTensor);
+ GetVariableInput(context, node, kBwHiddenStateTensor);
+
+ TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
+ TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
switch (fw_input_weights->type) {
case kTfLiteFloat32:
return EvalFloat(input, fw_input_weights, fw_recurrent_weights, fw_bias,
- bw_input_weights, bw_recurrent_weights, bw_bias, params,
- fw_hidden_state, fw_output, bw_hidden_state, bw_output);
+ bw_input_weights, bw_recurrent_weights, bw_bias,
+ aux_input, fw_aux_input_weights, bw_aux_input_weights,
+ params, fw_hidden_state, fw_output, bw_hidden_state,
+ bw_output);
case kTfLiteUInt8: {
- TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
- TfLiteTensor* fw_hidden_state_quantized = GetTemporary(context, node, 1);
- TfLiteTensor* bw_hidden_state_quantized = GetTemporary(context, node, 2);
- TfLiteTensor* fw_scaling_factors = GetTemporary(context, node, 3);
- TfLiteTensor* bw_scaling_factors = GetTemporary(context, node, 4);
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, kInputQuantized);
+ TfLiteTensor* fw_hidden_state_quantized =
+ GetTemporary(context, node, kFwHiddenStateQuantized);
+ TfLiteTensor* bw_hidden_state_quantized =
+ GetTemporary(context, node, kBwHiddenStateQuantized);
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, kScalingFactors);
+ TfLiteTensor* aux_input_quantized =
+ (aux_input != nullptr)
+ ? GetTemporary(context, node, kAuxInputQuantized)
+ : nullptr;
+
return EvalHybrid(input, fw_input_weights, fw_recurrent_weights, fw_bias,
- bw_input_weights, bw_recurrent_weights, bw_bias, params,
- input_quantized, fw_hidden_state_quantized,
- fw_scaling_factors, fw_hidden_state, fw_output,
- bw_hidden_state_quantized, bw_scaling_factors,
+ bw_input_weights, bw_recurrent_weights, bw_bias,
+ aux_input, fw_aux_input_weights, bw_aux_input_weights,
+ params, scaling_factors, input_quantized,
+ aux_input_quantized, fw_hidden_state_quantized,
+ fw_hidden_state, fw_output, bw_hidden_state_quantized,
bw_hidden_state, bw_output);
}
default:
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
index 911b108eaa..3e34ba6196 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
@@ -664,13 +664,19 @@ class BidirectionalRNNOpModel : public SingleOpModel {
fw_weights_ = AddInput(TensorType_FLOAT32);
fw_recurrent_weights_ = AddInput(TensorType_FLOAT32);
fw_bias_ = AddInput(TensorType_FLOAT32);
- fw_hidden_state_ = AddOutput(TensorType_FLOAT32);
- fw_output_ = AddOutput(TensorType_FLOAT32);
+ fw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
bw_weights_ = AddInput(TensorType_FLOAT32);
bw_recurrent_weights_ = AddInput(TensorType_FLOAT32);
bw_bias_ = AddInput(TensorType_FLOAT32);
- bw_hidden_state_ = AddOutput(TensorType_FLOAT32);
+ bw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
+
+ aux_input_ = AddNullInput();
+ aux_fw_weights_ = AddNullInput();
+ aux_bw_weights_ = AddNullInput();
+
+ fw_output_ = AddOutput(TensorType_FLOAT32);
bw_output_ = AddOutput(TensorType_FLOAT32);
+
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
BuiltinOptions_SequenceRNNOptions,
CreateSequenceRNNOptions(builder_, /*time_major=*/false,
@@ -681,9 +687,14 @@ class BidirectionalRNNOpModel : public SingleOpModel {
{fw_units_, input_size_}, // fw_weights
{fw_units_, fw_units_}, // fw_recurrent_weights
{fw_units_}, // fw_bias
+ {batches_, fw_units_}, // fw_hidden_state
{bw_units_, input_size_}, // bw_weights
{bw_units_, bw_units_}, // bw_recurrent_weights
- {bw_units_} // bw_bias
+ {bw_units_}, // bw_bias
+ {batches_, bw_units_}, // bw_hidden_state
+ {batches_, sequence_len_, 0}, // aux_input
+ {fw_units_, 0}, // aux_fw_weights
+ {bw_units_, 0}, // aux_bw_weights
});
}
@@ -719,19 +730,6 @@ class BidirectionalRNNOpModel : public SingleOpModel {
PopulateTensor(input_, offset, begin, end);
}
- void ResetHiddenStates() {
- const int fw_zero_buffer_size = fw_units_ * batches_;
- std::unique_ptr<float[]> fw_zero_buffer(new float[fw_zero_buffer_size]);
- memset(fw_zero_buffer.get(), 0, fw_zero_buffer_size * sizeof(float));
- PopulateTensor(fw_hidden_state_, 0, fw_zero_buffer.get(),
- fw_zero_buffer.get() + fw_zero_buffer_size);
- const int bw_zero_buffer_size = bw_units_ * batches_;
- std::unique_ptr<float[]> bw_zero_buffer(new float[bw_zero_buffer_size]);
- memset(bw_zero_buffer.get(), 0, bw_zero_buffer_size * sizeof(float));
- PopulateTensor(bw_hidden_state_, 0, bw_zero_buffer.get(),
- bw_zero_buffer.get() + bw_zero_buffer_size);
- }
-
std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); }
std::vector<float> GetBwOutput() { return ExtractVector<float>(bw_output_); }
@@ -753,6 +751,9 @@ class BidirectionalRNNOpModel : public SingleOpModel {
int bw_bias_;
int bw_hidden_state_;
int bw_output_;
+ int aux_input_;
+ int aux_fw_weights_;
+ int aux_bw_weights_;
int batches_;
int sequence_len_;
@@ -774,7 +775,6 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) {
rnn.SetFwRecurrentWeights(recurrent_weights);
rnn.SetBwRecurrentWeights(recurrent_weights);
- rnn.ResetHiddenStates();
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
float* batch_start = rnn_input;
float* batch_end = batch_start + input_sequence_size;
@@ -813,8 +813,6 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
rnn.SetFwRecurrentWeights(recurrent_weights);
rnn.SetBwRecurrentWeights(recurrent_weights);
- rnn.ResetHiddenStates();
-
// Reverse inputs in each batch: in_1, in_2,..., in_k is inserted in the
// following order: [in_k,..., in_2, in_1, in_k,...,in_2, in_1].
for (int i = 0; i < rnn.sequence_len(); i++) {
@@ -880,8 +878,6 @@ TEST(BidirectionalRNNOpTest, EndToEndTest) {
rnn.SetFwRecurrentWeights(recurrent_weights);
rnn.SetBwRecurrentWeights(recurrent_weights);
- rnn.ResetHiddenStates();
-
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
const int output_sequence_size = output_size * rnn.sequence_len();
const int num_examples = 64;
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index 51989f541f..3ed0cdb131 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -249,6 +249,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired(context, node));
+ int channels_in = filter->dims->data[3];
int channels_out = filter->dims->data[0];
int width = input->dims->data[2];
int height = input->dims->data[1];
@@ -372,12 +373,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
data->scaling_factors_id;
TfLiteTensor* scaling_factors =
GetTemporary(context, node, data->scaling_factors_index);
- scaling_factors->type = kTfLiteInt32;
+ scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw;
TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
// Only one scale factor per batch is typically necessary. See optimized
- // implementation for why we need to allocate for height elements here.
- scaling_factors_size->data[0] = height;
+ // implementation for why we need to allocate for the height of the inputs
+ // flattened to 2D.
+ scaling_factors_size->data[0] = NumElements(input) / channels_in;
if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
scaling_factors_size));
@@ -549,7 +551,10 @@ void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
scaling_factors_ptr[b] *= filter->params.scale;
}
- int8_t* im2col_ptr = reinterpret_cast<int8_t*>(im2col->data.uint8);
+ int8_t* im2col_ptr = nullptr;
+ if (im2col != nullptr) {
+ im2col_ptr = reinterpret_cast<int8_t*>(im2col->data.uint8);
+ }
int8_t* filter_ptr = reinterpret_cast<int8_t*>(filter->data.uint8);
switch (kernel_type) {
diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc
index a4b9fb1a0b..411615aa62 100644
--- a/tensorflow/contrib/lite/kernels/conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/conv_test.cc
@@ -177,6 +177,69 @@ TEST_P(ConvolutionOpTest, SimpleTestFloat32WithChannels) {
}));
}
+TEST_P(ConvolutionOpTest, PointwiseFloat32) {
+ ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_FLOAT32, {1, 1, 1, 2}},
+ {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ // First batch
+ 1.5, 1.5, 1.5, 1.5, // row = 1
+ 3., 3., 3., 3., // row = 2
+ // Second batch
+ 1.5, 3., 4.5, 6., // row = 1
+ 1.5, 3., 4.5, 6., // row = 2
+ }));
+}
+
+// TODO(alanchiao): this passes locally, but fails on continuous build system.
+// Re-enable when root cause found.
+TEST_P(ConvolutionOpTest, DISABLED_PointwiseMultifilterFloat32) {
+ ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_FLOAT32, {2, 1, 1, 2}},
+ {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ 2, 3, // second filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({
+ 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 3., 5., 3.,
+ 5., 3., 5., 3., 5., 1.5, 2.5, 3., 5., 4.5, 7.5,
+ 6., 10., 1.5, 2.5, 3., 5., 4.5, 7.5, 6., 10.,
+ }));
+}
+
TEST_P(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) {
ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 6, 1}},
{TensorType_FLOAT32, {1, 2, 2, 1}},
@@ -769,6 +832,82 @@ TEST_P(ConvolutionOpTest, SimpleTestHybridWithChannels) {
0.16)));
}
+TEST_P(ConvolutionOpTest, PointwiseHybrid) {
+ HybridConvolutionOpModel m(
+ GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_UINT8, {1, 1, 1, 2}}, {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ // Example: we get 3.03156 instead of 3.
+ //
+ // Second batch:
+ // 0.5 0.5 1 1 1.5 1.5 2 2 -> 32 32 64 64 95 95 127 127 with scale factor
+ // 127/2. We care about the two 64's.
+ //
+ // Filter:
+ // 64 127 with scale factor of 127/2.
+ //
+ // (64 * 64 + 64 * 127) * (2/127)^2 gives us the expected result.
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.5, 1.5, 1.5, 1.5, // first batch, row = 1
+ 3., 3., 3., 3., // first batch, row = 2
+ 1.5, 3., 4.5, 6., // second batch, row = 1
+ 1.5, 3., 4.5, 6., // second batch, row = 2
+ },
+ 0.0316)));
+}
+
+// TODO(alanchiao): this passes locally, but fails on continuous build system.
+// Re-enable when root cause found.
+TEST_P(ConvolutionOpTest, DISABLED_PointwiseMultifilterHybrid) {
+ HybridConvolutionOpModel m(
+ GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_UINT8, {2, 1, 1, 2}}, {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ 2, 3, // second filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 3., 5., 3.,
+ 5., 3., 5., 3., 5., 1.5, 2.5, 3., 5., 4.5, 7.5,
+ 6., 10., 1.5, 2.5, 3., 5., 4.5, 7.5, 6., 10.,
+ },
+ 0.0474)));
+}
+
INSTANTIATE_TEST_CASE_P(
ConvolutionOpTest, ConvolutionOpTest,
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
diff --git a/tensorflow/contrib/lite/kernels/floor.cc b/tensorflow/contrib/lite/kernels/floor.cc
index 697b777693..f7d5f5146d 100644
--- a/tensorflow/contrib/lite/kernels/floor.cc
+++ b/tensorflow/contrib/lite/kernels/floor.cc
@@ -41,8 +41,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- optimized_ops::Floor(GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(output), GetTensorDims(output));
+ optimized_ops::Floor(GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(output), GetTensorData<float>(output));
+
return kTfLiteOk;
}
} // namespace floor
diff --git a/tensorflow/contrib/lite/kernels/floor_div.cc b/tensorflow/contrib/lite/kernels/floor_div.cc
index 3c177ea330..75cf19a5a7 100644
--- a/tensorflow/contrib/lite/kernels/floor_div.cc
+++ b/tensorflow/contrib/lite/kernels/floor_div.cc
@@ -97,15 +97,15 @@ TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
}
}
if (requires_broadcast) {
- reference_ops::BroadcastBinaryFunction<T, T, T>(
- GetTensorData<T>(input1), GetTensorDims(input1), denominator_data,
- GetTensorDims(input2), GetTensorData<T>(output), GetTensorDims(output),
- FloorDiv<T>);
+ reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
+ GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), denominator_data, GetTensorShape(output),
+ GetTensorData<T>(output), FloorDiv<T>);
} else {
reference_ops::BinaryFunction<T, T, T>(
- GetTensorData<T>(input1), GetTensorDims(input1),
- GetTensorData<T>(input2), GetTensorDims(input2),
- GetTensorData<T>(output), GetTensorDims(output), FloorDiv<T>);
+ GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), GetTensorData<T>(input2),
+ GetTensorShape(output), GetTensorData<T>(output), FloorDiv<T>);
}
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
index 88a0622286..360b472c45 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -26,6 +26,21 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
int input_size, int num_units, int batch_size,
TfLiteFusedActivation activation,
float* hidden_state_ptr_batch, float* output_ptr_batch) {
+ RnnBatchStep(input_ptr_batch, input_weights_ptr,
+ /*aux_input_ptr_batch=*/nullptr,
+ /*aux_input_weights_ptr=*/nullptr, recurrent_weights_ptr,
+ bias_ptr, input_size, /*aux_input_size=*/0, num_units,
+ batch_size, activation, hidden_state_ptr_batch,
+ output_ptr_batch);
+}
+
+void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
+ const float* aux_input_ptr_batch,
+ const float* aux_input_weights_ptr,
+ const float* recurrent_weights_ptr, const float* bias_ptr,
+ int input_size, int aux_input_size, int num_units,
+ int batch_size, TfLiteFusedActivation activation,
+ float* hidden_state_ptr_batch, float* output_ptr_batch) {
// Output = bias
tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
output_ptr_batch);
@@ -33,6 +48,12 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_weights_ptr, num_units, input_size, input_ptr_batch, batch_size,
output_ptr_batch, /*result_stride=*/1);
+ // Output += aux_input * aux_input_weights (if they are not empty).
+ if (aux_input_size > 0) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_weights_ptr, num_units, aux_input_size, aux_input_ptr_batch,
+ batch_size, output_ptr_batch, /*result_stride=*/1);
+ }
// Output += recurrent_weights * hidden_state
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_weights_ptr, num_units, num_units, hidden_state_ptr_batch,
@@ -54,6 +75,28 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
int8_t* quantized_hidden_state_ptr_batch,
float* scaling_factors, float* hidden_state_ptr_batch,
float* output_ptr_batch) {
+ RnnBatchStep(input_ptr_batch, input_weights_ptr, input_weights_scale,
+ /*aux_input_ptr_batch=*/nullptr,
+ /*aux_input_weights_ptr=*/nullptr,
+ /*aux_input_weights_scale=*/0.0f, recurrent_weights_ptr,
+ recurrent_weights_scale, bias_ptr, input_size,
+ /*aux_input_size=*/0, num_units, batch_size, activation,
+ quantized_input_ptr_batch,
+ /*aux_quantized_input_ptr_batch=*/nullptr,
+ quantized_hidden_state_ptr_batch, scaling_factors,
+ hidden_state_ptr_batch, output_ptr_batch);
+}
+
+void RnnBatchStep(
+ const float* input_ptr_batch, const int8_t* input_weights_ptr,
+ float input_weights_scale, const float* aux_input_ptr_batch,
+ const int8_t* aux_input_weights_ptr, float aux_input_weights_scale,
+ const int8_t* recurrent_weights_ptr, float recurrent_weights_scale,
+ const float* bias_ptr, int input_size, int aux_input_size, int num_units,
+ int batch_size, TfLiteFusedActivation activation,
+ int8_t* quantized_input_ptr_batch, int8_t* aux_quantized_input_ptr_batch,
+ int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
+ float* hidden_state_ptr_batch, float* output_ptr_batch) {
// Output = bias
tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
output_ptr_batch);
@@ -80,6 +123,26 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
scaling_factors, batch_size, output_ptr_batch, /*result_stride=*/1);
}
+ if (aux_input_ptr_batch &&
+ !tensor_utils::IsZeroVector(aux_input_ptr_batch,
+ batch_size * aux_input_size)) {
+ float unused_min, unused_max;
+ for (int b = 0; b < batch_size; ++b) {
+ const int offset = b * aux_input_size;
+ tensor_utils::SymmetricQuantizeFloats(
+ aux_input_ptr_batch + offset, aux_input_size,
+ aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ scaling_factors[b] *= aux_input_weights_scale;
+ }
+
+ // Output += aux_input * aux_input_weights
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_weights_ptr, num_units, aux_input_size,
+ aux_quantized_input_ptr_batch, scaling_factors, batch_size,
+ output_ptr_batch, /*result_stride=*/1);
+ }
+
// Save quantization and matmul computation for all zero input.
if (!tensor_utils::IsZeroVector(hidden_state_ptr_batch,
batch_size * num_units)) {
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
index 599850db60..38436c1382 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
@@ -35,6 +35,15 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
TfLiteFusedActivation activation,
float* hidden_state_ptr_batch, float* output_ptr_batch);
+// Same as above but includes an auxiliary input with the corresponding weights.
+void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
+ const float* aux_input_ptr_batch,
+ const float* aux_input_weights_ptr,
+ const float* recurrent_weights_ptr, const float* bias_ptr,
+ int input_size, int aux_input_size, int num_units,
+ int batch_size, TfLiteFusedActivation activation,
+ float* hidden_state_ptr_batch, float* output_ptr_batch);
+
// Performs a quantized RNN batch inference step. Same as above, but for
// quantization purposes, we also pass in quantized_hidden_state_ptr_batch and
// quantized_input_ptr_batch pointers for temporary storage of the quantized
@@ -56,6 +65,17 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
float* scaling_factors, float* hidden_state_ptr_batch,
float* output_ptr_batch);
+void RnnBatchStep(
+ const float* input_ptr_batch, const int8_t* input_weights_ptr,
+ float input_weights_scale, const float* aux_input_ptr_batch,
+ const int8_t* aux_input_weights_ptr, float aux_input_weights_scale,
+ const int8_t* recurrent_weights_ptr, float recurrent_weights_scale,
+ const float* bias_ptr, int input_size, int aux_input_size, int num_units,
+ int batch_size, TfLiteFusedActivation activation,
+ int8_t* quantized_input_ptr_batch, int8_t* aux_quantized_input_ptr_batch,
+ int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
+ float* hidden_state_ptr_batch, float* output_ptr_batch);
+
// Performs an LSTM batch inference step for input specified by input_ptr_batch.
// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
// biases (*_bias_ptr), and buffers (*_scratch), along with additional
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
index df4d871466..b6151c40b3 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
@@ -27,8 +27,33 @@ namespace tflite {
namespace optimized_ops {
// Unoptimized reference ops:
+using reference_ops::ArgMax;
using reference_ops::Relu1;
using reference_ops::Relu6;
+using reference_ops::SpaceToBatchND;
+
+template <FusedActivationFunctionType Ac>
+void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ static_assert(Ac == FusedActivationFunctionType::kNone, "");
+ tflite::L2NormalizationParams op_params;
+ // No params need to be set for float, but reserved in signature for future
+ // activations.
+
+ L2Normalization(op_params, input_shape, input_data, output_shape,
+ output_data);
+}
+
+inline void L2Normalization(const uint8* input_data,
+ const RuntimeShape& input_shape,
+ int32 input_zero_point, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ tflite::L2NormalizationParams op_params;
+ op_params.input_zero_point = input_zero_point;
+
+ L2Normalization(op_params, input_shape, input_data, output_shape,
+ output_data);
+}
template <FusedActivationFunctionType Ac>
void L2Normalization(const float* input_data, const Dims<4>& input_dims,
@@ -296,13 +321,17 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
int output_shift, int32 output_activation_min,
int32 output_activation_max, uint8* output_data,
const Dims<4>& output_dims) {
- BroadcastMul4DSlow(
- input1_data, input1_dims, input1_offset, input2_data, input2_dims,
- input2_offset, output_offset, output_multiplier,
- // This legacy version switches the sign of the output shift.
- kReverseShift * output_shift,
- // (Break to highlight preceding line.)
- output_activation_min, output_activation_max, output_data, output_dims);
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+ op_params.input1_offset = input1_offset;
+ op_params.input2_offset = input2_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
}
// legacy, for compatibility with old checked-in code
@@ -621,6 +650,294 @@ inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+template <typename T>
+inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthToSpaceParams op_params;
+ op_params.block_size = block_size;
+
+ DepthToSpace(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::SpaceToDepthParams op_params;
+ op_params.block_size = block_size;
+
+ SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void Mul(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32 output_activation_min, int32 output_activation_max,
+ int32* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Mul(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32* output_data, const Dims<4>& output_dims) {
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+ tflite::ArithmeticParams op_params;
+ // No parameters needed.
+
+ MulNoActivation(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
+ const int16* input2_data, const Dims<4>& input2_dims,
+ int16* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ // No parameters needed.
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
+ const int16* input2_data, const Dims<4>& input2_dims,
+ int32 output_offset, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.output_offset = output_offset;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// For compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ float float_activation_min;
+ float float_activation_max;
+ GetActivationMinMax(Ac, &float_activation_min, &float_activation_max);
+ SetActivationParams(float_activation_min, float_activation_max, &op_params);
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// Legacy Dims<4>.
+inline void LocalResponseNormalization(const float* input_data,
+ const Dims<4>& input_dims, int range,
+ float bias, float alpha, float beta,
+ float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::LocalResponseNormalizationParams op_params;
+ op_params.range = range;
+ op_params.bias = bias;
+ op_params.alpha = alpha;
+ op_params.beta = beta;
+
+ LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// Legacy Dims<4> version.
+template <typename SrcT, typename DstT>
+void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
+ const Dims<4>& output_dims) {
+ Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// Legacy Dims<4> version.
+inline void Floor(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// Legacy Dims<4>
+inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, float* output_data,
+ const Dims<4>& output_dims, bool align_corners) {
+ tflite::ResizeBilinearParams op_params;
+ op_params.align_corners = align_corners;
+ ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_size_dims), output_size_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// Legacy Dims<4>
+inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, uint8* output_data,
+ const Dims<4>& output_dims, bool align_corners) {
+ tflite::ResizeBilinearParams op_params;
+ op_params.align_corners = align_corners;
+ ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_size_dims), output_size_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
+ output_data, output_dims, /*align_corners=*/false);
+}
+
+// legacy, for compatibility with old checked-in code
+inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, uint8* output_data,
+ const Dims<4>& output_dims) {
+ ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
+ output_data, output_dims, /*align_corners=*/false);
+}
+
+// Legacy Dims<4>.
+template <typename T>
+inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* crops_data, const Dims<4>& crops_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ BatchToSpaceND(DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// Legacy signature, function covered both Pad and PadV2.
+template <typename T>
+inline void PadV2(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const T pad_value) {
+ TFLITE_DCHECK_EQ(left_paddings.size(), 4);
+ TFLITE_DCHECK_EQ(right_paddings.size(), 4);
+ tflite::PadParams op_params;
+ op_params.left_padding_count = 4;
+ op_params.right_padding_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.left_padding[i] = left_paddings[3 - i];
+ op_params.right_padding[i] = right_paddings[3 - i];
+ }
+ const T pad_value_copy = pad_value;
+
+ Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
+ DimsToShape(output_dims), output_data);
+}
+
+// Old Pad that calls legacy PadV2.
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const int32_t pad_value) {
+ const T converted_pad_value = static_cast<T>(pad_value);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, converted_pad_value);
+}
+
+// Old Pad that only padded with 0.
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims) {
+ const T pad_value = static_cast<T>(0);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, pad_value);
+}
+
+template <typename T>
+inline void Slice(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& begin, const std::vector<int>& size,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::SliceParams op_params;
+ op_params.begin_count = 4;
+ op_params.size_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.begin[i] = begin[3 - i];
+ op_params.size[i] = size[3 - i];
+ }
+
+ Slice(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Minimum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Maximum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
} // namespace optimized_ops
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 420bc68b43..70b6994a2b 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -236,6 +236,35 @@ void NeonVectorVectorCwiseProductAccumulate(const float* vector1,
}
}
+void NeonVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result) {
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int postamble_start =
+ v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
+
+ for (int b = 0; b < n_batch; b++) {
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ // Load from memory to vectors.
+ float32x4_t batch_vector_f32x4 = vld1q_f32(batch_vector + v);
+ float32x4_t vector_f32x4 = vld1q_f32(vector + v);
+ // Multiply.
+ float32x4_t result_f32x4 = vmulq_f32(batch_vector_f32x4, vector_f32x4);
+ // Store.
+ vst1q_f32(result + v, result_f32x4);
+ }
+ // Postamble loop
+ for (int v = postamble_start; v < v_size; v++) {
+ result[v] = vector[v] * batch_vector[v];
+ }
+ // Update the pointers.
+ result += v_size;
+ batch_vector += v_size;
+ }
+}
+
void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector,
int v_size,
const float* batch_vector,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
index 63c89d1eee..e671624fe7 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -52,6 +52,13 @@ void VectorVectorCwiseProductAccumulate(const float* vector1,
result);
}
+void VectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result) {
+ NEON_OR_PORTABLE(VectorBatchVectorCwiseProduct, vector, v_size, batch_vector,
+ n_batch, result);
+}
+
void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
const float* batch_vector,
int n_batch, float* result) {
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index e4bb4e0534..70adffda3b 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -1948,7 +1948,7 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
const int filter_width = ArraySize(filter_dims, 1);
const int filter_height = ArraySize(filter_dims, 2);
- const int8* gemm_input_data = nullptr;
+ const int8_t* gemm_input_data = nullptr;
int num_input;
const bool need_im2col = stride_width != 1 || stride_height != 1 ||
filter_width != 1 || filter_height != 1;
@@ -2338,18 +2338,6 @@ inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
}
}
-// Legacy Dims<4>.
-template <typename T>
-inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
- int block_size, T* output_data,
- const Dims<4>& output_dims) {
- tflite::DepthToSpaceParams op_params;
- op_params.block_size = block_size;
-
- DepthToSpace(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_dims), output_data);
-}
-
template <typename T>
inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
const RuntimeShape& unextended_input_shape,
@@ -2391,18 +2379,6 @@ inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
}
}
-// Legacy Dims<4>.
-template <typename T>
-inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
- int block_size, T* output_data,
- const Dims<4>& output_dims) {
- tflite::SpaceToDepthParams op_params;
- op_params.block_size = block_size;
-
- SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_dims), output_data);
-}
-
inline void Relu(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Relu (not fused)");
@@ -2438,18 +2414,6 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
}
}
-// Legacy.
-template <FusedActivationFunctionType Ac>
-void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
- static_assert(Ac == FusedActivationFunctionType::kNone, "");
- tflite::L2NormalizationParams op_params;
- // No params need to be set for float.
-
- L2Normalization(op_params, input_shape, input_data, output_shape,
- output_data);
-}
-
inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
int32* output_inv_sqrt,
int* output_shift) {
@@ -2535,18 +2499,6 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
}
}
-// Legacy.
-inline void L2Normalization(const uint8* input_data,
- const RuntimeShape& input_shape,
- int32 input_zero_point, uint8* output_data,
- const RuntimeShape& output_shape) {
- tflite::L2NormalizationParams op_params;
- op_params.input_zero_point = input_zero_point;
-
- L2Normalization(op_params, input_shape, input_data, output_shape,
- output_data);
-}
-
inline void Add(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const float* input1_data,
const RuntimeShape& input2_shape, const float* input2_data,
@@ -2888,32 +2840,6 @@ inline void Mul(const ArithmeticParams& params,
}
}
-// Legacy Dims<4>.
-inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
- tflite::ArithmeticParams op_params;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
-
- Mul(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
- output_data);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Mul(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
inline void Mul(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const int32* input1_data,
const RuntimeShape& input2_shape, const int32* input2_data,
@@ -2931,20 +2857,6 @@ inline void Mul(const ArithmeticParams& params,
}
}
-// Legacy Dims<4>.
-inline void Mul(const int32* input1_data, const Dims<4>& input1_dims,
- const int32* input2_data, const Dims<4>& input2_dims,
- int32 output_activation_min, int32 output_activation_max,
- int32* output_data, const Dims<4>& output_dims) {
- tflite::ArithmeticParams op_params;
- op_params.quantized_activation_min = output_activation_min;
- op_params.quantized_activation_max = output_activation_max;
-
- Mul(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
- output_data);
-}
-
inline void MulNoActivation(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const int32* input1_data,
@@ -2971,20 +2883,6 @@ inline void MulNoActivation(const ArithmeticParams& params,
}
}
-// Legacy Dims<4>.
-template <FusedActivationFunctionType Ac>
-void Mul(const int32* input1_data, const Dims<4>& input1_dims,
- const int32* input2_data, const Dims<4>& input2_dims,
- int32* output_data, const Dims<4>& output_dims) {
- TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
- tflite::ArithmeticParams op_params;
- // No parameters needed.
-
- MulNoActivation(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
-}
-
inline void Mul(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const int16* input1_data,
const RuntimeShape& input2_shape, const int16* input2_data,
@@ -3006,18 +2904,6 @@ inline void Mul(const ArithmeticParams& params,
}
}
-// Legacy Dims<4>.
-inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
- const int16* input2_data, const Dims<4>& input2_dims,
- int16* output_data, const Dims<4>& output_dims) {
- tflite::ArithmeticParams op_params;
- // No parameters needed.
-
- Mul(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
- output_data);
-}
-
inline void Mul(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const int16* input1_data,
const RuntimeShape& input2_shape, const int16* input2_data,
@@ -3049,53 +2935,6 @@ inline void Mul(const ArithmeticParams& params,
}
}
-// Legacy Dims<4>.
-inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
- const int16* input2_data, const Dims<4>& input2_dims,
- int32 output_offset, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- tflite::ArithmeticParams op_params;
- op_params.output_offset = output_offset;
- op_params.quantized_activation_min = output_activation_min;
- op_params.quantized_activation_max = output_activation_max;
-
- Mul(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
- output_data);
-}
-
-// Legacy Dims<4>.
-template <typename T>
-void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- tflite::ArithmeticParams op_params;
- SetActivationParams(output_activation_min, output_activation_max, &op_params);
-
- BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
-}
-
-// Legacy Dims<4>.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-inline void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float* output_data, const Dims<4>& output_dims) {
- tflite::ArithmeticParams op_params;
- float float_activation_min;
- float float_activation_max;
- GetActivationMinMax(Ac, &float_activation_min, &float_activation_max);
- SetActivationParams(float_activation_min, float_activation_max, &op_params);
-
- BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
-}
-
// Element-wise mul that can often be used for inner loop of broadcast Mul as
// well as the non-broadcast Mul.
inline void MulElementwise(int size, const ArithmeticParams& params,
@@ -3324,15 +3163,28 @@ inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
// reference_ops.h.
template <typename T>
-void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastDiv");
+void BroadcastDiv4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& unextended_input1_shape,
+ const T* input1_data,
+ const RuntimeShape& unextended_input2_shape,
+ const T* input2_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastDiv4DSlow");
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -3345,14 +3197,14 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for the
// best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ output_data[Offset(output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] /
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] /
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
output_activation_min, output_activation_max);
}
}
@@ -3360,6 +3212,21 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+template <typename T>
+void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
// TODO(aselle): This is not actually optimized yet.
inline void SubNonBroadcast(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
@@ -4233,22 +4100,6 @@ inline void LocalResponseNormalization(
}
}
-// Legacy Dims<4>.
-inline void LocalResponseNormalization(const float* input_data,
- const Dims<4>& input_dims, int range,
- float bias, float alpha, float beta,
- float* output_data,
- const Dims<4>& output_dims) {
- tflite::LocalResponseNormalizationParams op_params;
- op_params.range = range;
- op_params.bias = bias;
- op_params.alpha = alpha;
- op_params.beta = beta;
-
- LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_dims), output_data);
-}
-
inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
float beta, float* output_data,
const RuntimeShape& output_shape) {
@@ -5190,14 +5041,6 @@ inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
output_map.array() = input_map.array().template cast<DstT>();
}
-// Legacy Dims<4> version.
-template <typename SrcT, typename DstT>
-void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
- const Dims<4>& output_dims) {
- Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data);
-}
-
inline void Floor(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Floor");
@@ -5206,13 +5049,6 @@ inline void Floor(const RuntimeShape& input_shape, const float* input_data,
output_map.array() = Eigen::floor(input_map.array());
}
-// Legacy Dims<4> version.
-inline void Floor(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data);
-}
-
#ifdef USE_NEON
inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
float scale, float* output_ptr) {
@@ -5586,18 +5422,15 @@ inline void ResizeBilinearGenericSmallChannel(
inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
const RuntimeShape& unextended_input_shape,
const float* input_data,
- const RuntimeShape& unextended_output_size_shape,
+ const RuntimeShape& output_size_shape,
const int32* output_size_data,
const RuntimeShape& unextended_output_shape,
float* output_data) {
gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_size_shape =
- RuntimeShape::ExtendedShape(4, unextended_output_size_shape);
RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
@@ -5606,12 +5439,9 @@ inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
int32 input_width = input_shape.Dims(2);
int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
- TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1);
- TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1);
- TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1);
- TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2);
- int32 output_height = output_size_data[Offset(output_size_shape, 0, 0, 0, 0)];
- int32 output_width = output_size_data[Offset(output_size_shape, 0, 0, 0, 1)];
+ TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
+ int32 output_height = output_size_data[0];
+ int32 output_width = output_size_data[1];
// Specialize for 2x2 upsample.
if (!op_params.align_corners && output_height == 2 * input_height &&
@@ -5636,43 +5466,31 @@ inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
}
}
-// Legacy Dims<4>
-inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
- const int32* output_size_data,
- const Dims<4>& output_size_dims, float* output_data,
- const Dims<4>& output_dims, bool align_corners) {
- tflite::ResizeBilinearParams op_params;
- op_params.align_corners = align_corners;
- ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_size_dims), output_size_data,
- DimsToShape(output_dims), output_data);
-}
-
// TODO(prabhumk): This is not a real quantized bilinear. It does not use int8
// or int16 arithmetic.
inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
- const RuntimeShape& input_shape,
+ const RuntimeShape& unextended_input_shape,
const uint8* input_data,
const RuntimeShape& output_size_shape,
const int32* output_size_data,
- const RuntimeShape& output_shape,
+ const RuntimeShape& unextended_output_shape,
uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
- TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_size_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
int32 input_height = input_shape.Dims(1);
int32 input_width = input_shape.Dims(2);
int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
- TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1);
- TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1);
- TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1);
- TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2);
- int32 output_height = output_size_data[Offset(output_size_shape, 0, 0, 0, 0)];
- int32 output_width = output_size_data[Offset(output_size_shape, 0, 0, 0, 1)];
+ TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
+ int32 output_height = output_size_data[0];
+ int32 output_width = output_size_data[1];
float height_scale =
(op_params.align_corners && output_height > 1)
@@ -5690,36 +5508,6 @@ inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
output_data);
}
-// Legacy Dims<4>
-inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
- const int32* output_size_data,
- const Dims<4>& output_size_dims, uint8* output_data,
- const Dims<4>& output_dims, bool align_corners) {
- tflite::ResizeBilinearParams op_params;
- op_params.align_corners = align_corners;
- ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_size_dims), output_size_data,
- DimsToShape(output_dims), output_data);
-}
-
-// legacy, for compatibility with old checked-in code
-inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
- const int32* output_size_data,
- const Dims<4>& output_size_dims, float* output_data,
- const Dims<4>& output_dims) {
- ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
- output_data, output_dims, /*align_corners=*/false);
-}
-
-// legacy, for compatibility with old checked-in code
-inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
- const int32* output_size_data,
- const Dims<4>& output_size_dims, uint8* output_data,
- const Dims<4>& output_dims) {
- ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
- output_data, output_dims, /*align_corners=*/false);
-}
-
// Helper methods for BatchToSpaceND.
// `spatial_index_dim` specifies post-crop offset index in this spatial
// dimension, i.e. spatial offset introduced by flattening batch to spatial
@@ -5808,19 +5596,6 @@ inline void BatchToSpaceND(
}
}
-// Legacy Dims<4>.
-template <typename T>
-inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
- const int32* block_shape_data,
- const Dims<4>& block_shape_dims,
- const int32* crops_data, const Dims<4>& crops_dims,
- T* output_data, const Dims<4>& output_dims) {
- BatchToSpaceND(DimsToShape(input_dims), input_data,
- DimsToShape(block_shape_dims), block_shape_data,
- DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
- output_data);
-}
-
template <typename T>
void TypedMemset(void* ptr, T value, size_t num) {
// Optimization for common cases where memset() will suffice.
@@ -5978,49 +5753,6 @@ inline void Pad(const tflite::PadParams& op_params,
output_data);
}
-// Legacy signature, function covered both Pad and PadV2.
-template <typename T>
-inline void PadV2(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const T pad_value) {
- TFLITE_DCHECK_EQ(left_paddings.size(), 4);
- TFLITE_DCHECK_EQ(right_paddings.size(), 4);
- tflite::PadParams op_params;
- op_params.left_padding_count = 4;
- op_params.right_padding_count = 4;
- for (int i = 0; i < 4; ++i) {
- op_params.left_padding[i] = left_paddings[3 - i];
- op_params.right_padding[i] = right_paddings[3 - i];
- }
- const T pad_value_copy = pad_value;
-
- Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
- DimsToShape(output_dims), output_data);
-}
-
-// Old Pad that calls legacy PadV2.
-template <typename T>
-inline void Pad(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const int32_t pad_value) {
- const T converted_pad_value = static_cast<T>(pad_value);
- PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
- output_dims, converted_pad_value);
-}
-
-// Old Pad that only padded with 0.
-template <typename T>
-inline void Pad(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims) {
- const T pad_value = static_cast<T>(0);
- PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
- output_dims, pad_value);
-}
-
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
@@ -6065,22 +5797,6 @@ inline void Slice(const tflite::SliceParams& op_params,
}
template <typename T>
-inline void Slice(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& begin, const std::vector<int>& size,
- T* output_data, const Dims<4>& output_dims) {
- tflite::SliceParams op_params;
- op_params.begin_count = 4;
- op_params.size_count = 4;
- for (int i = 0; i < 4; ++i) {
- op_params.begin[i] = begin[3 - i];
- op_params.size[i] = size[3 - i];
- }
-
- Slice(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_dims), output_data);
-}
-
-template <typename T>
void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape,
T* output_data) {
@@ -6103,22 +5819,6 @@ void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
}
template <typename T>
-void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
- Minimum(DimsToShape(input1_dims), input1_data, input2_data,
- DimsToShape(output_dims), output_data);
-}
-
-template <typename T>
-void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
- Maximum(DimsToShape(input1_dims), input1_data, input2_data,
- DimsToShape(output_dims), output_data);
-}
-
-template <typename T>
void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
const Dims<4>& filter_dims, int stride_width,
int stride_height, int pad_width, int pad_height,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
index 010b40b901..8664ebc4f6 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
@@ -86,6 +86,14 @@ void NeonBatchVectorBatchVectorDotProduct(const float* vector1,
int n_batch, float* result,
int result_stride);
+// Cwise product of a vector and a batch-vector.
+void PortableVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector,
+ int n_batch, float* result);
+void NeonVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result);
+
// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
// operation, the assumption here is that result array is initialized to valid
// values.
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
index 71ae74f34c..683ccdc74d 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
@@ -27,6 +27,28 @@ namespace tflite {
namespace reference_ops {
template <FusedActivationFunctionType Ac>
+void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ static_assert(Ac == FusedActivationFunctionType::kNone, "");
+ tflite::L2NormalizationParams op_params;
+ // No params need to be set for float.
+
+ L2Normalization(op_params, input_shape, input_data, output_shape,
+ output_data);
+}
+
+inline void L2Normalization(const uint8* input_data,
+ const RuntimeShape& input_shape,
+ int32 input_zero_point, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ tflite::L2NormalizationParams op_params;
+ op_params.input_zero_point = input_zero_point;
+
+ L2Normalization(op_params, input_shape, input_data, output_shape,
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
void L2Normalization(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
@@ -58,6 +80,15 @@ inline void Relu6(const float* input_data, const Dims<4>& input_dims,
output_data);
}
+inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data,
+ const RuntimeShape& input_shape, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ tflite::ActivationParams params;
+ params.quantized_activation_max = max_value;
+ params.quantized_activation_min = min_value;
+ ReluX(params, input_shape, input_data, output_shape, output_data);
+}
+
template <FusedActivationFunctionType Ac>
inline void Add(int left_shift, const uint8* input1_data,
const Dims<4>& input1_dims, int32 input1_offset,
@@ -311,6 +342,30 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims), output_data);
}
+// Legacy.
+// Transitional version that will be moved shortly to legacy_reference_ops, as
+// part of RuntimeShape revisions.
+inline void BroadcastMul4DSlow(const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+ op_params.input1_offset = input1_offset;
+ op_params.input2_offset = input2_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = output_shift;
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
int32 input1_offset, const uint8* input2_data,
const Dims<4>& input2_dims, int32 input2_offset,
@@ -624,6 +679,377 @@ inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+template <typename T>
+inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthToSpaceParams op_params;
+ op_params.block_size = block_size;
+
+ DepthToSpace(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::SpaceToDepthParams op_params;
+ op_params.block_size = block_size;
+
+ SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void Mul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ T output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
+ const int16* input2_data, const Dims<4>& input2_dims,
+ int16* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ // No params in this version.
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
+ const int16* input2_data, const Dims<4>& input2_dims,
+ int32 output_offset, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.output_offset = output_offset;
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void LocalResponseNormalization(const float* input_data,
+ const Dims<4>& input_dims, int range,
+ float bias, float alpha, float beta,
+ float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::LocalResponseNormalizationParams op_params;
+ op_params.range = range;
+ op_params.bias = bias;
+ op_params.alpha = alpha;
+ op_params.beta = beta;
+
+ LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename SrcT, typename DstT>
+void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
+ const Dims<4>& output_dims) {
+ Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Floor(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, T* output_data,
+ const Dims<4>& output_dims, bool align_corners) {
+ tflite::ResizeBilinearParams op_params;
+ op_params.align_corners = align_corners;
+ ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_size_dims), output_size_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ ResizeBilinear<float>(input_data, input_dims, output_size_data,
+ output_size_dims, output_data, output_dims,
+ /*align_corners=*/false);
+}
+
+inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, uint8* output_data,
+ const Dims<4>& output_dims) {
+ ResizeBilinear<uint8>(input_data, input_dims, output_size_data,
+ output_size_dims, output_data, output_dims,
+ /*align_corners=*/false);
+}
+
+template <typename T>
+inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* paddings_data,
+ const Dims<4>& paddings_dims, T* output_data,
+ const Dims<4>& output_dims,
+ const int32_t pad_value) {
+ tflite::SpaceToBatchParams op_params;
+ op_params.output_offset = pad_value;
+
+ SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(paddings_dims), paddings_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* paddings_data,
+ const Dims<4>& paddings_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::SpaceToBatchParams op_params;
+ op_params.output_offset = 0;
+
+ SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(paddings_dims), paddings_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* crops_data, const Dims<4>& crops_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ BatchToSpaceND(DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// Legacy signature, function covered both Pad and PadV2.
+template <typename T>
+inline void PadV2(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const T pad_value) {
+ TFLITE_DCHECK_EQ(left_paddings.size(), 4);
+ TFLITE_DCHECK_EQ(right_paddings.size(), 4);
+ tflite::PadParams op_params;
+ op_params.left_padding_count = 4;
+ op_params.right_padding_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.left_padding[i] = left_paddings[3 - i];
+ op_params.right_padding[i] = right_paddings[3 - i];
+ }
+ // SetFloatOrInt(pad_value, &op_params.pad_value);
+ const T pad_value_copy = pad_value;
+
+ Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
+ DimsToShape(output_dims), output_data);
+}
+
+// Old Pad that calls legacy PadV2.
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const int32_t pad_value) {
+ const T converted_pad_value = static_cast<T>(pad_value);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, converted_pad_value);
+}
+
+// Old Pad that only padded with 0.
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims) {
+ const T pad_value = static_cast<T>(0);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, pad_value);
+}
+
+template <typename T>
+void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Minimum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Maximum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T, typename Op>
+void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims,
+ Op op) {
+ MaximumMinimumBroadcast4DSlow(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, op);
+}
+
+template <typename T1, typename T2, typename T3>
+void ArgMax(const T3* axis, const T1* input_data,
+ const tflite::Dims<4>& input_dims, T2* output_data,
+ const tflite::Dims<4>& output_dims) {
+ ArgMinMax(DimsToShape(input_dims), input_data, axis, DimsToShape(output_dims),
+ output_data, std::greater<T1>());
+}
+
+template <typename T1, typename T2, typename T3, typename Cmp>
+void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
+ T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
+ ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data, cmp);
+}
+
+template <typename T>
+inline void Pow(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ Pow(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
+ input2_data, DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ BroadcastPow4DSlow(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Logical(const bool* input1_data, const Dims<4>& input1_dims,
+ const bool* input2_data, const Dims<4>& input2_dims,
+ bool* output_data, const Dims<4>& output_dims,
+ const std::function<bool(bool, bool)>& func) {
+ Logical(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
+ input2_data, DimsToShape(output_dims), output_data, func);
+}
+
+inline void BroadcastLogical(const bool* input1_data,
+ const Dims<4>& input1_dims,
+ const bool* input2_data,
+ const Dims<4>& input2_dims, bool* output_data,
+ const Dims<4>& output_dims,
+ const std::function<bool(bool, bool)>& func) {
+ BroadcastLogical4DSlow(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, func);
+}
+
+// R: Result type. T1: Input 1 type. T2: Input 2 type.
+template <typename R, typename T1, typename T2>
+inline void BroadcastBinaryFunction(const T1* input1_data,
+ const Dims<4>& input1_dims,
+ const T2* input2_data,
+ const Dims<4>& input2_dims, R* output_data,
+ const Dims<4>& output_dims,
+ R (*func)(T1, T2)) {
+ BroadcastBinaryFunction(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, func);
+}
+
+// R: Result type. T1: Input 1 type. T2: Input 2 type.
+template <typename R, typename T1, typename T2>
+inline void BinaryFunction(const T1* input1_data, const Dims<4>& input1_dims,
+ const T2* input2_data, const Dims<4>& input2_dims,
+ R* output_data, const Dims<4>& output_dims,
+ R (*func)(T1, T2)) {
+ BinaryFunction(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, func);
+}
+
+template <typename T>
+inline void Slice(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& begin, const std::vector<int>& size,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::SliceParams op_params;
+ op_params.begin_count = 4;
+ op_params.size_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.begin[i] = begin[3 - i];
+ op_params.size[i] = size[3 - i];
+ }
+
+ Slice(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
index aa93e857d7..e79e75a898 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -151,6 +151,16 @@ void PortableVectorVectorCwiseProductAccumulate(const float* vector1,
}
}
+void PortableVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector,
+ int n_batch, float* result) {
+ for (int b = 0; b < n_batch; b++) {
+ for (int v = 0; v < v_size; v++) {
+ *result++ = vector[v] * *batch_vector++;
+ }
+ }
+}
+
void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector,
int v_size,
const float* batch_vector,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
index a375aaffa6..3829be0c5e 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -69,6 +69,11 @@ void PortableBatchVectorBatchVectorDotProduct(const float* vector1,
int n_batch, float* result,
int result_stride);
+// Cwise product of a vector and a batch-vector.
+void PortableVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector,
+ int n_batch, float* result);
+
// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
// operation, the assumption here is that result array is initialized to valid
// values.
@@ -161,6 +166,13 @@ void VectorVectorCwiseProductAccumulate(const float* vector1,
PortableVectorVectorCwiseProductAccumulate(vector1, vector2, v_size, result);
}
+void VectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result) {
+ PortableVectorBatchVectorCwiseProduct(vector, v_size, batch_vector, n_batch,
+ result);
+}
+
void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
const float* batch_vector,
int n_batch, float* result) {
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 3875b73e05..62f7ade7d5 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -110,6 +110,11 @@ inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
{dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
}
+inline void ShapeFromDims(const tflite::Dims<4>& dims, RuntimeShape* shape) {
+ shape->BuildFrom(
+ {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
+}
+
template <typename T>
int CountLeadingZeros(T integer_input) {
static_assert(std::is_unsigned<T>::value,
@@ -459,18 +464,6 @@ inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
}
}
-// Legacy Dims<4>.
-template <typename T>
-inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
- int block_size, T* output_data,
- const Dims<4>& output_dims) {
- tflite::DepthToSpaceParams op_params;
- op_params.block_size = block_size;
-
- DepthToSpace(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_dims), output_data);
-}
-
template <typename T>
inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
const RuntimeShape& unextended_input_shape,
@@ -523,18 +516,6 @@ inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
}
}
-// Legacy Dims<4>.
-template <typename T>
-inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
- int block_size, T* output_data,
- const Dims<4>& output_dims) {
- tflite::SpaceToDepthParams op_params;
- op_params.block_size = block_size;
-
- SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_dims), output_data);
-}
-
inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
const float* weights_data,
const Dims<4>& weights_dims, const float* bias_data,
@@ -888,7 +869,6 @@ inline void Relu6(const RuntimeShape& input_shape, const float* input_data,
inline void ReluX(const tflite::ActivationParams& params,
const RuntimeShape& input_shape, const uint8* input_data,
-
const RuntimeShape& output_shape, uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("Quantized ReluX (not fused)");
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -902,16 +882,6 @@ inline void ReluX(const tflite::ActivationParams& params,
}
}
-// Legacy.
-inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data,
- const RuntimeShape& input_shape, uint8* output_data,
- const RuntimeShape& output_shape) {
- tflite::ActivationParams params;
- params.quantized_activation_max = max_value;
- params.quantized_activation_min = min_value;
- ReluX(params, input_shape, input_data, output_shape, output_data);
-}
-
inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
const RuntimeShape& input_shape,
const float* input_data,
@@ -935,18 +905,6 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
}
}
-// Legacy .
-template <FusedActivationFunctionType Ac>
-void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
- static_assert(Ac == FusedActivationFunctionType::kNone, "");
- tflite::L2NormalizationParams op_params;
- // No params need to be set for float.
-
- L2Normalization(op_params, input_shape, input_data, output_shape,
- output_data);
-}
-
inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
int32* output_inv_sqrt,
int* output_shift) {
@@ -1028,18 +986,6 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
}
}
-// Legacy.
-inline void L2Normalization(const uint8* input_data,
- const RuntimeShape& input_shape,
- int32 input_zero_point, uint8* output_data,
- const RuntimeShape& output_shape) {
- tflite::L2NormalizationParams op_params;
- op_params.input_zero_point = input_zero_point;
-
- L2Normalization(op_params, input_shape, input_data, output_shape,
- output_data);
-}
-
template <typename T>
inline void Add(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const T* input1_data,
@@ -1380,36 +1326,6 @@ inline void Mul(const ArithmeticParams& params,
}
}
-// Legacy Dims<4>.
-template <typename T>
-inline void Mul(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- tflite::ArithmeticParams op_params;
- SetActivationParams(output_activation_min, output_activation_max, &op_params);
-
- Mul(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
- output_data);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Mul(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- tflite::ArithmeticParams op_params;
- SetActivationParams(output_activation_min, output_activation_max, &op_params);
-
- Mul(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
- output_data);
-}
-
// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
@@ -1468,36 +1384,6 @@ void BroadcastMul4DSlow(const ArithmeticParams& params,
}
}
-// Legacy.
-template <typename T>
-void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- tflite::ArithmeticParams op_params;
- SetActivationParams(output_activation_min, output_activation_max, &op_params);
-
- BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac, typename T>
-void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- T output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- tflite::ArithmeticParams op_params;
- SetActivationParams(output_activation_min, output_activation_max, &op_params);
-
- BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
-}
-
// Element-wise mul that can often be used for inner loop of broadcast Mul as
// well as the non-broadcast Mul.
inline void MulElementwise(int size, const ArithmeticParams& params,
@@ -1626,30 +1512,6 @@ inline void BroadcastMul4DSlow(const ArithmeticParams& params,
}
}
-// Legacy.
-// Transitional version that will be moved shortly to legacy_reference_ops, as
-// part of RuntimeShape revisions.
-inline void BroadcastMul4DSlow(const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- tflite::ArithmeticParams op_params;
- SetActivationParams(output_activation_min, output_activation_max, &op_params);
- op_params.input1_offset = input1_offset;
- op_params.input2_offset = input2_offset;
- op_params.output_offset = output_offset;
- op_params.output_multiplier = output_multiplier;
- op_params.output_shift = output_shift;
-
- BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
-}
-
inline void Mul(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const int16* input1_data,
const RuntimeShape& input2_shape, const int16* input2_data,
@@ -1669,18 +1531,6 @@ inline void Mul(const ArithmeticParams& params,
}
}
-// Legacy Dims<4>.
-inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
- const int16* input2_data, const Dims<4>& input2_dims,
- int16* output_data, const Dims<4>& output_dims) {
- tflite::ArithmeticParams op_params;
- // No params in this version.
-
- Mul(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
- output_data);
-}
-
inline void Mul(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const int16* input1_data,
const RuntimeShape& input2_shape, const int16* input2_data,
@@ -1710,36 +1560,32 @@ inline void Mul(const ArithmeticParams& params,
}
}
-// Legacy Dims<4>.
-inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
- const int16* input2_data, const Dims<4>& input2_dims,
- int32 output_offset, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- tflite::ArithmeticParams op_params;
- op_params.quantized_activation_min = output_activation_min;
- op_params.quantized_activation_max = output_activation_max;
- op_params.output_offset = output_offset;
-
- Mul(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
- output_data);
-}
-
// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
// generate max(D1, D2) nested for loops.
template <typename T>
-void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastDiv");
+void BroadcastDiv4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& unextended_input1_shape,
+ const T* input1_data,
+ const RuntimeShape& unextended_input2_shape,
+ const T* input2_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -1752,14 +1598,14 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for
// the best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ output_data[Offset(output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] /
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] /
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
output_activation_min, output_activation_max);
}
}
@@ -1767,12 +1613,32 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
template <typename T>
-inline void Div(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void Div(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape& input2_shape, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] / input2_data[i], output_activation_min,
@@ -1780,6 +1646,21 @@ inline void Div(const T* input1_data, const Dims<4>& input1_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+template <typename T>
+inline void Div(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ Div(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
inline void SubNonBroadcast(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const float* input1_data,
@@ -2075,52 +1956,43 @@ inline void SubWithActivation(const ArithmeticParams& params,
}
}
-template <FusedActivationFunctionType Ac, typename Scalar>
-void Concatenation(int concat_dim, const Scalar* const* input_data,
- const Dims<4>* const* input_dims, int inputs_count,
- Scalar* output_data, const Dims<4>& output_dims) {
- int concat_size = 0;
+template <typename Scalar>
+inline void Concatenation(const ConcatenationParams& params,
+ const RuntimeShape* const* input_shapes,
+ const Scalar* const* input_data,
+ const RuntimeShape& output_shape,
+ Scalar* output_data) {
+ int axis = params.axis;
+ int inputs_count = params.inputs_count;
+ const int concat_dimensions = output_shape.DimensionsCount();
+ TFLITE_DCHECK_LT(axis, concat_dimensions);
+
+ int64_t concat_size = 0;
for (int i = 0; i < inputs_count; i++) {
- for (int j = 0; j < 4; j++) {
- if (j != concat_dim) {
- MatchingArraySize(*input_dims[i], j, output_dims, j);
+ TFLITE_DCHECK_EQ(input_shapes[i]->DimensionsCount(), concat_dimensions);
+ for (int j = 0; j < concat_dimensions; j++) {
+ if (j != axis) {
+ MatchingDim(*input_shapes[i], j, output_shape, j);
}
}
- concat_size += ArraySize(*input_dims[i], concat_dim);
+ concat_size += input_shapes[i]->Dims(axis);
}
- TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- // For now we don't have a model with a Concatenation with fused activation.
- TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
- int outer_size = 1;
- for (int i = concat_dim + 1; i < 4; i++) {
- outer_size *= output_dims.sizes[i];
+ TFLITE_DCHECK_EQ(concat_size, output_shape.Dims(axis));
+ int64_t outer_size = 1;
+ for (int i = 0; i < axis; ++i) {
+ outer_size *= output_shape.Dims(i);
}
- Scalar* output_ptr = output_data;
- for (int k = 0; k < outer_size; k++) {
- for (int i = 0; i < inputs_count; ++i) {
- const int copy_size =
- input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim];
- memcpy(output_ptr, input_data[i] + k * copy_size,
- copy_size * sizeof(Scalar));
- output_ptr += copy_size;
- }
+ // For all input arrays,
+ // FlatSize() = outer_size * Dims(axis) * base_inner_size;
+ int64_t base_inner_size = 1;
+ for (int i = axis + 1; i < concat_dimensions; ++i) {
+ base_inner_size *= output_shape.Dims(i);
}
-}
-template <typename Scalar>
-void Pack(int dim, const Scalar* const* input_data,
- const Dims<4>* const* input_dims, int inputs_count,
- Scalar* output_data, const Dims<4>& output_dims) {
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- int outer_size = 1;
- for (int i = dim + 1; i < 4; i++) {
- outer_size *= output_dims.sizes[i];
- }
Scalar* output_ptr = output_data;
- const int copy_size = FlatSize(**input_dims) / outer_size;
for (int k = 0; k < outer_size; k++) {
for (int i = 0; i < inputs_count; ++i) {
+ const int copy_size = input_shapes[i]->Dims(axis) * base_inner_size;
memcpy(output_ptr, input_data[i] + k * copy_size,
copy_size * sizeof(Scalar));
output_ptr += copy_size;
@@ -2128,60 +2000,78 @@ void Pack(int dim, const Scalar* const* input_data,
}
}
-template <typename Scalar>
-void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims,
- int dimensions, int outputs_count, Scalar* const* output_datas,
- const Dims<4>& output_dims) {
- int outer_size = 1;
- for (int i = dimensions - axis; i < 4; i++) {
- outer_size *= input_dims.sizes[i];
- }
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+template <FusedActivationFunctionType Ac, typename Scalar>
+inline void Concatenation(int concat_dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ // For now we don't have a model with a Concatenation with fused activation.
+ TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
- const int copy_size = FlatSize(input_dims) / outer_size / outputs_count;
- for (int k = 0; k < outer_size; k++) {
- for (int i = 0; i < outputs_count; ++i) {
- Scalar* output_ptr = output_datas[i] + copy_size * k;
- int loc = k * outputs_count * copy_size + i * copy_size;
- memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
- }
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
}
+ tflite::ConcatenationParams op_params;
+ op_params.axis = 3 - concat_dim;
+ op_params.inputs_count = inputs_count;
+
+ Concatenation(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
}
// TODO(prabhumk): This is the same as the optimized implementation.
// TODO(prabhumk): The quantized implementation of concatentation isn't fully
// quantized as it takes scale as a floating point value. This should be fixed
// when optimizng this routine further.
-inline void Concatenation(int concat_dim, const uint8* const* input_data,
- const Dims<4>* const* input_dims,
- const int32* input_zeropoint,
- const float* input_scale, int inputs_count,
- uint8* output_data, const Dims<4>& output_dims,
- const int32 output_zeropoint,
- const float output_scale) {
+
+// template <>
+inline void ConcatenationWithScaling(const ConcatenationParams& params,
+ const RuntimeShape* const* input_shapes,
+ const uint8* const* input_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ int axis = params.axis;
+ const int32* input_zeropoint = params.input_zeropoint;
+ const float* input_scale = params.input_scale;
+ int inputs_count = params.inputs_count;
+ const int32 output_zeropoint = params.output_zeropoint;
+ const float output_scale = params.output_scale;
+
// The arguments input_zeropoint and input_scale are expected to be an array
// that have the quantization parameters for all the inputs to the concat
// operator.
TFLITE_DCHECK_GT(inputs_count, 1);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
int64_t concat_size = 0;
for (int i = 0; i < inputs_count; i++) {
+ TFLITE_DCHECK_EQ(input_shapes[i]->DimensionsCount(), 4);
for (int j = 0; j < 4; j++) {
- if (j != concat_dim) {
- MatchingArraySize(*input_dims[i], j, output_dims, j);
+ if (j != axis) {
+ MatchingDim(*input_shapes[i], j, output_shape, j);
}
}
- concat_size += ArraySize(*input_dims[i], concat_dim);
+ concat_size += input_shapes[i]->Dims(axis);
}
- TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
+ TFLITE_DCHECK_EQ(concat_size, output_shape.Dims(axis));
int64_t outer_size = 1;
- for (int i = concat_dim + 1; i < 4; i++) {
- outer_size *= output_dims.sizes[i];
+ for (int i = 0; i < axis; ++i) {
+ outer_size *= output_shape.Dims(i);
+ }
+ // For all input arrays,
+ // FlatSize() = outer_size * Dims(axis) * base_inner_size;
+ int64_t base_inner_size = 1;
+ for (int i = axis + 1; i < 4; ++i) {
+ base_inner_size *= output_shape.Dims(i);
}
const float inverse_output_scale = 1.f / output_scale;
uint8* output_ptr = output_data;
for (int k = 0; k < outer_size; k++) {
for (int i = 0; i < inputs_count; ++i) {
- const int copy_size =
- input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim];
+ const int copy_size = input_shapes[i]->Dims(axis) * base_inner_size;
const uint8* input_ptr = input_data[i] + k * copy_size;
if (input_zeropoint[i] == output_zeropoint &&
input_scale[i] == output_scale) {
@@ -2202,6 +2092,72 @@ inline void Concatenation(int concat_dim, const uint8* const* input_data,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+inline void Concatenation(int concat_dim, const uint8* const* input_data,
+ const Dims<4>* const* input_dims,
+ const int32* input_zeropoint,
+ const float* input_scale, int inputs_count,
+ uint8* output_data, const Dims<4>& output_dims,
+ const int32 output_zeropoint,
+ const float output_scale) {
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::ConcatenationParams op_params;
+ op_params.axis = 3 - concat_dim;
+ op_params.input_zeropoint = input_zeropoint;
+ op_params.input_scale = input_scale;
+ op_params.inputs_count = inputs_count;
+ op_params.output_zeropoint = output_zeropoint;
+ op_params.output_scale = output_scale;
+
+ ConcatenationWithScaling(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename Scalar>
+void Pack(int dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ int outer_size = 1;
+ for (int i = dim + 1; i < 4; i++) {
+ outer_size *= output_dims.sizes[i];
+ }
+ Scalar* output_ptr = output_data;
+ const int copy_size = FlatSize(**input_dims) / outer_size;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < inputs_count; ++i) {
+ memcpy(output_ptr, input_data[i] + k * copy_size,
+ copy_size * sizeof(Scalar));
+ output_ptr += copy_size;
+ }
+ }
+}
+
+template <typename Scalar>
+void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims,
+ int dimensions, int outputs_count, Scalar* const* output_datas,
+ const Dims<4>& output_dims) {
+ int outer_size = 1;
+ for (int i = dimensions - axis; i < 4; i++) {
+ outer_size *= input_dims.sizes[i];
+ }
+
+ const int copy_size = FlatSize(input_dims) / outer_size / outputs_count;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < outputs_count; ++i) {
+ Scalar* output_ptr = output_datas[i] + copy_size * k;
+ int loc = k * outputs_count * copy_size + i * copy_size;
+ memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
+ }
+ }
+}
+
template <typename Scalar>
void Pack(int dim, const Scalar* const* input_data,
const Dims<4>* const* input_dims, const int32* input_zeropoint,
@@ -2910,22 +2866,6 @@ inline void LocalResponseNormalization(
}
}
-// Legacy Dims<4>.
-inline void LocalResponseNormalization(const float* input_data,
- const Dims<4>& input_dims, int range,
- float bias, float alpha, float beta,
- float* output_data,
- const Dims<4>& output_dims) {
- tflite::LocalResponseNormalizationParams op_params;
- op_params.range = range;
- op_params.bias = bias;
- op_params.alpha = alpha;
- op_params.beta = beta;
-
- LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_dims), output_data);
-}
-
inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
float beta, float* output_data,
const RuntimeShape& output_shape) {
@@ -3465,14 +3405,6 @@ inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
}
}
-// Legacy Dims<4> version.
-template <typename SrcT, typename DstT>
-void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
- const Dims<4>& output_dims) {
- Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data);
-}
-
inline void Floor(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -3483,13 +3415,6 @@ inline void Floor(const RuntimeShape& input_shape, const float* input_data,
}
}
-// Legacy Dims<4> version.
-inline void Floor(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data);
-}
-
template <typename T>
inline void Gather(const T* input_data, const Dims<4>& input_dims,
int input_rank, const int32* coords_data,
@@ -3573,39 +3498,6 @@ inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
}
}
-// Legacy Dims<4>.
-template <typename T>
-inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims,
- const int32* output_size_data,
- const Dims<4>& output_size_dims, T* output_data,
- const Dims<4>& output_dims, bool align_corners) {
- tflite::ResizeBilinearParams op_params;
- op_params.align_corners = align_corners;
- ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_size_dims), output_size_data,
- DimsToShape(output_dims), output_data);
-}
-
-// legacy, for compatibility with old checked-in code
-inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
- const int32* output_size_data,
- const Dims<4>& output_size_dims, float* output_data,
- const Dims<4>& output_dims) {
- ResizeBilinear<float>(input_data, input_dims, output_size_data,
- output_size_dims, output_data, output_dims,
- /*align_corners=*/false);
-}
-
-// Legacy.
-inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
- const int32* output_size_data,
- const Dims<4>& output_size_dims, uint8* output_data,
- const Dims<4>& output_dims) {
- ResizeBilinear<uint8>(input_data, input_dims, output_size_data,
- output_size_dims, output_data, output_dims,
- /*align_corners=*/false);
-}
-
template <typename T>
inline void SpaceToBatchND(
const SpaceToBatchParams& params,
@@ -3664,41 +3556,6 @@ inline void SpaceToBatchND(
}
}
-// Legacy Dims<4>.
-template <typename T>
-inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
- const int32* block_shape_data,
- const Dims<4>& block_shape_dims,
- const int32* paddings_data,
- const Dims<4>& paddings_dims, T* output_data,
- const Dims<4>& output_dims,
- const int32_t pad_value) {
- tflite::SpaceToBatchParams op_params;
- op_params.output_offset = pad_value;
-
- SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(block_shape_dims), block_shape_data,
- DimsToShape(paddings_dims), paddings_data,
- DimsToShape(output_dims), output_data);
-}
-
-// Legacy if no good reason to have signature with pad_value=0.
-template <typename T>
-inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
- const int32* block_shape_data,
- const Dims<4>& block_shape_dims,
- const int32* paddings_data,
- const Dims<4>& paddings_dims, T* output_data,
- const Dims<4>& output_dims) {
- tflite::SpaceToBatchParams op_params;
- op_params.output_offset = 0;
-
- SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(block_shape_dims), block_shape_data,
- DimsToShape(paddings_dims), paddings_data,
- DimsToShape(output_dims), output_data);
-}
-
template <typename T>
inline void BatchToSpaceND(
const RuntimeShape& unextended_input1_shape, const T* input1_data,
@@ -3751,19 +3608,6 @@ inline void BatchToSpaceND(
}
}
-// Legacy Dims<4>.
-template <typename T>
-inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
- const int32* block_shape_data,
- const Dims<4>& block_shape_dims,
- const int32* crops_data, const Dims<4>& crops_dims,
- T* output_data, const Dims<4>& output_dims) {
- BatchToSpaceND(DimsToShape(input_dims), input_data,
- DimsToShape(block_shape_dims), block_shape_data,
- DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
- output_data);
-}
-
// There are two versions of pad: Pad and PadV2. In PadV2 there is a second
// scalar input that provides the padding value. Therefore pad_value_ptr can be
// equivalent to a simple input1_data. For Pad, it should point to a zero
@@ -3863,50 +3707,6 @@ inline void Pad(const tflite::PadParams& op_params,
output_data);
}
-// Legacy signature, function covered both Pad and PadV2.
-template <typename T>
-inline void PadV2(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const T pad_value) {
- TFLITE_DCHECK_EQ(left_paddings.size(), 4);
- TFLITE_DCHECK_EQ(right_paddings.size(), 4);
- tflite::PadParams op_params;
- op_params.left_padding_count = 4;
- op_params.right_padding_count = 4;
- for (int i = 0; i < 4; ++i) {
- op_params.left_padding[i] = left_paddings[3 - i];
- op_params.right_padding[i] = right_paddings[3 - i];
- }
- // SetFloatOrInt(pad_value, &op_params.pad_value);
- const T pad_value_copy = pad_value;
-
- Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
- DimsToShape(output_dims), output_data);
-}
-
-// Old Pad that calls legacy PadV2.
-template <typename T>
-inline void Pad(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const int32_t pad_value) {
- const T converted_pad_value = static_cast<T>(pad_value);
- PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
- output_dims, converted_pad_value);
-}
-
-// Old Pad that only padded with 0.
-template <typename T>
-inline void Pad(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims) {
- const T pad_value = static_cast<T>(0);
- PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
- output_dims, pad_value);
-}
-
template <typename T>
inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
int begin_mask, int end_mask, int shrink_axis_mask,
@@ -4001,22 +3801,6 @@ inline void Slice(const tflite::SliceParams& op_params,
}
template <typename T>
-inline void Slice(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& begin, const std::vector<int>& size,
- T* output_data, const Dims<4>& output_dims) {
- tflite::SliceParams op_params;
- op_params.begin_count = 4;
- op_params.size_count = 4;
- for (int i = 0; i < 4; ++i) {
- op_params.begin[i] = begin[3 - i];
- op_params.size[i] = size[3 - i];
- }
-
- Slice(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_dims), output_data);
-}
-
-template <typename T>
inline void Exp(const T* input_data, const size_t num_elements,
T* output_data) {
for (size_t idx = 0; idx < num_elements; ++idx) {
@@ -4140,91 +3924,6 @@ inline bool ReduceGeneric(const T* input_data, const int* input_dims,
temp_index, reducer, output_data);
}
-// Computes the sum of elements across dimensions given in axis.
-template <typename T>
-inline bool Sum(const T* input_data, const int* input_dims,
- const int input_num_dims, T* output_data,
- const int* output_dims, const int output_num_dims,
- const int* axis, const int num_axis_dimensions, bool keep_dims,
- int* temp_index, int* resolved_axis) {
- T init_value = static_cast<T>(0);
-
- auto reducer = [](const T current, const T in) -> T { return current + in; };
- return ReduceGeneric<T>(input_data, input_dims, input_num_dims, output_data,
- output_dims, output_num_dims, axis,
- num_axis_dimensions, keep_dims, temp_index,
- resolved_axis, init_value, reducer);
-}
-
-// Computes the max of elements across dimensions given in axis.
-template <typename T>
-inline bool ReduceMax(const T* input_data, const int* input_dims,
- const int input_num_dims, T* output_data,
- const int* output_dims, const int output_num_dims,
- const int* axis, const int64_t num_axis_dimensions,
- bool keep_dims, int* temp_index, int* resolved_axis) {
- T init_value = std::numeric_limits<T>::lowest();
-
- auto reducer = [](const T current, const T in) -> T {
- return (in > current) ? in : current;
- };
- return ReduceGeneric<T>(input_data, input_dims, input_num_dims, output_data,
- output_dims, output_num_dims, axis,
- num_axis_dimensions, keep_dims, temp_index,
- resolved_axis, init_value, reducer);
-}
-
-// Computes the min of elements across dimensions given in axis.
-template <typename T>
-inline bool ReduceMin(const T* input_data, const int* input_dims,
- const int input_num_dims, T* output_data,
- const int* output_dims, const int output_num_dims,
- const int* axis, const int64_t num_axis_dimensions,
- bool keep_dims, int* temp_index, int* resolved_axis) {
- T init_value = std::numeric_limits<T>::max();
-
- auto reducer = [](const T current, const T in) -> T {
- return (in < current) ? in : current;
- };
- return ReduceGeneric<T>(input_data, input_dims, input_num_dims, output_data,
- output_dims, output_num_dims, axis,
- num_axis_dimensions, keep_dims, temp_index,
- resolved_axis, init_value, reducer);
-}
-
-// Computes the prod of elements across dimensions given in axis.
-template <typename T>
-inline bool ReduceProd(const T* input_data, const int* input_dims,
- const int input_num_dims, T* output_data,
- const int* output_dims, const int output_num_dims,
- const int* axis, const int64_t num_axis_dimensions,
- bool keep_dims, int* temp_index, int* resolved_axis) {
- T init_value = static_cast<T>(1);
-
- auto reducer = [](const T current, const T in) -> T { return in * current; };
- return ReduceGeneric<T>(input_data, input_dims, input_num_dims, output_data,
- output_dims, output_num_dims, axis,
- num_axis_dimensions, keep_dims, temp_index,
- resolved_axis, init_value, reducer);
-}
-
-// Computes the logical_or of elements across dimensions given in axis.
-inline bool ReduceAny(const bool* input_data, const int* input_dims,
- const int input_num_dims, bool* output_data,
- const int* output_dims, const int output_num_dims,
- const int* axis, const int64_t num_axis_dimensions,
- bool keep_dims, int* temp_index, int* resolved_axis) {
- bool init_value = false;
-
- auto reducer = [](const bool current, const bool in) -> bool {
- return current || in;
- };
- return ReduceGeneric<bool>(input_data, input_dims, input_num_dims,
- output_data, output_dims, output_num_dims, axis,
- num_axis_dimensions, keep_dims, temp_index,
- resolved_axis, init_value, reducer);
-}
-
// Computes the mean of elements across dimensions given in axis.
// It does so in two stages, first calculates the sum of elements along the axis
// then divides it by the number of element in axis.
@@ -4404,33 +4103,23 @@ void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
}
}
-template <typename T>
-void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
- Minimum(DimsToShape(input1_dims), input1_data, input2_data,
- DimsToShape(output_dims), output_data);
-}
-
-template <typename T>
-void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
- Maximum(DimsToShape(input1_dims), input1_data, input2_data,
- DimsToShape(output_dims), output_data);
-}
-
template <typename T, typename Op>
-void MaximumMinimumBroadcast4DSlow(const RuntimeShape& input1_shape,
+void MaximumMinimumBroadcast4DSlow(const RuntimeShape& unextended_input1_shape,
const T* input1_data,
- const RuntimeShape& input2_shape,
+ const RuntimeShape& unextended_input2_shape,
const T* input2_data,
- const RuntimeShape& output_shape,
+ const RuntimeShape& unextended_output_shape,
T* output_data, Op op) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
- &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
@@ -4448,19 +4137,9 @@ void MaximumMinimumBroadcast4DSlow(const RuntimeShape& input1_shape,
}
}
-template <typename T, typename Op>
-void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims,
- Op op) {
- MaximumMinimumBroadcast4DSlow(DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data, op);
-}
-
template <typename T1, typename T2, typename T3, typename Cmp>
-void ArgMinMax(const T3* axis, const RuntimeShape& input_shape,
- const T1* input_data, const RuntimeShape& output_shape,
+void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
+ const T3* input2_data, const RuntimeShape& output_shape,
T2* output_data, const Cmp& cmp) {
// The current ArgMax implemention can only determine the index of the maximum
// value in the last dimension. So the axis argument is ignored.
@@ -4469,17 +4148,19 @@ void ArgMinMax(const T3* axis, const RuntimeShape& input_shape,
// 1). For the sake of simplicity, the output dimensions are equal to the
// input dimensions here. We enforce the constraint that the last dimension
// must always be 1.
- TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.Dims(3), 1);
- const int outer_size = MatchingFlatSizeSkipDim(input_shape, 3, output_shape);
- const int depth = input_shape.Dims(3);
+ const int trailing_dim = output_shape.DimensionsCount() - 1;
+ TFLITE_DCHECK_EQ(input1_shape.DimensionsCount(),
+ output_shape.DimensionsCount());
+ TFLITE_DCHECK_EQ(output_shape.Dims(trailing_dim), 1);
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input1_shape, trailing_dim, output_shape);
+ const int depth = input1_shape.Dims(trailing_dim);
for (int i = 0; i < outer_size; ++i) {
- auto min_max_value = input_data[i * depth];
+ auto min_max_value = input1_data[i * depth];
int min_max_index = 0;
for (int d = 1; d < depth; ++d) {
- const auto& curr_value = input_data[i * depth + d];
+ const auto& curr_value = input1_data[i * depth + d];
if (cmp(curr_value, min_max_value)) {
min_max_value = curr_value;
min_max_index = d;
@@ -4489,21 +4170,11 @@ void ArgMinMax(const T3* axis, const RuntimeShape& input_shape,
}
}
-// Legacy Dims<4> version.
-template <typename T1, typename T2, typename T3, typename Cmp>
-void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
- T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
- ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data, cmp);
-}
-
-// Legacy.
-// TODO(renjieliu): Remove this one.
template <typename T1, typename T2, typename T3>
-void ArgMax(const T3* axis, const T1* input_data,
- const tflite::Dims<4>& input_dims, T2* output_data,
- const tflite::Dims<4>& output_dims) {
- ArgMinMax(axis, input_data, input_dims, output_data, output_dims,
+void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
+ const T3* input2_data, const RuntimeShape& output_shape,
+ T2* output_data) {
+ ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
std::greater<T1>());
}
@@ -4640,7 +4311,6 @@ inline void Comparison(const RuntimeShape& input1_shape, const T* input1_data,
}
}
-// Legacy Dims<4> version.
template <typename T, ComparisonFn<T> F>
inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
const T* input2_data, const Dims<4>& input2_dims,
@@ -4870,15 +4540,6 @@ inline void Pow(const RuntimeShape& input1_shape, const T* input1_data,
}
}
-// Legacy Dims<4> version.
-template <typename T>
-inline void Pow(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- Pow(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
- input2_data, DimsToShape(output_dims), output_data);
-}
-
template <typename T>
inline void BroadcastPow4DSlow(const RuntimeShape& input1_shape,
const T* input1_data,
@@ -4907,16 +4568,6 @@ inline void BroadcastPow4DSlow(const RuntimeShape& input1_shape,
}
}
-// Legacy Dims<4> version.
-template <typename T>
-inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- BroadcastPow4DSlow(DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
-}
-
inline void Logical(const RuntimeShape& input1_shape, const bool* input1_data,
const RuntimeShape& input2_shape, const bool* input2_data,
const RuntimeShape& output_shape, bool* output_data,
@@ -4928,24 +4579,21 @@ inline void Logical(const RuntimeShape& input1_shape, const bool* input1_data,
}
}
-// Legacy Dims<4> version.
-inline void Logical(const bool* input1_data, const Dims<4>& input1_dims,
- const bool* input2_data, const Dims<4>& input2_dims,
- bool* output_data, const Dims<4>& output_dims,
- const std::function<bool(bool, bool)>& func) {
- Logical(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
- input2_data, DimsToShape(output_dims), output_data, func);
-}
-
inline void BroadcastLogical4DSlow(
- const RuntimeShape& input1_shape, const bool* input1_data,
- const RuntimeShape& input2_shape, const bool* input2_data,
- const RuntimeShape& output_shape, bool* output_data,
+ const RuntimeShape& unextended_input1_shape, const bool* input1_data,
+ const RuntimeShape& unextended_input2_shape, const bool* input2_data,
+ const RuntimeShape& unextended_output_shape, bool* output_data,
const std::function<bool(bool, bool)>& func) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
- &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
@@ -4963,18 +4611,6 @@ inline void BroadcastLogical4DSlow(
}
}
-// Legacy Dims<4> version.
-inline void BroadcastLogical(const bool* input1_data,
- const Dims<4>& input1_dims,
- const bool* input2_data,
- const Dims<4>& input2_dims, bool* output_data,
- const Dims<4>& output_dims,
- const std::function<bool(bool, bool)>& func) {
- BroadcastLogical4DSlow(DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data, func);
-}
-
// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more
// generalized and efficient BroadcastBinaryFunction.
//
@@ -4982,16 +4618,21 @@ inline void BroadcastLogical(const bool* input1_data,
//
// R: Result type. T1: Input 1 type. T2: Input 2 type.
template <typename R, typename T1, typename T2>
-inline void BroadcastBinaryFunction4DSlow(const RuntimeShape& input1_shape,
- const T1* input1_data,
- const RuntimeShape& input2_shape,
- const T2* input2_data,
- const RuntimeShape& output_shape,
- R* output_data, R (*func)(T1, T2)) {
+inline void BroadcastBinaryFunction4DSlow(
+ const RuntimeShape& unextended_input1_shape, const T1* input1_data,
+ const RuntimeShape& unextended_input2_shape, const T2* input2_data,
+ const RuntimeShape& unextended_output_shape, R* output_data,
+ R (*func)(T1, T2)) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
- &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
@@ -5009,31 +4650,17 @@ inline void BroadcastBinaryFunction4DSlow(const RuntimeShape& input1_shape,
}
}
-// Legacy Dims<4> version.
-//
-// R: Result type. T1: Input 1 type. T2: Input 2 type.
-template <typename R, typename T1, typename T2>
-inline void BroadcastBinaryFunction(const T1* input1_data,
- const Dims<4>& input1_dims,
- const T2* input2_data,
- const Dims<4>& input2_dims, R* output_data,
- const Dims<4>& output_dims,
- R (*func)(T1, T2)) {
- BroadcastBinaryFunction4DSlow(DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data, func);
-}
-
-// Legacy Dims<4> version.
-//
// R: Result type. T1: Input 1 type. T2: Input 2 type.
// TODO(renjieliu): Refactor other binary functions to use this one.
template <typename R, typename T1, typename T2>
-inline void BinaryFunction(const T1* input1_data, const Dims<4>& input1_dims,
- const T2* input2_data, const Dims<4>& input2_dims,
- R* output_data, const Dims<4>& output_dims,
+inline void BinaryFunction(const RuntimeShape& input1_shape,
+ const T1* input1_data,
+ const RuntimeShape& input2_shape,
+ const T2* input2_data,
+ const RuntimeShape& output_shape, R* output_data,
R (*func)(T1, T2)) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = func(input1_data[i], input2_data[i]);
}
diff --git a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc
index 3d8765f11b..15df31f75a 100644
--- a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc
@@ -28,14 +28,12 @@ template <typename T>
void TestOneResizeBilinear(int batch, int depth, int input_width,
int input_height, int output_width,
int output_height, float error_threshold) {
- Dims<4> input_dims_inference =
- MakeDimsForInference(depth, input_width, input_height, batch);
- Dims<4> output_dims_inference =
- MakeDimsForInference(depth, output_width, output_height, batch);
+ RuntimeShape input_dims_inference({batch, input_height, input_width, depth});
+ RuntimeShape output_dims_inference(
+ {batch, output_height, output_width, depth});
- const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference);
- const int output_buffer_size =
- RequiredBufferSizeForDims(output_dims_inference);
+ const int input_buffer_size = input_dims_inference.FlatSize();
+ const int output_buffer_size = output_dims_inference.FlatSize();
std::vector<T> input_data(input_buffer_size, 0);
std::vector<T> reference_output_data(output_buffer_size, 0);
@@ -47,15 +45,19 @@ void TestOneResizeBilinear(int batch, int depth, int input_width,
const T max_amplitude = static_cast<T>(255);
FillRandom(&input_data, min_amplitude, max_amplitude);
- Dims<4> output_size_dims = MakeDimsForInference(2, 1, 1, 1);
+ RuntimeShape output_size_dims({1, 1, 1, 2});
std::vector<int32> output_size_data = {output_height, output_width};
- reference_ops::ResizeBilinear(
- input_data.data(), input_dims_inference, output_size_data.data(),
- output_size_dims, reference_output_data.data(), output_dims_inference);
- optimized_ops::ResizeBilinear(input_data.data(), input_dims_inference,
- output_size_data.data(), output_size_dims,
- output_data.data(), output_dims_inference);
+ tflite::ResizeBilinearParams op_params;
+ op_params.align_corners = false;
+
+ reference_ops::ResizeBilinear(op_params, input_dims_inference,
+ input_data.data(), output_size_dims,
+ output_size_data.data(), output_dims_inference,
+ reference_output_data.data());
+ optimized_ops::ResizeBilinear(
+ op_params, input_dims_inference, input_data.data(), output_size_dims,
+ output_size_data.data(), output_dims_inference, output_data.data());
double sum_diff = 0;
float max_abs_val = 0;
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
index 1ff8cfe39c..748356d1bd 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
@@ -101,6 +101,11 @@ void BatchVectorBatchVectorDotProduct(const float* vector1,
int n_batch, float* result,
int result_stride);
+// Cwise product of a vector and a batch-vector.
+void VectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result);
+
// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
// operation, the assumption here is that result array is initialized to valid
// values.
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
index e8343f1223..240fb64ca3 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
@@ -555,6 +555,120 @@ TEST(uKernels, ZeroVectorTest) {
ElementsAreArray(ArrayFloatNear({0.0, 0.0, 0.0, 0.0, 0.0})));
}
+TEST(uKernels, VectorBatchVectorCwiseProductAccumulate) {
+ constexpr int kVectorSize = 29;
+ constexpr int kBatchSize = 4;
+ static float input[kVectorSize] = {
+ 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1,
+ 11.11, 12.12, 13.13, 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2,
+ 21.21, 22.22, 23.23, 24.24, 25.25, 26.26, 27.27, 28.28, 0};
+ std::vector<float> output = {
+ /* batch 0 */
+ 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
+ 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2, 21.21, 22.22, 23.23,
+ 24.24, 25.25, 26.26, 27.27, 28.28, 0,
+ /* batch 1 */
+ -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
+ -12.12, -13.13, -14.14, -15.15, -16.16, -17.17, -18.18, -19.19, -20.2,
+ -21.21, -22.22, -23.23, -24.24, -25.25, -26.26, -27.27, -28.28, 0,
+ /* batch 2 */
+ 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11, -12.12,
+ 13.13, -14.14, 15.15, -16.16, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22,
+ 23.23, -24.24, 25.25, -26.26, 27.27, -28.28, 0,
+ /* batch 3 */
+ -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
+ -13.13, 14.14, -15.15, 16.16, -17.17, 18.18, -19.19, 20.2, -21.21, 22.22,
+ -23.23, 24.24, -25.25, 26.26, -27.27, 28.28, 0};
+ VectorBatchVectorCwiseProductAccumulate(input, kVectorSize, output.data(),
+ kBatchSize, output.data());
+
+ // Expect output = input * output + output.
+ const std::vector<float> expected_output = {
+ /* batch 0 */
+ 2.310000, 7.040000, 14.190000, 23.760000, 35.750000, 50.159996, 66.989998,
+ 86.240005, 107.909996, 112.110008, 134.542084, 159.014389, 185.526901,
+ 214.079605, 244.672485, 277.305603, 311.978912, 348.692413, 387.446136,
+ 428.240051, 471.074066, 515.948364, 562.862854, 611.817566, 662.812500,
+ 715.847595, 770.922974, 828.038452, 0.000000,
+ /* batch 1 */
+ -2.310000, -7.040000, -14.190000, -23.760000, -35.750000, -50.159996,
+ -66.989998, -86.240005, -107.909996, -112.110008, -134.542084,
+ -159.014389, -185.526901, -214.079605, -244.672485, -277.305603,
+ -311.978912, -348.692413, -387.446136, -428.240051, -471.074066,
+ -515.948364, -562.862854, -611.817566, -662.812500, -715.847595,
+ -770.922974, -828.038452, 0.000000,
+ /* batch 2 */
+ 2.310000, -7.040000, 14.190000, -23.760000, 35.750000, -50.159996,
+ 66.989998, -86.240005, 107.909996, -112.110008, 134.542084, -159.014389,
+ 185.526901, -214.079605, 244.672485, -277.305603, 311.978912, -348.692413,
+ 387.446136, -428.240051, 471.074066, -515.948364, 562.862854, -611.817566,
+ 662.812500, -715.847595, 770.922974, -828.038452, 0.000000,
+ /* batch 3 */
+ -2.310000, 7.040000, -14.190000, 23.760000, -35.750000, 50.159996,
+ -66.989998, 86.240005, -107.909996, 112.110008, -134.542084, 159.014389,
+ -185.526901, 214.079605, -244.672485, 277.305603, -311.978912, 348.692413,
+ -387.446136, 428.240051, -471.074066, 515.948364, -562.862854, 611.817566,
+ -662.812500, 715.847595, -770.922974, 828.038452, 0.000000};
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
+TEST(uKernels, VectorBatchVectorCwiseProductNoAccumulate) {
+ constexpr int kVectorSize = 29;
+ constexpr int kBatchSize = 4;
+ static float input[kVectorSize] = {
+ 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1,
+ 11.11, 12.12, 13.13, 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2,
+ 21.21, 22.22, 23.23, 24.24, 25.25, 26.26, 27.27, 28.28, 0};
+ std::vector<float> output = {
+ /* batch 0 */
+ 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
+ 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2, 21.21, 22.22, 23.23,
+ 24.24, 25.25, 26.26, 27.27, 28.28, 0,
+ /* batch 1 */
+ -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
+ -12.12, -13.13, -14.14, -15.15, -16.16, -17.17, -18.18, -19.19, -20.2,
+ -21.21, -22.22, -23.23, -24.24, -25.25, -26.26, -27.27, -28.28, 0,
+ /* batch 2 */
+ 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11, -12.12,
+ 13.13, -14.14, 15.15, -16.16, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22,
+ 23.23, -24.24, 25.25, -26.26, 27.27, -28.28, 0,
+ /* batch 3 */
+ -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
+ -13.13, 14.14, -15.15, 16.16, -17.17, 18.18, -19.19, 20.2, -21.21, 22.22,
+ -23.23, 24.24, -25.25, 26.26, -27.27, 28.28, 0};
+ VectorBatchVectorCwiseProduct(input, kVectorSize, output.data(), kBatchSize,
+ output.data());
+
+ // Expect output = input * output + output.
+ const std::vector<float> expected_output = {
+ /* batch 0 */
+ 1.210000, 4.840000, 10.889999, 19.360001, 30.250000, 43.559998, 59.289997,
+ 77.440002, 98.009995, 102.010010, 123.432091, 146.894394, 172.396896,
+ 199.939606, 229.522491, 261.145599, 294.808899, 330.512421, 368.256134,
+ 408.040039, 449.864075, 493.728363, 539.632874, 587.577576, 637.562500,
+ 689.587585, 743.652954, 799.758423, 0.000000,
+ /* batch 1 */
+ -1.210000, -4.840000, -10.889999, -19.360001, -30.250000, -43.559998,
+ -59.289997, -77.440002, -98.009995, -102.010010, -123.432091, -146.894394,
+ -172.396896, -199.939606, -229.522491, -261.145599, -294.808899,
+ -330.512421, -368.256134, -408.040039, -449.864075, -493.728363,
+ -539.632874, -587.577576, -637.562500, -689.587585, -743.652954,
+ -799.758423, 0.000000,
+ /* batch 2 */
+ 1.210000, -4.840000, 10.889999, -19.360001, 30.250000, -43.559998,
+ 59.289997, -77.440002, 98.009995, -102.010010, 123.432091, -146.894394,
+ 172.396896, -199.939606, 229.522491, -261.145599, 294.808899, -330.512421,
+ 368.256134, -408.040039, 449.864075, -493.728363, 539.632874, -587.577576,
+ 637.562500, -689.587585, 743.652954, -799.758423, 0.000000,
+ /* batch 3 */
+ -1.210000, 4.840000, -10.889999, 19.360001, -30.250000, 43.559998,
+ -59.289997, 77.440002, -98.009995, 102.010010, -123.432091, 146.894394,
+ -172.396896, 199.939606, -229.522491, 261.145599, -294.808899, 330.512421,
+ -368.256134, 408.040039, -449.864075, 493.728363, -539.632874, 587.577576,
+ -637.562500, 689.587585, -743.652954, 799.758423, 0.000000};
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
TEST(uKernels, BatchVectorBatchVectorDotProductTest) {
constexpr int kVectorSize = 5;
constexpr int kBatch = 2;
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 8e17eaa964..3b296f024f 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -710,6 +710,11 @@ struct ArithmeticParams {
struct ConcatenationParams {
int8 axis;
+ const int32* input_zeropoint;
+ const float* input_scale;
+ uint16 inputs_count;
+ int32 output_zeropoint;
+ float output_scale;
};
struct ComparisonParams {
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h
index c8ce3c917d..ed46cd984f 100644
--- a/tensorflow/contrib/lite/kernels/kernel_util.h
+++ b/tensorflow/contrib/lite/kernels/kernel_util.h
@@ -30,6 +30,11 @@ inline const TfLiteTensor* GetInput(TfLiteContext* context, TfLiteNode* node,
int index) {
return &context->tensors[node->inputs->data[index]];
}
+inline TfLiteTensor* GetVariableInput(TfLiteContext* context, TfLiteNode* node,
+ int index) {
+ TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];
+ return (tensor->is_variable) ? tensor : nullptr;
+}
inline TfLiteTensor* GetOutput(TfLiteContext* context, TfLiteNode* node,
int index) {
return &context->tensors[node->outputs->data[index]];
diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc
index a7b54c6b84..5b3536de0c 100644
--- a/tensorflow/contrib/lite/kernels/l2norm.cc
+++ b/tensorflow/contrib/lite/kernels/l2norm.cc
@@ -68,10 +68,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
-#define TF_LITE_L2NORM(type) \
- type::L2Normalization<FusedActivationFunctionType::kNone>( \
- GetTensorData<float>(input), GetTensorShape(input), \
- GetTensorData<float>(output), GetTensorShape(output))
+#define TF_LITE_L2NORM(type) \
+ tflite::L2NormalizationParams op_params; \
+ op_params.input_zero_point = 0; \
+ type::L2Normalization(op_params, GetTensorShape(input), \
+ GetTensorData<float>(input), GetTensorShape(output), \
+ GetTensorData<float>(output))
if (kernel_type == kReference) {
TF_LITE_L2NORM(reference_ops);
@@ -81,10 +83,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
#undef TF_LITE_L2NORM
} else if (output->type == kTfLiteUInt8) {
-#define TF_LITE_L2NORM(type) \
- type::L2Normalization(GetTensorData<uint8>(input), GetTensorShape(input), \
- input->params.zero_point, \
- GetTensorData<uint8>(output), GetTensorShape(output))
+#define TF_LITE_L2NORM(type) \
+ tflite::L2NormalizationParams op_params; \
+ op_params.input_zero_point = input->params.zero_point; \
+ type::L2Normalization(op_params, GetTensorShape(input), \
+ GetTensorData<uint8>(input), GetTensorShape(output), \
+ GetTensorData<uint8>(output))
if (kernel_type == kReference) {
TF_LITE_L2NORM(reference_ops);
diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/contrib/lite/kernels/local_response_norm.cc
index 36dca299d0..799c1528bd 100644
--- a/tensorflow/contrib/lite/kernels/local_response_norm.cc
+++ b/tensorflow/contrib/lite/kernels/local_response_norm.cc
@@ -64,11 +64,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
-#define TF_LITE_LOCAL_RESPONSE_NORM(type) \
- type::LocalResponseNormalization( \
- GetTensorData<float>(input), GetTensorDims(input), params->radius, \
- params->bias, params->alpha, params->beta, GetTensorData<float>(output), \
- GetTensorDims(output))
+#define TF_LITE_LOCAL_RESPONSE_NORM(type) \
+ tflite::LocalResponseNormalizationParams op_params; \
+ op_params.range = params->radius; \
+ op_params.bias = params->bias; \
+ op_params.alpha = params->alpha; \
+ op_params.beta = params->beta; \
+ type::LocalResponseNormalization( \
+ op_params, GetTensorShape(input), GetTensorData<float>(input), \
+ GetTensorShape(output), GetTensorData<float>(output))
if (kernel_type == kReference) {
TF_LITE_LOCAL_RESPONSE_NORM(reference_ops);
}
diff --git a/tensorflow/contrib/lite/kernels/logical.cc b/tensorflow/contrib/lite/kernels/logical.cc
index 87c2fee667..c71f3b4701 100644
--- a/tensorflow/contrib/lite/kernels/logical.cc
+++ b/tensorflow/contrib/lite/kernels/logical.cc
@@ -86,14 +86,14 @@ TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (data->requires_broadcast) {
- reference_ops::BroadcastLogical(
- GetTensorData<bool>(input1), GetTensorDims(input1),
- GetTensorData<bool>(input2), GetTensorDims(input2),
- GetTensorData<bool>(output), GetTensorDims(output), func);
+ reference_ops::BroadcastLogical4DSlow(
+ GetTensorShape(input1), GetTensorData<bool>(input1),
+ GetTensorShape(input2), GetTensorData<bool>(input2),
+ GetTensorShape(output), GetTensorData<bool>(output), func);
} else {
- reference_ops::Logical(GetTensorData<bool>(input1), GetTensorDims(input1),
- GetTensorData<bool>(input2), GetTensorDims(input2),
- GetTensorData<bool>(output), GetTensorDims(output),
+ reference_ops::Logical(GetTensorShape(input1), GetTensorData<bool>(input1),
+ GetTensorShape(input2), GetTensorData<bool>(input2),
+ GetTensorShape(output), GetTensorData<bool>(output),
func);
}
diff --git a/tensorflow/contrib/lite/kernels/maximum_minimum.cc b/tensorflow/contrib/lite/kernels/maximum_minimum.cc
index 8d676218bd..0308a3976a 100644
--- a/tensorflow/contrib/lite/kernels/maximum_minimum.cc
+++ b/tensorflow/contrib/lite/kernels/maximum_minimum.cc
@@ -86,13 +86,14 @@ struct MinimumOp {
template <typename data_type, typename op_type>
void TFLiteOperation(TfLiteContext* context, TfLiteNode* node,
const OpContext& op_context) {
- reference_ops::TensorFlowMaximumMinimum<data_type>(
+ reference_ops::MaximumMinimumBroadcast4DSlow(
+ GetTensorShape(op_context.input1),
GetTensorData<data_type>(op_context.input1),
- GetTensorDims(op_context.input1),
+ GetTensorShape(op_context.input2),
GetTensorData<data_type>(op_context.input2),
- GetTensorDims(op_context.input2),
+ GetTensorShape(op_context.output),
GetTensorData<data_type>(op_context.output),
- GetTensorDims(op_context.output), op_type::template op<data_type>);
+ op_type::template op<data_type>);
}
template <KernelType kernel_type, typename OpType>
diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc
index 561e39cfc6..92d8bc8b67 100644
--- a/tensorflow/contrib/lite/kernels/mul.cc
+++ b/tensorflow/contrib/lite/kernels/mul.cc
@@ -102,24 +102,28 @@ template <KernelType kernel_type>
void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
const OpData* data, const TfLiteTensor* input1,
const TfLiteTensor* input2, TfLiteTensor* output) {
-#define TF_LITE_MUL(type, opname, data_type) \
- data_type output_activation_min, output_activation_max; \
- CalculateActivationRange(params->activation, &output_activation_min, \
- &output_activation_max); \
- type::opname(GetTensorData<data_type>(input1), GetTensorDims(input1), \
- GetTensorData<data_type>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<data_type>(output), GetTensorDims(output))
+#define TF_LITE_MUL(type, opname, data_type) \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(output_activation_min, output_activation_max, \
+ &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<data_type>(input1), GetTensorShape(input2), \
+ GetTensorData<data_type>(input2), GetTensorShape(output), \
+ GetTensorData<data_type>(output))
+
if (output->type == kTfLiteInt32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_MUL(reference_ops, BroadcastMul, int32_t);
+ TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, int32_t);
} else {
TF_LITE_MUL(reference_ops, Mul, int32_t);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_MUL(optimized_ops, BroadcastMul, int32_t);
+ TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, int32_t);
} else {
TF_LITE_MUL(optimized_ops, Mul, int32_t);
}
@@ -127,13 +131,13 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
} else if (output->type == kTfLiteFloat32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_MUL(reference_ops, BroadcastMul, float);
+ TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, float);
} else {
TF_LITE_MUL(reference_ops, Mul, float);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_MUL(optimized_ops, BroadcastMul, float);
+ TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, float);
} else {
TF_LITE_MUL(optimized_ops, Mul, float);
}
@@ -149,14 +153,20 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input2, TfLiteTensor* output) {
if (input1->type == kTfLiteUInt8 && input2->type == kTfLiteUInt8 &&
output->type == kTfLiteUInt8) {
-#define TF_LITE_MUL(type, opname) \
- type::opname(GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- -input1->params.zero_point, GetTensorData<uint8_t>(input2), \
- GetTensorDims(input2), -input2->params.zero_point, \
- output->params.zero_point, data->output_multiplier, \
- data->output_shift, data->output_activation_min, \
- data->output_activation_max, GetTensorData<uint8_t>(output), \
- GetTensorDims(output));
+#define TF_LITE_MUL(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(data->output_activation_min, \
+ data->output_activation_max, &op_params); \
+ op_params.input1_offset = -input1->params.zero_point; \
+ op_params.input2_offset = -input2->params.zero_point; \
+ op_params.output_offset = output->params.zero_point; \
+ op_params.output_multiplier = data->output_multiplier; \
+ op_params.output_shift = data->output_shift; \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<uint8_t>(input1), GetTensorShape(input2), \
+ GetTensorData<uint8_t>(input2), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
+
// The quantized version of Mul doesn't support activations, so we
// always use BroadcastMul.
if (kernel_type == kReference) {
@@ -167,10 +177,12 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
#undef TF_LITE_MUL
} else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
output->type == kTfLiteInt16) {
-#define TF_LITE_MUL(type, opname) \
- type::opname(GetTensorData<int16_t>(input1), GetTensorDims(input1), \
- GetTensorData<int16_t>(input2), GetTensorDims(input2), \
- GetTensorData<int16_t>(output), GetTensorDims(output));
+#define TF_LITE_MUL(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<int16_t>(input1), GetTensorShape(input2), \
+ GetTensorData<int16_t>(input2), GetTensorShape(output), \
+ GetTensorData<int16_t>(output))
if (kernel_type == kReference) {
TF_LITE_MUL(reference_ops, Mul);
} else {
@@ -179,12 +191,15 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
#undef TF_LITE_MUL
} else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
output->type == kTfLiteUInt8) {
-#define TF_LITE_MUL(type, opname) \
- type::opname(GetTensorData<int16_t>(input1), GetTensorDims(input1), \
- GetTensorData<int16_t>(input2), GetTensorDims(input2), \
- output->params.zero_point, data->output_activation_min, \
- data->output_activation_max, GetTensorData<uint8_t>(output), \
- GetTensorDims(output));
+#define TF_LITE_MUL(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(data->output_activation_min, \
+ data->output_activation_max, &op_params); \
+ op_params.output_offset = output->params.zero_point; \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<int16_t>(input1), GetTensorShape(input2), \
+ GetTensorData<int16_t>(input2), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
if (kernel_type == kReference) {
TF_LITE_MUL(reference_ops, Mul);
} else {
diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc
index 4be8c243c1..55bcf3b533 100644
--- a/tensorflow/contrib/lite/kernels/pad.cc
+++ b/tensorflow/contrib/lite/kernels/pad.cc
@@ -134,12 +134,22 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
after_padding.push_back(paddings_data[idx * 2 + 1]);
}
-#define TF_LITE_PAD(type, scalar, pad_value) \
- type::PadV2(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), before_padding, after_padding, \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output), pad_value)
-
+#define TF_LITE_PAD(type, scalar, pad_value) \
+ TF_LITE_ENSURE_EQ(context, before_padding.size(), 4); \
+ TF_LITE_ENSURE_EQ(context, after_padding.size(), 4); \
+ tflite::PadParams op_params; \
+ op_params.left_padding_count = 4; \
+ op_params.right_padding_count = 4; \
+ for (int i = 0; i < 4; ++i) { \
+ op_params.left_padding[i] = before_padding[3 - i]; \
+ op_params.right_padding[i] = after_padding[3 - i]; \
+ } \
+ const scalar pad_value_copy = pad_value; \
+ \
+ type::Pad(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), &pad_value_copy, \
+ GetTensorShape(op_context.output), \
+ GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) {
case kTfLiteFloat32: {
float pad_value = op_context.constant_values == nullptr
diff --git a/tensorflow/contrib/lite/kernels/pow.cc b/tensorflow/contrib/lite/kernels/pow.cc
index 4a539c47a8..d676de5b1d 100644
--- a/tensorflow/contrib/lite/kernels/pow.cc
+++ b/tensorflow/contrib/lite/kernels/pow.cc
@@ -80,14 +80,14 @@ template <typename T>
void PowImpl(const TfLiteTensor* input1, const TfLiteTensor* input2,
TfLiteTensor* output, bool requires_broadcast) {
if (requires_broadcast) {
- reference_ops::BroadcastPow(GetTensorData<T>(input1), GetTensorDims(input1),
- GetTensorData<T>(input2), GetTensorDims(input2),
- GetTensorData<T>(output),
- GetTensorDims(output));
+ reference_ops::BroadcastPow4DSlow(
+ GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), GetTensorData<T>(input2),
+ GetTensorShape(output), GetTensorData<T>(output));
} else {
- reference_ops::Pow(GetTensorData<T>(input1), GetTensorDims(input1),
- GetTensorData<T>(input2), GetTensorDims(input2),
- GetTensorData<T>(output), GetTensorDims(output));
+ reference_ops::Pow(GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), GetTensorData<T>(input2),
+ GetTensorShape(output), GetTensorData<T>(output));
}
}
diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc
index 4001cf357f..ca83797936 100644
--- a/tensorflow/contrib/lite/kernels/reduce.cc
+++ b/tensorflow/contrib/lite/kernels/reduce.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
+#include <limits>
#include <vector>
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
@@ -296,221 +297,125 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-template <KernelType kernel_type>
-TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
- OpContext op_context(context, node);
- int num_axis = static_cast<int>(NumElements(op_context.axis));
+// The underlying logic for Reduce Sum/Prod/Max/Min/Any
+template <typename T>
+TfLiteStatus EvalLogic(TfLiteContext* context, TfLiteNode* node,
+ OpContext* op_context, T init_value,
+ T reducer(const T current, const T in)) {
+ int64_t num_axis = NumElements(op_context->axis);
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
// Resize the output tensor if the output tensor is dynamic.
- if (IsDynamicTensor(op_context.output)) {
+ if (IsDynamicTensor(op_context->output)) {
TF_LITE_ENSURE_OK(context,
- ResizeTempAxis(context, &op_context, resolved_axis));
- TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ ResizeTempAxis(context, op_context, resolved_axis));
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, op_context));
}
-
-#define TF_LITE_SUM(kernel_type, data_type) \
- kernel_type::Sum<>( \
- GetTensorData<data_type>(op_context.input), \
- op_context.input->dims->data, op_context.input->dims->size, \
- GetTensorData<data_type>(op_context.output), \
- op_context.output->dims->data, op_context.output->dims->size, \
- GetTensorData<int>(op_context.axis), num_axis, \
- op_context.params->keep_dims, GetTensorData<int>(temp_index), \
- GetTensorData<int>(resolved_axis))
-
- if (kernel_type == kReference) {
- switch (op_context.input->type) {
- case kTfLiteFloat32:
- TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, float));
- break;
- case kTfLiteInt32:
- TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int));
- break;
- case kTfLiteInt64:
- TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int64_t));
- break;
- case kTfLiteUInt8:
- TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
- op_context.output->params.scale);
- TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
- op_context.output->params.zero_point);
- TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, uint8_t));
- break;
- default:
- return kTfLiteError;
- }
+ if (op_context->input->type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, op_context->input->params.scale,
+ op_context->output->params.scale);
+ TF_LITE_ENSURE_EQ(context, op_context->input->params.zero_point,
+ op_context->output->params.zero_point);
}
-#undef TF_LITE_SUM
+ TF_LITE_ENSURE(
+ context,
+ reference_ops::ReduceGeneric<T>(
+ GetTensorData<T>(op_context->input), op_context->input->dims->data,
+ op_context->input->dims->size, GetTensorData<T>(op_context->output),
+ op_context->output->dims->data, op_context->output->dims->size,
+ GetTensorData<int>(op_context->axis), num_axis,
+ op_context->params->keep_dims, GetTensorData<int>(temp_index),
+ GetTensorData<int>(resolved_axis), init_value, reducer));
return kTfLiteOk;
}
-template <KernelType kernel_type>
-TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) {
- OpContext op_context(context, node);
- int64_t num_axis = NumElements(op_context.axis);
- TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
- // Resize the output tensor if the output tensor is dynamic.
- if (IsDynamicTensor(op_context.output)) {
- TF_LITE_ENSURE_OK(context,
- ResizeTempAxis(context, &op_context, resolved_axis));
- TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
- }
-
-#define TF_LITE_PROD(kernel_type, data_type) \
- kernel_type::ReduceProd<>( \
- GetTensorData<data_type>(op_context.input), \
- op_context.input->dims->data, op_context.input->dims->size, \
- GetTensorData<data_type>(op_context.output), \
- op_context.output->dims->data, op_context.output->dims->size, \
- GetTensorData<int>(op_context.axis), num_axis, \
- op_context.params->keep_dims, GetTensorData<int>(temp_index), \
- GetTensorData<int>(resolved_axis))
+enum ReduceType {
+ kSum,
+ kProd,
+ kMax,
+ kMin,
+ kAny,
+};
- if (kernel_type == kReference) {
- switch (op_context.input->type) {
- case kTfLiteFloat32:
- TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, float));
- break;
- case kTfLiteInt32:
- TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int));
- break;
- case kTfLiteInt64:
- TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int64_t));
- break;
- case kTfLiteUInt8:
- // TODO(wangtz): uint8 reduce_prod is not yet supported.
- default:
- return kTfLiteError;
- }
+// Eval for determined input type and reduce type.
+template <typename T>
+TfLiteStatus EvalType(TfLiteContext* context, TfLiteNode* node,
+ OpContext* op_context, ReduceType reduce_type) {
+ switch (reduce_type) {
+ case kSum:
+ return EvalLogic<T>(
+ context, node, op_context, static_cast<T>(0),
+ [](const T current, const T in) -> T { return in + current; });
+ break;
+ case kProd:
+ return EvalLogic<T>(
+ context, node, op_context, static_cast<T>(1),
+ [](const T current, const T in) -> T { return in * current; });
+ break;
+ case kMax:
+ return EvalLogic<T>(context, node, op_context,
+ std::numeric_limits<T>::lowest(),
+ [](const T current, const T in) -> T {
+ return (in > current) ? in : current;
+ });
+ break;
+ case kMin:
+ return EvalLogic<T>(context, node, op_context,
+ std::numeric_limits<T>::max(),
+ [](const T current, const T in) -> T {
+ return (in < current) ? in : current;
+ });
+ break;
+ default:
+ return kTfLiteError;
}
-#undef TF_LITE_PROD
- return kTfLiteOk;
}
-template <KernelType kernel_type>
-TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) {
- OpContext op_context(context, node);
- int64_t num_axis = NumElements(op_context.axis);
- TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
- // Resize the output tensor if the output tensor is dynamic.
- if (IsDynamicTensor(op_context.output)) {
- TF_LITE_ENSURE_OK(context,
- ResizeTempAxis(context, &op_context, resolved_axis));
- TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
- }
-
-#define TF_LITE_MAX(kernel_type, data_type) \
- kernel_type::ReduceMax<>( \
- GetTensorData<data_type>(op_context.input), \
- op_context.input->dims->data, op_context.input->dims->size, \
- GetTensorData<data_type>(op_context.output), \
- op_context.output->dims->data, op_context.output->dims->size, \
- GetTensorData<int>(op_context.axis), num_axis, \
- op_context.params->keep_dims, GetTensorData<int>(temp_index), \
- GetTensorData<int>(resolved_axis))
-
- if (kernel_type == kReference) {
- switch (op_context.input->type) {
- case kTfLiteFloat32:
- TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, float));
- break;
- case kTfLiteInt32:
- TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int));
- break;
- case kTfLiteInt64:
- TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int64_t));
- break;
- case kTfLiteUInt8:
- TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
- op_context.output->params.scale);
- TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
- op_context.output->params.zero_point);
- TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, uint8_t));
- break;
- default:
- return kTfLiteError;
- }
+// Template specialization for bool type
+template <>
+TfLiteStatus EvalType<bool>(TfLiteContext* context, TfLiteNode* node,
+ OpContext* op_context, ReduceType reduce_type) {
+ switch (reduce_type) {
+ case kAny:
+ return EvalLogic<bool>(context, node, op_context, false,
+ [](const bool current, const bool in) -> bool {
+ return in || current;
+ });
+ break;
+ default:
+ return kTfLiteError;
}
-#undef TF_LITE_MAX
- return kTfLiteOk;
}
-template <KernelType kernel_type>
-TfLiteStatus EvalMin(TfLiteContext* context, TfLiteNode* node) {
- OpContext op_context(context, node);
- int64_t num_axis = NumElements(op_context.axis);
- TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
- // Resize the output tensor if the output tensor is dynamic.
- if (IsDynamicTensor(op_context.output)) {
- TF_LITE_ENSURE_OK(context,
- ResizeTempAxis(context, &op_context, resolved_axis));
- TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
- }
-
-#define TF_LITE_MIN(kernel_type, data_type) \
- kernel_type::ReduceMin<>( \
- GetTensorData<data_type>(op_context.input), \
- op_context.input->dims->data, op_context.input->dims->size, \
- GetTensorData<data_type>(op_context.output), \
- op_context.output->dims->data, op_context.output->dims->size, \
- GetTensorData<int>(op_context.axis), num_axis, \
- op_context.params->keep_dims, GetTensorData<int>(temp_index), \
- GetTensorData<int>(resolved_axis))
-
- if (kernel_type == kReference) {
- switch (op_context.input->type) {
- case kTfLiteFloat32:
- TF_LITE_ENSURE(context, TF_LITE_MIN(reference_ops, float));
- break;
- case kTfLiteInt32:
- TF_LITE_ENSURE(context, TF_LITE_MIN(reference_ops, int));
- break;
- case kTfLiteInt64:
- TF_LITE_ENSURE(context, TF_LITE_MIN(reference_ops, int64_t));
- break;
- case kTfLiteUInt8:
- TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
- op_context.output->params.scale);
- TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
- op_context.output->params.zero_point);
- TF_LITE_ENSURE(context, TF_LITE_MIN(reference_ops, uint8_t));
- break;
- default:
- return kTfLiteError;
- }
+// The entry point that handles input types and then calls template functions to
+// handle ReduceType.
+template <KernelType kernel_type, ReduceType reduce_type>
+TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node) {
+ if (kernel_type != kReference) {
+ return kTfLiteOk;
}
-#undef TF_LITE_MIN
- return kTfLiteOk;
-}
-
-template <KernelType kernel_type>
-TfLiteStatus EvalAny(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node);
- int64_t num_axis = NumElements(op_context.axis);
- TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
- // Resize the output tensor if the output tensor is dynamic.
- if (IsDynamicTensor(op_context.output)) {
- TF_LITE_ENSURE_OK(context,
- ResizeTempAxis(context, &op_context, resolved_axis));
- TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
- }
- if (kernel_type == kReference) {
- reference_ops::ReduceAny(
- GetTensorData<bool>(op_context.input), op_context.input->dims->data,
- op_context.input->dims->size, GetTensorData<bool>(op_context.output),
- op_context.output->dims->data, op_context.output->dims->size,
- GetTensorData<int>(op_context.axis), num_axis,
- op_context.params->keep_dims, GetTensorData<int>(temp_index),
- GetTensorData<int>(resolved_axis));
+ switch (op_context.input->type) {
+ case kTfLiteFloat32:
+ return EvalType<float>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteInt32:
+ return EvalType<int>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteInt64:
+ return EvalType<int64_t>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteUInt8:
+ return EvalType<uint8_t>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteBool:
+ return EvalType<bool>(context, node, &op_context, reduce_type);
+ break;
+ default:
+ return kTfLiteError;
}
-
- return kTfLiteOk;
}
+
} // namespace reduce
TfLiteRegistration* Register_MEAN_REF() {
@@ -521,36 +426,37 @@ TfLiteRegistration* Register_MEAN_REF() {
}
TfLiteRegistration* Register_SUM_REF() {
- static TfLiteRegistration r = {reduce::Init, reduce::Free,
- reduce::PrepareSimple,
- reduce::EvalSum<reduce::kReference>};
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareSimple,
+ reduce::EvalGeneric<reduce::kReference, reduce::kSum>};
return &r;
}
TfLiteRegistration* Register_REDUCE_PROD_REF() {
- static TfLiteRegistration r = {reduce::Init, reduce::Free,
- reduce::PrepareSimple,
- reduce::EvalProd<reduce::kReference>};
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareSimple,
+ reduce::EvalGeneric<reduce::kReference, reduce::kProd>};
return &r;
}
TfLiteRegistration* Register_REDUCE_MAX_REF() {
- static TfLiteRegistration r = {reduce::Init, reduce::Free,
- reduce::PrepareSimple,
- reduce::EvalMax<reduce::kReference>};
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareSimple,
+ reduce::EvalGeneric<reduce::kReference, reduce::kMax>};
return &r;
}
TfLiteRegistration* Register_REDUCE_MIN_REF() {
- static TfLiteRegistration r = {reduce::Init, reduce::Free,
- reduce::PrepareSimple,
- reduce::EvalMin<reduce::kReference>};
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareSimple,
+ reduce::EvalGeneric<reduce::kReference, reduce::kMin>};
return &r;
}
TfLiteRegistration* Register_REDUCE_ANY_REF() {
- static TfLiteRegistration r = {reduce::Init, reduce::Free, reduce::PrepareAny,
- reduce::EvalAny<reduce::kReference>};
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareAny,
+ reduce::EvalGeneric<reduce::kReference, reduce::kAny>};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
index 86c4cd3ee8..dafa3aebab 100644
--- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
@@ -88,11 +88,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
if (output->type == kTfLiteFloat32) {
-#define TF_LITE_RESIZE_BILINEAR(type, datatype) \
- type::ResizeBilinear(GetTensorData<datatype>(input), GetTensorDims(input), \
- GetTensorData<int32>(size), GetTensorDims(size), \
- GetTensorData<datatype>(output), GetTensorDims(output), \
- params->align_corners)
+#define TF_LITE_RESIZE_BILINEAR(type, datatype) \
+ tflite::ResizeBilinearParams op_params; \
+ op_params.align_corners = params->align_corners; \
+ type::ResizeBilinear(op_params, GetTensorShape(input), \
+ GetTensorData<datatype>(input), GetTensorShape(size), \
+ GetTensorData<int32>(size), GetTensorShape(output), \
+ GetTensorData<datatype>(output))
if (kernel_type == kReference) {
TF_LITE_RESIZE_BILINEAR(reference_ops, float);
diff --git a/tensorflow/contrib/lite/kernels/slice.cc b/tensorflow/contrib/lite/kernels/slice.cc
index 6a20e802a9..55e16506df 100644
--- a/tensorflow/contrib/lite/kernels/slice.cc
+++ b/tensorflow/contrib/lite/kernels/slice.cc
@@ -159,10 +159,28 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
sizes.push_back(1);
}
-#define TF_LITE_SLICE(data_type) \
- optimized_ops::Slice<data_type>( \
- GetTensorData<data_type>(input), GetTensorDims(input), begins, sizes, \
- GetTensorData<data_type>(output), GetTensorDims(output))
+ // The original Slice op implementation only accepted 4-D sizes. That
+ // constraint is, for the present, maintained here.
+ //
+ // The dimensions in the kernel used to be in reverse-order, and TFLite
+ // arranged the begins and sizes vectors accordingly. This macro incorporates
+ // the needed reversing.
+#define TF_LITE_SLICE(data_type) \
+ { \
+ TF_LITE_ENSURE_EQ(context, begins.size(), 4); \
+ TF_LITE_ENSURE_EQ(context, sizes.size(), 4); \
+ tflite::SliceParams op_params; \
+ op_params.begin_count = 4; \
+ op_params.size_count = 4; \
+ for (int i = 0; i < 4; ++i) { \
+ op_params.begin[i] = begins[3 - i]; \
+ op_params.size[i] = sizes[3 - i]; \
+ } \
+ \
+ optimized_ops::Slice<data_type>( \
+ op_params, GetTensorShape(input), GetTensorData<data_type>(input), \
+ GetTensorShape(output), GetTensorData<data_type>(output)); \
+ }
switch (input->type) {
case kTfLiteFloat32:
diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
index 03079f1c3b..8332ae32cf 100644
--- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
@@ -114,14 +114,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar, pad_value) \
- type::SpaceToBatchND(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), \
+ tflite::SpaceToBatchParams op_params; \
+ op_params.output_offset = pad_value; \
+ type::SpaceToBatchND(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ GetTensorShape(op_context.block_shape), \
GetTensorData<int32_t>(op_context.block_shape), \
- GetTensorDims(op_context.block_shape), \
+ GetTensorShape(op_context.paddings), \
GetTensorData<int32_t>(op_context.paddings), \
- GetTensorDims(op_context.paddings), \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output), pad_value)
+ GetTensorShape(op_context.output), \
+ GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/contrib/lite/kernels/space_to_depth.cc
index 9dbe9b9eda..9238e879f8 100644
--- a/tensorflow/contrib/lite/kernels/space_to_depth.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_depth.cc
@@ -79,10 +79,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \
- type::SpaceToDepth<scalar>( \
- GetTensorData<scalar>(input), GetTensorDims(input), params->block_size, \
- GetTensorData<scalar>(output), GetTensorDims(output))
+#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \
+ tflite::SpaceToDepthParams op_params; \
+ op_params.block_size = params->block_size; \
+ type::SpaceToDepth(op_params, GetTensorShape(input), \
+ GetTensorData<scalar>(input), GetTensorShape(output), \
+ GetTensorData<scalar>(output))
switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index 0acd705950..c678f14930 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -64,10 +64,14 @@ constexpr int kProjectionWeightsTensor = 16; // Optional
// Projection bias tensor of size {n_output}
constexpr int kProjectionBiasTensor = 17; // Optional
+// Stateful input tensors that are variables and will be modified by the Op.
+// Activation state tensor of size {n_batch, n_output}
+constexpr int kInputActivationStateTensor = 18;
+// Cell state tensor of size {n_batch, n_cell}
+constexpr int kInputCellStateTensor = 19;
+
// Output tensors.
-constexpr int kOutputStateTensor = 0;
-constexpr int kCellStateTensor = 1;
-constexpr int kOutputTensor = 2;
+constexpr int kOutputTensor = 0;
// Temporary tensors
enum TemporaryTensor {
@@ -82,7 +86,7 @@ enum TemporaryTensor {
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- auto* scratch_tensor_index = new int;
+ auto* scratch_tensor_index = new int();
context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
return scratch_tensor_index;
}
@@ -247,8 +251,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 3);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 20);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
@@ -276,12 +280,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
n_output, n_cell));
- // Get the pointer to output, output_state and cell_state buffer tensors.
+ // Get the pointer to output, activation_state and cell_state buffer tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
- TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- // Resize the output, output_state and cell_state tensors.
+ TfLiteTensor* activation_state =
+ GetVariableInput(context, node, kInputActivationStateTensor);
+ TfLiteTensor* cell_state =
+ GetVariableInput(context, node, kInputCellStateTensor);
+
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
+
+ // Resize the output tensors.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = max_time;
output_size->data[1] = n_batch;
@@ -289,22 +302,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size));
- TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2);
- output_state_size->data[0] = n_batch;
- output_state_size->data[1] = n_output;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, output_state, output_state_size));
-
- TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
- cell_size->data[0] = n_batch;
- cell_size->data[1] = n_cell;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, cell_state, cell_size));
-
- // Mark state tensors as persistent tensors.
- output_state->allocation_type = kTfLiteArenaRwPersistent;
- cell_state->allocation_type = kTfLiteArenaRwPersistent;
-
// The weights are of consistent type, so it suffices to check one.
// TODO(mirkov): create a utility/macro for this check, so all Ops can use it.
const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 &&
@@ -340,7 +337,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (is_hybrid_op) {
// Allocate temporary tensors to store quantized values of input,
- // output_state and cell_state tensors.
+ // activation_state and cell_state tensors.
node->temporaries->data[kInputQuantized] =
*scratch_tensor_index + kInputQuantized;
TfLiteTensor* input_quantized =
@@ -354,17 +351,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kOutputStateQuantized] =
*scratch_tensor_index + kOutputStateQuantized;
- TfLiteTensor* output_state_quantized =
+ TfLiteTensor* activation_state_quantized =
GetTemporary(context, node, kOutputStateQuantized);
- output_state_quantized->type = kTfLiteUInt8;
- output_state_quantized->allocation_type = kTfLiteArenaRw;
- if (!TfLiteIntArrayEqual(output_state_quantized->dims,
- output_state->dims)) {
- TfLiteIntArray* output_state_quantized_size =
- TfLiteIntArrayCopy(output_state->dims);
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, output_state_quantized,
- output_state_quantized_size));
+ activation_state_quantized->type = kTfLiteUInt8;
+ activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
+ activation_state->dims)) {
+ TfLiteIntArray* activation_state_quantized_size =
+ TfLiteIntArrayCopy(activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, activation_state_quantized,
+ activation_state_quantized_size));
}
node->temporaries->data[kCellStateQuantized] =
*scratch_tensor_index + kCellStateQuantized;
@@ -449,7 +446,7 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
- TfLiteTensor* output_state, TfLiteTensor* cell_state,
+ TfLiteTensor* activation_state, TfLiteTensor* cell_state,
TfLiteTensor* output) {
const int max_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
@@ -510,7 +507,7 @@ TfLiteStatus EvalFloat(
const float* cell_bias_ptr = cell_bias->data.f;
const float* output_gate_bias_ptr = output_gate_bias->data.f;
- float* output_state_ptr = output_state->data.f;
+ float* activation_state_ptr = activation_state->data.f;
float* cell_state_ptr = cell_state->data.f;
// Feed the sequence into the LSTM step-by-step.
@@ -527,7 +524,7 @@ TfLiteStatus EvalFloat(
cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
output_gate_bias_ptr, projection_weights_ptr, projection_bias_ptr,
- params, n_batch, n_cell, n_input, n_output, output_state_ptr,
+ params, n_batch, n_cell, n_input, n_output, activation_state_ptr,
cell_state_ptr, input_gate_scratch, forget_gate_scratch, cell_scratch,
output_gate_scratch, output_ptr_batch);
}
@@ -552,9 +549,9 @@ TfLiteStatus EvalHybrid(
const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
- TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
- TfLiteTensor* output_state, TfLiteTensor* cell_state,
- TfLiteTensor* output) {
+ TfLiteTensor* activation_state_quantized,
+ TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
const int max_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
const int n_input = input->dims->data[2];
@@ -655,14 +652,14 @@ TfLiteStatus EvalHybrid(
const float* cell_bias_ptr = cell_bias->data.f;
const float* output_gate_bias_ptr = output_gate_bias->data.f;
- float* output_state_ptr = output_state->data.f;
+ float* activation_state_ptr = activation_state->data.f;
float* cell_state_ptr = cell_state->data.f;
// Temporary storage for quantized values and scaling factors.
int8_t* quantized_input_ptr =
reinterpret_cast<int8_t*>(input_quantized->data.uint8);
- int8_t* quantized_output_state_ptr =
- reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
+ int8_t* quantized_activation_state_ptr =
+ reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
int8_t* quantized_cell_state_ptr =
reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
float* scaling_factors_ptr = scaling_factors->data.f;
@@ -692,8 +689,8 @@ TfLiteStatus EvalHybrid(
n_input, n_output, input_gate_scratch, forget_gate_scratch,
cell_scratch, output_gate_scratch, scaling_factors_ptr,
prod_scaling_factors_ptr, recovered_cell_weights_ptr,
- quantized_input_ptr, quantized_output_state_ptr,
- quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
+ quantized_input_ptr, quantized_activation_state_ptr,
+ quantized_cell_state_ptr, activation_state_ptr, cell_state_ptr,
output_ptr_batch);
}
return kTfLiteOk;
@@ -744,8 +741,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
- TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
+ TfLiteTensor* activation_state =
+ GetVariableInput(context, node, kInputActivationStateTensor);
+ TfLiteTensor* cell_state =
+ GetVariableInput(context, node, kInputCellStateTensor);
+
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (input_to_output_weights->type) {
@@ -758,11 +758,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
cell_to_output_weights, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias,
projection_weights, projection_bias, params,
- scratch_buffer, output_state, cell_state, output);
+ scratch_buffer, activation_state, cell_state, output);
}
case kTfLiteUInt8: {
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
- TfLiteTensor* output_state_quantized =
+ TfLiteTensor* activation_state_quantized =
GetTemporary(context, node, /*index=*/2);
TfLiteTensor* cell_state_quantized =
GetTemporary(context, node, /*index=*/3);
@@ -780,8 +780,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias,
projection_weights, projection_bias, params, scratch_buffer,
scaling_factors, prod_scaling_factors, recovered_cell_weights,
- input_quantized, output_state_quantized, cell_state_quantized,
- output_state, cell_state, output);
+ input_quantized, activation_state_quantized, cell_state_quantized,
+ activation_state, cell_state, output);
}
default:
context->ReportError(context, "Type %d is not currently supported.",
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
index de38bdef6f..cd3aac0532 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
@@ -100,8 +100,14 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
- output_state_ = AddOutput(TensorType_FLOAT32);
- cell_state_ = AddOutput(TensorType_FLOAT32);
+ // Adding the 2 input state tensors.
+ input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}},
+ /*is_variable=*/true);
+ input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}},
+ /*is_variable=*/true);
+
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
@@ -180,22 +186,6 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
PopulateTensor(projection_bias_, f);
}
- void ResetOutputState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetCellState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
void SetInput(int offset, const float* begin, const float* end) {
PopulateTensor(input_, offset, const_cast<float*>(begin),
const_cast<float*>(end));
@@ -233,9 +223,10 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
int projection_weights_;
int projection_bias_;
+ int input_activation_state_;
+ int input_cell_state_;
+
int output_;
- int output_state_;
- int cell_state_;
int n_batch_;
int n_input_;
@@ -458,6 +449,9 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToInputWeights(input_to_input_weights_);
@@ -475,10 +469,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -519,6 +509,9 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToInputWeights(input_to_input_weights_);
@@ -536,10 +529,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
/*tolerance=*/0.0157651);
}
@@ -629,6 +618,9 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToCellWeights(input_to_cell_weights_);
@@ -646,10 +638,6 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -691,6 +679,9 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToCellWeights(input_to_cell_weights_);
@@ -708,10 +699,6 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
}
@@ -1351,6 +1338,9 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToInputWeights(input_to_input_weights_);
@@ -1374,10 +1364,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
lstm.SetProjectionWeights(projection_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -1418,6 +1404,9 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToInputWeights(input_to_input_weights_);
@@ -1441,10 +1430,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetProjectionWeights(projection_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
}
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
index 0d6d29a171..0180c2c498 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
@@ -31,12 +31,15 @@ namespace ops {
namespace builtin {
namespace unidirectional_sequence_rnn {
+// Input tensors.
constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1;
constexpr int kRecurrentWeightsTensor = 2;
constexpr int kBiasTensor = 3;
-constexpr int kHiddenStateTensor = 0;
-constexpr int kOutputTensor = 1;
+constexpr int kHiddenStateTensor = 4;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
@@ -50,14 +53,16 @@ void Free(TfLiteContext* context, void* buffer) {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
+ const TfLiteTensor* hidden_state =
+ GetInput(context, node, kHiddenStateTensor);
// Check all the parameters of tensor match within themselves and match the
// input configuration.
@@ -74,20 +79,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, input_weights->type, recurrent_weights->type);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(hidden_state), 2);
+ TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size);
+ TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units);
- TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- // Resize state.
- TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2);
- hidden_state_size_array->data[0] = batch_size;
- hidden_state_size_array->data[1] = num_units;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state,
- hidden_state_size_array));
-
- // Mark hidden state as a persistent tensor.
- hidden_state->allocation_type = kTfLiteArenaRwPersistent;
-
// Resize output.
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(3);
output_size_array->data[0] = (time_major) ? max_time : batch_size;
@@ -276,7 +273,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
- TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
+ // The hidden_state is a variable input tensor that can be modified.
+ TfLiteTensor* hidden_state =
+ const_cast<TfLiteTensor*>(GetInput(context, node, kHiddenStateTensor));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (input_weights->type) {
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
index 0adab837b0..6b48e3fff7 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
@@ -183,7 +183,7 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
weights_ = AddInput(weights);
recurrent_weights_ = AddInput(recurrent_weights);
bias_ = AddInput(TensorType_FLOAT32);
- hidden_state_ = AddOutput(TensorType_FLOAT32);
+ hidden_state_ = AddInput(TensorType_FLOAT32, true);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
BuiltinOptions_SequenceRNNOptions,
@@ -194,12 +194,14 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
BuildInterpreter({{sequence_len_, batches_, input_size_},
{units_, input_size_},
{units_, units_},
- {units_}});
+ {units_},
+ {batches_, units}});
} else {
BuildInterpreter({{batches_, sequence_len_, input_size_},
{units_, input_size_},
{units_, units_},
- {units_}});
+ {units_},
+ {batches_, units_}});
}
}
@@ -221,14 +223,6 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
PopulateTensor(input_, offset, begin, end);
}
- void ResetHiddenState() {
- const int zero_buffer_size = units_ * batches_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(hidden_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
int input_size() { return input_size_; }
@@ -273,7 +267,6 @@ TEST(UnidirectionalRNNOpTest, BlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
float* batch_start = rnn_input;
@@ -299,7 +292,6 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
float* batch_start = rnn_input;
@@ -326,7 +318,6 @@ TEST(UnidirectionalRNNOpTest, TimeMajorBlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
for (int i = 0; i < rnn.sequence_len(); i++) {
float* batch_start = rnn_input + i * rnn.input_size();
@@ -356,7 +347,6 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
for (int i = 0; i < rnn.sequence_len(); i++) {
float* batch_start = rnn_input + i * rnn.input_size();
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 38f3e9881b..602f3ee5d2 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -98,7 +98,10 @@ int32_t GetAndroidSdkVersion() {
return 0;
}
-static const int32_t kAndroidSdkVersion = GetAndroidSdkVersion();
+int32_t GetAndroidSdkVersionCached() {
+ static int32_t androidSdkVersion = GetAndroidSdkVersion();
+ return androidSdkVersion;
+}
} // namespace
@@ -660,7 +663,7 @@ TfLiteStatus AddOpsAndParams(
break;
}
- if (nnapi_version == 11 && kAndroidSdkVersion < 28) {
+ if (nnapi_version == 11 && GetAndroidSdkVersionCached() < 28) {
FATAL("Op %d needs NNAPI1.1", builtin);
}
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 47f0c8e9a2..6e30251eff 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -70,7 +70,7 @@ py_library(
py_test(
name = "lite_test",
srcs = ["lite_test.py"],
- data = [":interpreter_test_data"],
+ data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pbtxt"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
@@ -130,6 +130,7 @@ py_test(
],
deps = [
":convert",
+ ":interpreter",
":op_hint",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 12cc66dc55..1c5516ae7c 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -126,7 +126,7 @@ def build_toco_convert_protos(input_tensors,
reorder_across_fake_quant=False,
allow_custom_ops=False,
change_concat_input_ranges=False,
- quantize_weights=False,
+ post_training_quantize=False,
dump_graphviz_dir=None,
dump_graphviz_video=False):
"""Builds protocol buffers describing a conversion of a model using TOCO.
@@ -149,9 +149,11 @@ def build_toco_convert_protos(input_tensors,
as `input_tensors`, or None. (default None)
output_format: Output file format. Currently must be `{TFLITE,
GRAPHVIZ_DOT}`. (default TFLITE)
- quantized_input_stats: List of tuples of integers representing the mean and
+ quantized_input_stats: List of tuples of floats representing the mean and
standard deviation. Each tuple maps to the corresponding input tensor.
- Only need if `inference_type` is `QUANTIZED_UINT8`. (default None)
+ Only need if `inference_input_type` is `QUANTIZED_UINT8`.
+ real_input_value = (quantized_input_value - mean_value) / std_dev_value.
+ (default None)
default_ranges_stats: Tuple of integers representing (min, max) range values
for all arrays without a specified range. Intended for experimenting with
quantization via "dummy quantization". (default None)
@@ -171,9 +173,9 @@ def build_toco_convert_protos(input_tensors,
change_concat_input_ranges: Boolean to change behavior of min/max ranges for
inputs and outputs of the concat operator for quantized models. Changes
the ranges of concat operator overlap when true. (default False)
- quantize_weights: Boolean indicating whether to store weights as quantized
- weights followed by dequantize operations. Computation is still done in
- float, but reduces model size (at the cost of accuracy and latency).
+ post_training_quantize: Boolean indicating whether to quantize the weights
+ of the converted float model. Model size will be reduced and there will be
+ latency improvements (at the cost of accuracy).
(default False)
dump_graphviz_dir: Full filepath of folder to dump the graphs at various
stages of processing GraphViz .dot files. Preferred over
@@ -197,10 +199,12 @@ def build_toco_convert_protos(input_tensors,
toco.inference_type = inference_type
if inference_input_type:
toco.inference_input_type = inference_input_type
+ else:
+ toco.inference_input_type = toco.inference_type
toco.drop_control_dependency = drop_control_dependency
toco.reorder_across_fake_quant = reorder_across_fake_quant
toco.allow_custom_ops = allow_custom_ops
- toco.quantize_weights = quantize_weights
+ toco.post_training_quantize = post_training_quantize
if default_ranges_stats:
toco.default_ranges_min = default_ranges_stats[0]
toco.default_ranges_max = default_ranges_stats[1]
@@ -212,7 +216,7 @@ def build_toco_convert_protos(input_tensors,
model.change_concat_input_ranges = change_concat_input_ranges
for idx, input_tensor in enumerate(input_tensors):
input_array = model.input_arrays.add()
- if inference_type == lite_constants.QUANTIZED_UINT8:
+ if toco.inference_input_type == lite_constants.QUANTIZED_UINT8:
input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
input_array.name = tensor_name(input_tensor)
if input_shapes is None:
@@ -226,6 +230,54 @@ def build_toco_convert_protos(input_tensors,
return model, toco
+def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
+ *args, **kwargs):
+ """"Convert a model using TOCO.
+
+ This function is used to convert GraphDefs that cannot be loaded into
+ TensorFlow to TFLite. Conversion can be customized by providing arguments
+ that are forwarded to `build_toco_convert_protos` (see documentation for
+ details).
+
+ Args:
+ input_data: Input data (i.e. often `sess.graph_def`),
+ input_arrays_with_shape: Tuple of strings representing input tensor names
+ and list of integers representing input shapes
+ (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
+ into TensorFlow and when `input_tensors` is None. (default None)
+ output_arrays: List of output tensors to freeze graph with. Use only when
+ graph cannot be loaded into TensorFlow and when `output_tensors` is None.
+ (default None)
+ *args: See `build_toco_convert_protos`,
+ **kwargs: See `build_toco_convert_protos`.
+
+ Returns:
+ The converted data. For example if TFLite was the destination, then
+ this will be a tflite flatbuffer in a bytes array.
+
+ Raises:
+ Defined in `build_toco_convert_protos`.
+ """
+ model_flags, toco_flags = build_toco_convert_protos(
+ input_tensors=[], output_tensors=[], *args, **kwargs)
+
+ for idx, (name, shape) in enumerate(input_arrays_with_shape):
+ input_array = model_flags.input_arrays.add()
+ if kwargs["inference_type"] == lite_constants.QUANTIZED_UINT8:
+ input_array.mean_value, input_array.std_value = kwargs[
+ "quantized_input_stats"][idx]
+ input_array.name = name
+ input_array.shape.dims.extend(map(int, shape))
+
+ for name in output_arrays:
+ model_flags.output_arrays.append(name)
+
+ data = toco_convert_protos(model_flags.SerializeToString(),
+ toco_flags.SerializeToString(),
+ input_data.SerializeToString())
+ return data
+
+
def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
**kwargs):
""""Convert a model using TOCO.
diff --git a/tensorflow/contrib/lite/python/convert_test.py b/tensorflow/contrib/lite/python/convert_test.py
index bc05514cec..59f537b82a 100644
--- a/tensorflow/contrib/lite/python/convert_test.py
+++ b/tensorflow/contrib/lite/python/convert_test.py
@@ -17,9 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib.lite.python import convert
from tensorflow.contrib.lite.python import lite_constants
from tensorflow.contrib.lite.python import op_hint
+from tensorflow.contrib.lite.python.interpreter import Interpreter
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
@@ -37,9 +40,12 @@ class ConvertTest(test_util.TensorFlowTestCase):
dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor
sess = session.Session()
+
# Try running on valid graph
- result = convert.toco_convert(sess.graph_def, [in_tensor], [out_tensor])
- self.assertTrue(result)
+ tflite_model = convert.toco_convert(sess.graph_def, [in_tensor],
+ [out_tensor])
+ self.assertTrue(tflite_model)
+
# TODO(aselle): remove tests that fail (we must get TOCO to not fatal
# all the time).
# Try running on identity graph (known fail)
@@ -52,11 +58,85 @@ class ConvertTest(test_util.TensorFlowTestCase):
out_tensor = array_ops.fake_quant_with_min_max_args(in_tensor + in_tensor,
min=0., max=1.)
sess = session.Session()
- result = convert.toco_convert(
+
+ tflite_model = convert.toco_convert(
sess.graph_def, [in_tensor], [out_tensor],
inference_type=lite_constants.QUANTIZED_UINT8,
quantized_input_stats=[(0., 1.)])
- self.assertTrue(result)
+ self.assertTrue(tflite_model)
+
+ def testGraphDefBasic(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name="input")
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ tflite_model = convert.toco_convert_graph_def(
+ sess.graph_def, [("input", [1, 16, 16, 3])], ["add"],
+ inference_type=lite_constants.FLOAT)
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual("input", input_details[0]["name"])
+ self.assertEqual(np.float32, input_details[0]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]["shape"]).all())
+ self.assertEqual((0., 0.), input_details[0]["quantization"])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual("add", output_details[0]["name"])
+ self.assertEqual(np.float32, output_details[0]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
+ self.assertEqual((0., 0.), output_details[0]["quantization"])
+
+ def testGraphDefQuantization(self):
+ in_tensor_1 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputA")
+ in_tensor_2 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputB")
+ _ = array_ops.fake_quant_with_min_max_args(
+ in_tensor_1 + in_tensor_2, min=0., max=1., name="output")
+ sess = session.Session()
+
+ input_arrays_map = [("inputA", [1, 16, 16, 3]), ("inputB", [1, 16, 16, 3])]
+ output_arrays = ["output"]
+ tflite_model = convert.toco_convert_graph_def(
+ sess.graph_def,
+ input_arrays_map,
+ output_arrays,
+ inference_type=lite_constants.QUANTIZED_UINT8,
+ quantized_input_stats=[(0., 1.), (0., 1.)])
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(2, len(input_details))
+ self.assertEqual("inputA", input_details[0]["name"])
+ self.assertEqual(np.uint8, input_details[0]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]["shape"]).all())
+ self.assertEqual((1., 0.),
+ input_details[0]["quantization"]) # scale, zero_point
+
+ self.assertEqual("inputB", input_details[1]["name"])
+ self.assertEqual(np.uint8, input_details[1]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == input_details[1]["shape"]).all())
+ self.assertEqual((1., 0.),
+ input_details[1]["quantization"]) # scale, zero_point
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual("output", output_details[0]["name"])
+ self.assertEqual(np.uint8, output_details[0]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
+ self.assertTrue(output_details[0]["quantization"][0] > 0) # scale
class ConvertTestOpHint(test_util.TensorFlowTestCase):
@@ -243,7 +323,6 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
with self.test_session() as sess:
stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
graph_def=sess.graph_def)
- print(stubbed_graphdef)
self.assertCountEqual(
self._getGraphOpTypes(
stubbed_graphdef,
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 2313bfa3b6..2de97fec86 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -42,6 +42,7 @@ from tensorflow.contrib.lite.python import lite_constants as constants
from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert import tensor_name as _tensor_name
from tensorflow.contrib.lite.python.convert import toco_convert # pylint: disable=unused-import
+from tensorflow.contrib.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def
from tensorflow.contrib.lite.python.convert import toco_convert_impl as _toco_convert_impl
from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
@@ -55,6 +56,7 @@ from tensorflow.python import keras as _keras
from tensorflow.python.client import session as _session
from tensorflow.python.framework import graph_util as _tf_graph_util
from tensorflow.python.framework import ops as _ops
+from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
from tensorflow.python.saved_model import signature_constants as _signature_constants
from tensorflow.python.saved_model import tag_constants as _tag_constants
@@ -76,9 +78,11 @@ class TocoConverter(object):
output_format: Output file format. Currently must be `{TFLITE,
GRAPHVIZ_DOT}`. (default TFLITE)
quantized_input_stats: Dict of strings representing input tensor names
- mapped to tuple of integers representing the mean and standard deviation
+ mapped to tuple of floats representing the mean and standard deviation
of the training data (e.g., {"foo" : (0., 1.)}). Only need if
- `inference_type` is `QUANTIZED_UINT8`. (default {})
+ `inference_input_type` is `QUANTIZED_UINT8`.
+ real_input_value = (quantized_input_value - mean_value) / std_dev_value.
+ (default {})
default_ranges_stats: Tuple of integers representing (min, max) range values
for all arrays without a specified range. Intended for experimenting with
quantization via "dummy quantization". (default None)
@@ -98,9 +102,9 @@ class TocoConverter(object):
created for any op that is unknown. The developer will need to provide
these to the TensorFlow Lite runtime with a custom resolver.
(default False)
- quantize_weights: Boolean indicating whether to store weights as quantized
- weights followed by dequantize operations. Computation is still done in
- float, but reduces model size (at the cost of accuracy and latency).
+ post_training_quantize: Boolean indicating whether to quantize the weights
+ of the converted float model. Model size will be reduced and there will be
+ latency improvements (at the cost of accuracy).
(default False)
dump_graphviz_dir: Full filepath of folder to dump the graphs at various
stages of processing GraphViz .dot files. Preferred over
@@ -133,7 +137,12 @@ class TocoConverter(object):
```
"""
- def __init__(self, graph_def, input_tensors, output_tensors):
+ def __init__(self,
+ graph_def,
+ input_tensors,
+ output_tensors,
+ input_arrays_with_shape=None,
+ output_arrays=None):
"""Constructor for TocoConverter.
Args:
@@ -142,6 +151,17 @@ class TocoConverter(object):
input_tensors: List of input tensors. Type and shape are computed using
`foo.get_shape()` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
+ input_arrays_with_shape: Tuple of strings representing input tensor names
+ and list of integers representing input shapes
+ (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
+ into TensorFlow and when `input_tensors` and `output_tensors` are None.
+ (default None)
+ output_arrays: List of output tensors to freeze graph with. Use only when
+ graph cannot be loaded into TensorFlow and when `input_tensors` and
+ `output_tensors` are None. (default None)
+
+ Raises:
+ ValueError: Invalid arguments.
"""
self._graph_def = graph_def
self._input_tensors = input_tensors
@@ -155,10 +175,19 @@ class TocoConverter(object):
self.reorder_across_fake_quant = False
self.change_concat_input_ranges = False
self.allow_custom_ops = False
- self.quantize_weights = False
+ self.post_training_quantize = False
self.dump_graphviz_dir = None
self.dump_graphviz_video = False
+ # Attributes are used by models that cannot be loaded into TensorFlow.
+ if not self._has_valid_tensors():
+ if not input_arrays_with_shape or not output_arrays:
+ raise ValueError(
+ "If input_tensors and output_tensors are None, both "
+ "input_arrays_with_shape and output_arrays must be defined.")
+ self._input_arrays_with_shape = input_arrays_with_shape
+ self._output_arrays = output_arrays
+
@classmethod
def from_session(cls, sess, input_tensors, output_tensors):
"""Creates a TocoConverter class from a TensorFlow Session.
@@ -200,6 +229,7 @@ class TocoConverter(object):
Unable to parse input file.
The graph is not frozen.
input_arrays or output_arrays contains an invalid tensor name.
+ input_shapes is not correctly defined when required
"""
with _ops.Graph().as_default():
with _session.Session() as sess:
@@ -222,20 +252,44 @@ class TocoConverter(object):
except (_text_format.ParseError, DecodeError):
raise ValueError(
"Unable to parse input file '{}'.".format(graph_def_file))
- _import_graph_def(graph_def, name="")
-
- # Get input and output tensors.
- input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
- output_tensors = _get_tensors_from_tensor_names(sess.graph,
- output_arrays)
- _set_tensor_shapes(input_tensors, input_shapes)
-
- # Check if graph is frozen.
- if not _is_frozen_graph(sess):
- raise ValueError("Please freeze the graph using freeze_graph.py.")
- # Create TocoConverter class.
- return cls(sess.graph_def, input_tensors, output_tensors)
+ # Handles models with custom TFLite ops that cannot be resolved in
+ # TensorFlow.
+ load_model_in_session = True
+ try:
+ _import_graph_def(graph_def, name="")
+ except _NotFoundError:
+ load_model_in_session = False
+
+ if load_model_in_session:
+ # Check if graph is frozen.
+ if not _is_frozen_graph(sess):
+ raise ValueError("Please freeze the graph using freeze_graph.py.")
+
+ # Get input and output tensors.
+ input_tensors = _get_tensors_from_tensor_names(
+ sess.graph, input_arrays)
+ output_tensors = _get_tensors_from_tensor_names(
+ sess.graph, output_arrays)
+ _set_tensor_shapes(input_tensors, input_shapes)
+
+ return cls(sess.graph_def, input_tensors, output_tensors)
+ else:
+ if not input_shapes:
+ raise ValueError("input_shapes must be defined for this model.")
+ if set(input_arrays) != set(input_shapes.keys()):
+ raise ValueError("input_shapes must contain a value for each item "
+ "in input_array.")
+
+ input_arrays_with_shape = [
+ (name, input_shapes[name]) for name in input_arrays
+ ]
+ return cls(
+ graph_def,
+ input_tensors=None,
+ output_tensors=None,
+ input_arrays_with_shape=input_arrays_with_shape,
+ output_arrays=output_arrays)
@classmethod
def from_saved_model(cls,
@@ -330,25 +384,25 @@ class TocoConverter(object):
None value for dimension in input_tensor.
"""
# Checks dimensions in input tensor.
- for tensor in self._input_tensors:
- if not tensor.get_shape():
- raise ValueError("Provide an input shape for input array '{0}'.".format(
- _tensor_name(tensor)))
- shape = tensor.get_shape().as_list()
- if None in shape[1:]:
- raise ValueError(
- "None is only supported in the 1st dimension. Tensor '{0}' has "
- "invalid shape '{1}'.".format(_tensor_name(tensor), shape))
- elif shape[0] is None:
- self._set_batch_size(batch_size=1)
+ if self._has_valid_tensors():
+ for tensor in self._input_tensors:
+ if not tensor.get_shape():
+ raise ValueError("Provide an input shape for input array "
+ "'{0}'.".format(_tensor_name(tensor)))
+ shape = tensor.get_shape().as_list()
+ if None in shape[1:]:
+ raise ValueError(
+ "None is only supported in the 1st dimension. Tensor '{0}' has "
+ "invalid shape '{1}'.".format(_tensor_name(tensor), shape))
+ elif shape[0] is None:
+ self._set_batch_size(batch_size=1)
# Get quantization stats. Ensures there is one stat per name if the stats
# are specified.
if self.quantized_input_stats:
quantized_stats = []
invalid_stats = []
- for tensor in self._input_tensors:
- name = _tensor_name(tensor)
+ for name in self.get_input_arrays():
if name in self.quantized_input_stats:
quantized_stats.append(self.quantized_input_stats[name])
else:
@@ -360,24 +414,35 @@ class TocoConverter(object):
else:
quantized_stats = None
+ converter_kwargs = {
+ "inference_type": self.inference_type,
+ "inference_input_type": self.inference_input_type,
+ "input_format": constants.TENSORFLOW_GRAPHDEF,
+ "output_format": self.output_format,
+ "quantized_input_stats": quantized_stats,
+ "default_ranges_stats": self.default_ranges_stats,
+ "drop_control_dependency": self.drop_control_dependency,
+ "reorder_across_fake_quant": self.reorder_across_fake_quant,
+ "change_concat_input_ranges": self.change_concat_input_ranges,
+ "allow_custom_ops": self.allow_custom_ops,
+ "post_training_quantize": self.post_training_quantize,
+ "dump_graphviz_dir": self.dump_graphviz_dir,
+ "dump_graphviz_video": self.dump_graphviz_video
+ }
+
# Converts model.
- result = _toco_convert_impl(
- input_data=self._graph_def,
- input_tensors=self._input_tensors,
- output_tensors=self._output_tensors,
- inference_type=self.inference_type,
- inference_input_type=self.inference_input_type,
- input_format=constants.TENSORFLOW_GRAPHDEF,
- output_format=self.output_format,
- quantized_input_stats=quantized_stats,
- default_ranges_stats=self.default_ranges_stats,
- drop_control_dependency=self.drop_control_dependency,
- reorder_across_fake_quant=self.reorder_across_fake_quant,
- change_concat_input_ranges=self.change_concat_input_ranges,
- allow_custom_ops=self.allow_custom_ops,
- quantize_weights=self.quantize_weights,
- dump_graphviz_dir=self.dump_graphviz_dir,
- dump_graphviz_video=self.dump_graphviz_video)
+ if self._has_valid_tensors():
+ result = _toco_convert_impl(
+ input_data=self._graph_def,
+ input_tensors=self._input_tensors,
+ output_tensors=self._output_tensors,
+ **converter_kwargs)
+ else:
+ result = _toco_convert_graph_def(
+ input_data=self._graph_def,
+ input_arrays_with_shape=self._input_arrays_with_shape,
+ output_arrays=self._output_arrays,
+ **converter_kwargs)
return result
def get_input_arrays(self):
@@ -386,7 +451,18 @@ class TocoConverter(object):
Returns:
List of strings.
"""
- return [_tensor_name(tensor) for tensor in self._input_tensors]
+ if self._has_valid_tensors():
+ return [_tensor_name(tensor) for tensor in self._input_tensors]
+ else:
+ return [name for name, _ in self._input_arrays_with_shape]
+
+ def _has_valid_tensors(self):
+ """Checks if the input and output tensors have been initialized.
+
+ Returns:
+ Bool.
+ """
+ return self._input_tensors and self._output_tensors
def _set_batch_size(self, batch_size):
"""Sets the first dimension of the input tensor to `batch_size`.
@@ -394,7 +470,14 @@ class TocoConverter(object):
Args:
batch_size: Batch size for the model. Replaces the first dimension of an
input size array if undefined. (default 1)
+
+ Raises:
+ ValueError: input_tensor is not defined.
"""
+ if not self._has_valid_tensors():
+ raise ValueError("The batch size cannot be set for this model. Please "
+ "use input_shapes parameter.")
+
for tensor in self._input_tensors:
shape = tensor.get_shape().as_list()
shape[0] = batch_size
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index 2f13684228..1c94ba605a 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -35,11 +35,51 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer
from tensorflow.python.platform import gfile
+from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
from tensorflow.python.saved_model import saved_model
from tensorflow.python.training.training_util import write_graph
+class FromConstructor(test_util.TensorFlowTestCase):
+
+ # Tests invalid constructors using a dummy value for the GraphDef.
+ def testInvalidConstructor(self):
+ message = ('If input_tensors and output_tensors are None, both '
+ 'input_arrays_with_shape and output_arrays must be defined.')
+
+ # `output_arrays` is not defined.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter(
+ None, None, [], input_arrays_with_shape=[('input', [3, 9])])
+ self.assertEqual(message, str(error.exception))
+
+ # `input_arrays_with_shape` is not defined.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter(None, [], None, output_arrays=['output'])
+ self.assertEqual(message, str(error.exception))
+
+ # Tests valid constructors using a dummy value for the GraphDef.
+ def testValidConstructor(self):
+ converter = lite.TocoConverter(
+ None,
+ None,
+ None,
+ input_arrays_with_shape=[('input', [3, 9])],
+ output_arrays=['output'])
+ self.assertFalse(converter._has_valid_tensors())
+ self.assertEqual(converter.get_input_arrays(), ['input'])
+
+ with self.assertRaises(ValueError) as error:
+ converter._set_batch_size(1)
+ self.assertEqual(
+ 'The batch size cannot be set for this model. Please use '
+ 'input_shapes parameter.', str(error.exception))
+
+ converter = lite.TocoConverter(None, ['input_tensor'], ['output_tensor'])
+ self.assertTrue(converter._has_valid_tensors())
+
+
class FromSessionTest(test_util.TensorFlowTestCase):
def testFloat(self):
@@ -279,6 +319,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
# Convert model and ensure model is not None.
converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
converter.inference_input_type = lite_constants.QUANTIZED_UINT8
+ converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -331,7 +372,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertTrue(output_details[0]['quantization'][0] > 0) # scale
- def testQuantizeWeights(self):
+ def testPostTrainingQuantize(self):
np.random.seed(0)
# We need the tensor to have more than 1024 elements for quantize_weights
# to kick in. Thus, the [33, 33] shape.
@@ -352,14 +393,14 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(float_tflite)
# Convert quantized weights model.
- quantized_weights_converter = lite.TocoConverter.from_session(
+ quantized_converter = lite.TocoConverter.from_session(
sess, [in_tensor_1], [out_tensor])
- quantized_weights_converter.quantize_weights = True
- quantized_weights_tflite = quantized_weights_converter.convert()
- self.assertTrue(quantized_weights_tflite)
+ quantized_converter.post_training_quantize = True
+ quantized_tflite = quantized_converter.convert()
+ self.assertTrue(quantized_tflite)
# Ensure that the quantized weights tflite model is smaller.
- self.assertTrue(len(quantized_weights_tflite) < len(float_tflite))
+ self.assertTrue(len(quantized_tflite) < len(float_tflite))
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
@@ -373,6 +414,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Write graph to file.
graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
write_graph(sess.graph_def, '', graph_def_file, False)
+ sess.close()
# Convert model and ensure model is not None.
converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
@@ -407,6 +449,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Write graph to file.
graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
write_graph(sess.graph_def, '', graph_def_file, False)
+ sess.close()
# Convert model and ensure model is not None.
converter = lite.TocoConverter.from_frozen_graph(
@@ -434,6 +477,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Write graph to file.
graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
write_graph(sess.graph_def, '', graph_def_file, False)
+ sess.close()
# Ensure the graph with variables cannot be converted.
with self.assertRaises(ValueError) as error:
@@ -451,6 +495,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Write graph to file.
graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt')
write_graph(sess.graph_def, '', graph_def_file, True)
+ sess.close()
# Convert model and ensure model is not None.
converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
@@ -490,6 +535,79 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
'Unable to parse input file \'{}\'.'.format(graph_def_file),
str(error.exception))
+ # TODO(nupurgarg): Test model loading in open source.
+ def _initObjectDetectionArgs(self):
+ # Initializes the arguments required for the object detection model.
+ self._graph_def_file = resource_loader.get_path_to_datafile(
+ 'testdata/tflite_graph.pbtxt')
+ self._input_arrays = ['normalized_input_image_tensor']
+ self._output_arrays = [
+ 'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1',
+ 'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:3'
+ ]
+ self._input_shapes = {'normalized_input_image_tensor': [1, 300, 300, 3]}
+
+ def testTFLiteGraphDef(self):
+ # Tests the object detection model that cannot be loaded in TensorFlow.
+ self._initObjectDetectionArgs()
+
+ converter = lite.TocoConverter.from_frozen_graph(
+ self._graph_def_file, self._input_arrays, self._output_arrays,
+ self._input_shapes)
+ converter.allow_custom_ops = True
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('normalized_input_image_tensor', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 300, 300, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(4, len(output_details))
+ self.assertEqual('TFLite_Detection_PostProcess', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 10, 4] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ self.assertEqual('TFLite_Detection_PostProcess:1',
+ output_details[1]['name'])
+ self.assertTrue(([1, 10] == output_details[1]['shape']).all())
+ self.assertEqual('TFLite_Detection_PostProcess:2',
+ output_details[2]['name'])
+ self.assertTrue(([1, 10] == output_details[2]['shape']).all())
+ self.assertEqual('TFLite_Detection_PostProcess:3',
+ output_details[3]['name'])
+ self.assertTrue(([1] == output_details[3]['shape']).all())
+
+ def testTFLiteGraphDefInvalid(self):
+ # Tests invalid cases for the model that cannot be loaded in TensorFlow.
+ self._initObjectDetectionArgs()
+
+ # Missing `input_shapes`.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_frozen_graph(
+ self._graph_def_file, self._input_arrays, self._output_arrays)
+ self.assertEqual('input_shapes must be defined for this model.',
+ str(error.exception))
+
+ # `input_shapes` does not contain the names in `input_arrays`.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_frozen_graph(
+ self._graph_def_file,
+ self._input_arrays,
+ self._output_arrays,
+ input_shapes={'invalid-value': [1, 19]})
+ self.assertEqual(
+ 'input_shapes must contain a value for each item in input_array.',
+ str(error.exception))
+
class FromSavedModelTest(test_util.TensorFlowTestCase):
@@ -628,26 +746,27 @@ class FromKerasFile(test_util.TensorFlowTestCase):
keras.backend.clear_session()
def _getSequentialModel(self):
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(2, input_shape=(3,)))
- model.add(keras.layers.RepeatVector(3))
- model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
- model.compile(
- loss=keras.losses.MSE,
- optimizer=keras.optimizers.RMSprop(),
- metrics=[keras.metrics.categorical_accuracy],
- sample_weight_mode='temporal')
- x = np.random.random((1, 3))
- y = np.random.random((1, 3, 3))
- model.train_on_batch(x, y)
- model.predict(x)
-
- try:
- fd, keras_file = tempfile.mkstemp('.h5')
- keras.models.save_model(model, keras_file)
- finally:
- os.close(fd)
- return keras_file
+ with session.Session().as_default():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+ model.predict(x)
+
+ try:
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
+ return keras_file
def testSequentialModel(self):
"""Test a Sequential tf.keras model with default inputs."""
@@ -752,25 +871,26 @@ class FromKerasFile(test_util.TensorFlowTestCase):
def testFunctionalModel(self):
"""Test a Functional tf.keras model with default inputs."""
- inputs = keras.layers.Input(shape=(3,), name='input')
- x = keras.layers.Dense(2)(inputs)
- output = keras.layers.Dense(3)(x)
-
- model = keras.models.Model(inputs, output)
- model.compile(
- loss=keras.losses.MSE,
- optimizer=keras.optimizers.RMSprop(),
- metrics=[keras.metrics.categorical_accuracy])
- x = np.random.random((1, 3))
- y = np.random.random((1, 3))
- model.train_on_batch(x, y)
-
- model.predict(x)
- fd, keras_file = tempfile.mkstemp('.h5')
- try:
- keras.models.save_model(model, keras_file)
- finally:
- os.close(fd)
+ with session.Session().as_default():
+ inputs = keras.layers.Input(shape=(3,), name='input')
+ x = keras.layers.Dense(2)(inputs)
+ output = keras.layers.Dense(3)(x)
+
+ model = keras.models.Model(inputs, output)
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy])
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+ model.train_on_batch(x, y)
+
+ model.predict(x)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ try:
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
# Convert to TFLite model.
converter = lite.TocoConverter.from_keras_model_file(keras_file)
@@ -809,36 +929,39 @@ class FromKerasFile(test_util.TensorFlowTestCase):
def testFunctionalModelMultipleInputs(self):
"""Test a Functional tf.keras model with multiple inputs and outputs."""
- a = keras.layers.Input(shape=(3,), name='input_a')
- b = keras.layers.Input(shape=(3,), name='input_b')
- dense = keras.layers.Dense(4, name='dense')
- c = dense(a)
- d = dense(b)
- e = keras.layers.Dropout(0.5, name='dropout')(c)
-
- model = keras.models.Model([a, b], [d, e])
- model.compile(
- loss=keras.losses.MSE,
- optimizer=keras.optimizers.RMSprop(),
- metrics=[keras.metrics.mae],
- loss_weights=[1., 0.5])
-
- input_a_np = np.random.random((10, 3))
- input_b_np = np.random.random((10, 3))
- output_d_np = np.random.random((10, 4))
- output_e_np = np.random.random((10, 4))
- model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
-
- model.predict([input_a_np, input_b_np], batch_size=5)
- fd, keras_file = tempfile.mkstemp('.h5')
- keras.models.save_model(model, keras_file)
+ with session.Session().as_default():
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
+
+ model = keras.models.Model([a, b], [d, e])
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.mae],
+ loss_weights=[1., 0.5])
+
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 3))
+ output_d_np = np.random.random((10, 4))
+ output_e_np = np.random.random((10, 4))
+ model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
+
+ model.predict([input_a_np, input_b_np], batch_size=5)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ try:
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
# Convert to TFLite model.
converter = lite.TocoConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
- os.close(fd)
os.remove(keras_file)
# Check values from converted model.
@@ -871,28 +994,29 @@ class FromKerasFile(test_util.TensorFlowTestCase):
def testFunctionalSequentialModel(self):
"""Test a Functional tf.keras model containing a Sequential model."""
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(2, input_shape=(3,)))
- model.add(keras.layers.RepeatVector(3))
- model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
- model = keras.models.Model(model.input, model.output)
-
- model.compile(
- loss=keras.losses.MSE,
- optimizer=keras.optimizers.RMSprop(),
- metrics=[keras.metrics.categorical_accuracy],
- sample_weight_mode='temporal')
- x = np.random.random((1, 3))
- y = np.random.random((1, 3, 3))
- model.train_on_batch(x, y)
- model.predict(x)
-
- model.predict(x)
- fd, keras_file = tempfile.mkstemp('.h5')
- try:
- keras.models.save_model(model, keras_file)
- finally:
- os.close(fd)
+ with session.Session().as_default():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+ model = keras.models.Model(model.input, model.output)
+
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+ model.predict(x)
+
+ model.predict(x)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ try:
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
# Convert to TFLite model.
converter = lite.TocoConverter.from_keras_model_file(keras_file)
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index 46bdb3e553..cc08ed3fe9 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -109,8 +109,14 @@ def _convert_model(flags):
if flags.mean_values and flags.std_dev_values:
input_arrays = converter.get_input_arrays()
- std_dev_values = _parse_array(flags.std_dev_values, type_fn=int)
- mean_values = _parse_array(flags.mean_values, type_fn=int)
+ std_dev_values = _parse_array(flags.std_dev_values, type_fn=float)
+
+ # In quantized inference, mean_value has to be integer so that the real
+ # value 0.0 is exactly representable.
+ if flags.inference_type == lite_constants.QUANTIZED_UINT8:
+ mean_values = _parse_array(flags.mean_values, type_fn=int)
+ else:
+ mean_values = _parse_array(flags.mean_values, type_fn=float)
quant_stats = list(zip(mean_values, std_dev_values))
if ((not flags.input_arrays and len(input_arrays) > 1) or
(len(input_arrays) != len(quant_stats))):
@@ -132,14 +138,18 @@ def _convert_model(flags):
if flags.reorder_across_fake_quant:
converter.reorder_across_fake_quant = flags.reorder_across_fake_quant
if flags.change_concat_input_ranges:
- converter.change_concat_input_ranges = flags.change_concat_input_ranges
+ converter.change_concat_input_ranges = (
+ flags.change_concat_input_ranges == "TRUE")
if flags.allow_custom_ops:
converter.allow_custom_ops = flags.allow_custom_ops
- if flags.quantize_weights:
+
+ if flags.post_training_quantize:
+ converter.post_training_quantize = flags.post_training_quantize
if flags.inference_type == lite_constants.QUANTIZED_UINT8:
- raise ValueError("--quantized_weights is not supported with "
- "--inference_type=QUANTIZED_UINT8")
- converter.quantize_weights = flags.quantize_weights
+ print("--post_training_quantize quantizes a graph of inference_type "
+ "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.")
+ converter.inference_type = lite_constants.FLOAT
+
if flags.dump_graphviz_dir:
converter.dump_graphviz_dir = flags.dump_graphviz_dir
if flags.dump_graphviz_video:
@@ -292,12 +302,13 @@ def run_main(_):
"--std_dev_values",
type=str,
help=("Standard deviation of training data for each input tensor, "
- "comma-separated integers. Used for quantization. (default None)"))
+ "comma-separated floats. Used for quantized input tensors. "
+ "(default None)"))
parser.add_argument(
"--mean_values",
type=str,
help=("Mean of training data for each input tensor, comma-separated "
- "integers. Used for quantization. (default None)"))
+ "floats. Used for quantized input tensors. (default None)"))
parser.add_argument(
"--default_ranges_min",
type=int,
@@ -310,12 +321,20 @@ def run_main(_):
help=("Default value for max bound of min/max range values used for all "
"arrays without a specified range, Intended for experimenting with "
"quantization via \"dummy quantization\". (default None)"))
+ # quantize_weights is DEPRECATED.
parser.add_argument(
"--quantize_weights",
+ dest="post_training_quantize",
action="store_true",
- help=("Store float weights as quantized weights followed by dequantize "
- "operations. Inference is still done in FLOAT, but reduces model "
- "size (at the cost of accuracy and latency)."))
+ help=argparse.SUPPRESS)
+ parser.add_argument(
+ "--post_training_quantize",
+ dest="post_training_quantize",
+ action="store_true",
+ help=(
+ "Boolean indicating whether to quantize the weights of the "
+ "converted float model. Model size will be reduced and there will "
+ "be latency improvements (at the cost of accuracy). (default False)"))
# Graph manipulation flags.
parser.add_argument(
@@ -333,9 +352,14 @@ def run_main(_):
"the graph. Results in a graph that differs from the quantized "
"training graph, potentially causing differing arithmetic "
"behavior. (default False)"))
+ # Usage for this flag is --change_concat_input_ranges=true or
+ # --change_concat_input_ranges=false in order to make it clear what the flag
+ # is set to. This keeps the usage consistent with other usages of the flag
+ # where the default is different. The default value here is False.
parser.add_argument(
"--change_concat_input_ranges",
- action="store_true",
+ type=str.upper,
+ choices=["TRUE", "FALSE"],
help=("Boolean to change behavior of min/max ranges for inputs and "
"outputs of the concat operator for quantized models. Changes the "
"ranges of concat operator overlap when true. (default False)"))
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 02d0890a7a..a75553db84 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -213,7 +213,6 @@ cc_library(
"graph_transformations/quantization_util.cc",
"graph_transformations/quantization_util.h",
"graph_transformations/quantize.cc",
- "graph_transformations/quantize_weights.cc",
"graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc",
"graph_transformations/remove_final_dequantize_op.cc",
"graph_transformations/remove_tensorflow_assert.cc",
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index aef35ad490..84f71dc7a7 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -236,8 +236,9 @@ struct ParsedTocoFlags {
Arg<bool> drop_fake_quant = Arg<bool>(false);
Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
Arg<bool> allow_custom_ops = Arg<bool>(false);
- Arg<bool> quantize_weights = Arg<bool>(false);
+ Arg<bool> post_training_quantize = Arg<bool>(false);
// Deprecated flags
+ Arg<bool> quantize_weights = Arg<bool>(false);
Arg<string> input_type;
Arg<string> input_types;
Arg<bool> debug_disable_recurrent_cell_fusion = Arg<bool>(false);
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 6fdf47dedc..b52a79282c 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1701,9 +1701,11 @@ void ConvertReduceOperator(const Model& model, const T& src_op,
*new_op->add_input() = src_op.inputs[0];
*new_op->add_input() = src_op.inputs[1];
- const tensorflow::DataType params_type =
- GetTensorFlowDataType(model, src_op.inputs[0]);
- (*new_op->mutable_attr())["T"].set_type(params_type);
+ if (src_op.type != OperatorType::kAny) {
+ const tensorflow::DataType params_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+ }
const tensorflow::DataType indices_type =
GetTensorFlowDataType(model, src_op.inputs[1]);
(*new_op->mutable_attr())["Tidx"].set_type(indices_type);
diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
index 4bf47aa3c4..84680b968e 100644
--- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
+++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
@@ -24,8 +24,8 @@ Table of contents:
* [Multiple output arrays](#multiple-output-arrays)
* [Specifying subgraphs](#specifying-subgraphs)
* [Graph visualizations](#graph-visualizations)
- * [Using --output_format=GRAPHVIZ_DOT](#using-output-formatgraphviz-dot)
- * [Using --dump_graphviz](#using-dump-graphviz)
+ * [Using --output_format=GRAPHVIZ_DOT](#using-output-format-graphviz-dot)
+ * [Using --dump_graphviz_dir](#using-dump-graphviz-dir)
* [Graph "video" logging](#graph-video-logging)
* [Legend for the graph visualizations](#graphviz-legend)
@@ -247,17 +247,17 @@ function tends to get fused).
## Graph visualizations
-TOCO can export a graph to the GraphViz Dot format for easy visualization via
+TOCO can export a graph to the Graphviz Dot format for easy visualization via
either the `--output_format` flag or the `--dump_graphviz_dir` flag. The
subsections below outline the use cases for each.
-### Using `--output_format=GRAPHVIZ_DOT`
+### Using `--output_format=GRAPHVIZ_DOT` <a name="using-output-format-graphviz-dot"></a>
-The first way to get a graphviz rendering is to pass `GRAPHVIZ_DOT` into
+The first way to get a Graphviz rendering is to pass `GRAPHVIZ_DOT` into
`--output_format`. This results in a plausible visualization of the graph. This
-reduces the requirements that exist during conversion between other input and
-output formats. This may be useful if conversion from TENSORFLOW_GRAPHDEF to
-TFLITE is failing.
+reduces the requirements that exist during conversion from a TensorFlow GraphDef
+to a TensorFlow Lite FlatBuffer. This may be useful if the conversion to TFLite
+is failing.
```
curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \
@@ -287,10 +287,10 @@ google-chrome /tmp/foo.dot.pdf
Example PDF files are viewable online in the next section.
-### Using `--dump_graphviz`
+### Using `--dump_graphviz_dir`
-The second way to get a graphviz rendering is to pass the `--dump_graphviz_dir`
-flag, specifying a destination directory to dump GraphViz rendering to. Unlike
+The second way to get a Graphviz rendering is to pass the `--dump_graphviz_dir`
+flag, specifying a destination directory to dump Graphviz rendering to. Unlike
the previous approach, this one retains the original output format. This
provides a visualization of the actual graph resulting from a specific
conversion process.
diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
index decc8a45a4..00bc8d4ccb 100644
--- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
+++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
@@ -38,7 +38,7 @@ files. The flag `--output_file` is always required. Additionally, either
of TFLite specific transformations. Therefore, the resulting
visualization may not reflect the final set of graph
transformations. To get a final visualization with all graph
- transformations use `--dump_graphviz` instead.
+ transformations use `--dump_graphviz_dir` instead.
The following flags specify optional parameters when using SavedModels.
@@ -67,21 +67,22 @@ based on index.
* `--input_shapes`. Type: colon-separated list of comma-separated lists of
integers. Each comma-separated list of integers gives the shape of one of
- the input arrays specified in [TensorFlow
- convention](https://www.tensorflow.org/versions/r1.2/programmers_guide/dims_types#shape).
+ the input arrays specified in
+ [TensorFlow convention](https://www.tensorflow.org/versions/r1.2/programmers_guide/dims_types#shape).
* Example: `--input_shapes=1,60,80,3` for a typical vision model means a
batch size of 1, an input image height of 60, an input image width of
80, and an input image depth of 3 (representing RGB channels).
* Example: `--input_arrays=foo,bar --input_shapes=2,3:4,5,6` means "foo"
has a shape of [2, 3] and "bar" has a shape of [4, 5, 6].
-* `--std_dev_values`, `--mean_values`. Type: comma-separated list of integers.
+* `--std_dev_values`, `--mean_values`. Type: comma-separated list of floats.
These specify the (de-)quantization parameters of the input array, when it
- is quantized.
+ is quantized. This is only needed if `inference_input_type` is
+ `QUANTIZED_UINT8`.
* The meaning of `mean_values` and `std_dev_values` is as follows: each
quantized value in the quantized input array will be interpreted as a
mathematical real number (i.e. as an input activation value) according
to the following formula:
- * `real_value = (quantized_input_value - mean_value) / std_value`.
+ * `real_value = (quantized_input_value - mean_value) / std_dev_value`.
* When performing float inference (`--inference_type=FLOAT`) on a
quantized input, the quantized input would be immediately dequantized by
the inference code according to the above formula, before proceeding
@@ -91,7 +92,8 @@ based on index.
the inference code. However, the quantization parameters of all arrays,
including those of the input arrays as specified by `mean_value` and
`std_dev_value`, determine the fixed-point multipliers used in the
- quantized inference code.
+ quantized inference code. `mean_value` must be an integer when
+ performing quantized inference.
## Transformation flags
@@ -147,10 +149,10 @@ have.
true, custom ops are created for any op that is unknown. The developer will
need to provide these to the TensorFlow Lite runtime with a custom resolver.
-* `--quantize_weights`. Type: boolean. Default: False. Indicates whether to
- store weights as quantized weights followed by dequantize operations.
- Computation is still done in float, but reduces model size (at the cost of
- accuracy and latency).
+* `--post_training_quantize`. Type: boolean. Default: False. Boolean
+ indicating whether to quantize the weights of the converted float model.
+ Model size will be reduced and there will be latency improvements (at the
+ cost of accuracy).
## Logging flags
diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md
index 3799eac0a1..51f808d4f0 100644
--- a/tensorflow/contrib/lite/toco/g3doc/python_api.md
+++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md
@@ -70,6 +70,7 @@ val = img + var
out = tf.identity(val, name="out")
with tf.Session() as sess:
+ sess.run(tf.global_variables_initializer())
converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out])
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 99f4a7d8f6..fdd0632451 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -142,7 +142,6 @@ DECLARE_GRAPH_TRANSFORMATION(PropagateFakeQuantNumBits);
DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes)
DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax)
DECLARE_GRAPH_TRANSFORMATION(Quantize)
-DECLARE_GRAPH_TRANSFORMATION(QuantizeWeights)
DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp)
DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert)
DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity)
@@ -178,9 +177,10 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveSpaceToBatchNDAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveBatchToSpaceNDAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolvePadV2Attributes)
-DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes)
-DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveReduceAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveReshapeAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveTransposeAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantPack)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRandomUniform)
@@ -217,12 +217,6 @@ class PropagateDefaultMinMax : public GraphTransformation {
std::vector<std::pair<ArrayDataType, MinMax>> type_ranges_;
};
-class ResolveReshapeAttributes : public GraphTransformation {
- public:
- bool Run(Model* model, std::size_t op_index) override;
- const char* Name() const override { return "ResolveReshapeAttributes"; }
-};
-
class RemoveTrivialReshape : public GraphTransformation {
public:
bool Run(Model* model, std::size_t op_index) override;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 28effc2a67..c25be078ff 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -561,26 +561,38 @@ void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
const bool keep_dims = KeepDims(*op);
if (op->inputs.size() == 2) {
// There is a reduction_indices input.
- const auto& reduction_array = model->GetArray(op->inputs[1]);
- if (!reduction_array.buffer) {
+ const auto& reduction_indices_array = model->GetArray(op->inputs[1]);
+ if (!reduction_indices_array.buffer) {
return;
}
- CHECK(reduction_array.buffer->type == ArrayDataType::kInt32);
- const auto& reduction_array_vals =
- reduction_array.GetBuffer<ArrayDataType::kInt32>().data;
- auto& output_dims = *output_array.mutable_shape()->mutable_dims();
- output_dims.clear();
- for (int i = 0; i < input_shape.dimensions_count(); i++) {
- bool is_reduction_dim = false;
- for (int r : reduction_array_vals) {
- if (i == r) {
- is_reduction_dim = true;
- }
+ CHECK(reduction_indices_array.buffer->type == ArrayDataType::kInt32);
+
+ int input_rank = input_shape.dimensions_count();
+ std::set<int32> true_indices;
+ const auto& reduction_indices =
+ reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data;
+ for (int i = 0; i < reduction_indices.size(); ++i) {
+ const int32 reduction_index = reduction_indices[i];
+ if (reduction_index < -input_rank || reduction_index >= input_rank) {
+ CHECK(false) << "Invalid reduction dimension " << reduction_index
+ << " for input with " << input_rank << " dimensions";
+ }
+ int32 wrapped_index = reduction_index;
+ if (wrapped_index < 0) {
+ wrapped_index += input_rank;
}
- if (!is_reduction_dim) {
- output_dims.push_back(input_shape.dims(i));
- } else if (keep_dims) {
- output_dims.push_back(1);
+ true_indices.insert(wrapped_index);
+ }
+
+ auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
+ mutable_dims->clear();
+ for (int i = 0; i < input_rank; ++i) {
+ if (true_indices.count(i) > 0) {
+ if (keep_dims) {
+ mutable_dims->emplace_back(1);
+ }
+ } else {
+ mutable_dims->emplace_back(input_shape.dims(i));
}
}
} else {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index 8d22ae2eb1..1bc366f555 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -62,7 +62,8 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kLessEqual || type == OperatorType::kSelect ||
type == OperatorType::kArgMax || type == OperatorType::kRelu ||
type == OperatorType::kRelu1 || type == OperatorType::kRelu6 ||
- type == OperatorType::kShape || type == OperatorType::kExpandDims;
+ type == OperatorType::kShape || type == OperatorType::kExpandDims ||
+ type == OperatorType::kPack || type == OperatorType::kTopK_V2;
}
// The quantized op allows output arrays of type float using
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc
deleted file mode 100644
index 7a8515f6d1..0000000000
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc
+++ /dev/null
@@ -1,106 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <iterator>
-#include <string>
-#include <vector>
-
-#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
-#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h"
-#include "tensorflow/contrib/lite/toco/model.h"
-#include "tensorflow/contrib/lite/toco/tooling_util.h"
-
-namespace toco {
-
-namespace {
-
-// The minimum number of elements a weights array must have to be quantized
-// by this transformation.
-// TODO(suharshs): Make this minimum size configurable.
-const int kWeightsMinSize = 1024;
-
-// Gets the quantization params from the float array.
-void GetQuantizationParamsFromArray(const Array& array,
- QuantizationParams* params) {
- const std::vector<float>& float_vals =
- array.GetBuffer<ArrayDataType::kFloat>().data;
- auto minmax = std::minmax_element(float_vals.begin(), float_vals.end());
- *params = tflite::ChooseQuantizationParams<uint8>(
- *minmax.first, *minmax.second, array.narrow_range);
-}
-
-} // namespace
-
-bool QuantizeWeights::Run(Model* model, std::size_t op_index) {
- const auto op_it = model->operators.begin() + op_index;
- Operator* op = op_it->get();
-
- // Get the weights tensor, if the current operator has one.
- int weights_index;
- if (op->type == OperatorType::kConv ||
- op->type == OperatorType::kDepthwiseConv ||
- op->type == OperatorType::kFullyConnected) {
- weights_index = 1;
- } else if (op->type == OperatorType::kLstmCell) {
- weights_index = LstmCellOperator::WEIGHTS_INPUT;
- } else {
- return false;
- }
-
- // Return early if the array isn't a constant param, this can happen in early
- // transformation passes until transpose operations following the weight array
- // are resolved.
- const string weights = op->inputs[weights_index];
- if (!IsConstantParameterArray(*model, weights)) {
- return false;
- }
-
- // Return early if the weight tensor is not type float.
- Array& weights_array = model->GetArray(weights);
- if (weights_array.data_type != ArrayDataType::kFloat) {
- return false;
- }
-
- // Return early if the tensor is too small. Small tensors don't take up too
- // much space and can result in bad quantization results.
- if (weights_array.GetBuffer<ArrayDataType::kFloat>().data.size() <
- kWeightsMinSize) {
- return false;
- }
-
- // Quantize the weight tensor to type kUint8.
- QuantizationParams params;
- GetQuantizationParamsFromArray(weights_array, &params);
- QuantizeArray(this, model, weights, ArrayDataType::kUint8, params);
-
- // Insert a Dequantize operation after the quantized weights tensor.
- auto* dequantize_op = new DequantizeOperator;
- model->operators.emplace(op_it, dequantize_op);
-
- // Create a new intermediate tensor to connect the Dequantize op to the
- // original op.
- const string dequantized_output =
- AvailableArrayName(*model, weights + "_dequantized");
- Array& dequantized_output_array = model->GetOrCreateArray(dequantized_output);
- dequantized_output_array.data_type = ArrayDataType::kFloat;
-
- // Connect up the new Dequantize op with the weights and original op.
- op->inputs[weights_index] = dequantized_output;
- dequantize_op->inputs = {weights};
- dequantize_op->outputs = {dequantized_output};
-
- return true;
-}
-
-} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
index 475415e481..c698a9567a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
@@ -51,6 +51,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
// Test for unary ops of types that we know how to resolve.
switch (unary_op->type) {
case OperatorType::kCast:
+ case OperatorType::kExp:
case OperatorType::kLog:
case OperatorType::kNeg:
case OperatorType::kRsqrt:
@@ -218,7 +219,8 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
max = std::max(max, (*input_float_data)[i]);
}
output_float_data[0] = max;
- } else if (unary_op->type == OperatorType::kNeg ||
+ } else if (unary_op->type == OperatorType::kExp ||
+ unary_op->type == OperatorType::kNeg ||
unary_op->type == OperatorType::kLog ||
unary_op->type == OperatorType::kRsqrt ||
unary_op->type == OperatorType::kSqrt ||
@@ -231,7 +233,9 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
for (int i = 0; i < output_buffer_size; i++) {
const float val = (*input_float_data)[i];
float outval = 0.f;
- if (unary_op->type == OperatorType::kNeg) {
+ if (unary_op->type == OperatorType::kExp) {
+ outval = std::exp(val);
+ } else if (unary_op->type == OperatorType::kNeg) {
outval = -val;
} else if (unary_op->type == OperatorType::kLog) {
outval = std::log(val);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc
index 7d456af2fb..73198ac7c0 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc
@@ -52,6 +52,8 @@ bool ResolveReduceAttributes::Run(Model* model, std::size_t op_index) {
return ResolveAttributes(model, static_cast<TensorFlowMinOperator*>(op));
case OperatorType::kReduceMax:
return ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
+ case OperatorType::kAny:
+ return ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
default:
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
index e163fc9ae1..acf1e3ede5 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
@@ -20,19 +20,6 @@ tf_cc_test(
)
tf_cc_test(
- name = "quantize_weights_test",
- srcs = ["quantize_weights_test.cc"],
- tags = ["no_oss"],
- deps = [
- "//tensorflow/contrib/lite/toco:graph_transformations",
- "//tensorflow/contrib/lite/toco:model",
- "//tensorflow/contrib/lite/toco:tooling_util",
- "@com_google_absl//absl/memory",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-tf_cc_test(
name = "resolve_constant_concatenation_test",
srcs = ["resolve_constant_concatenation_test.cc"],
tags = ["no_oss"],
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc
deleted file mode 100644
index c05eb0929f..0000000000
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc
+++ /dev/null
@@ -1,167 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <math.h>
-#include <string>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/memory/memory.h"
-#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
-#include "tensorflow/contrib/lite/toco/model.h"
-#include "tensorflow/contrib/lite/toco/tooling_util.h"
-
-namespace toco {
-
-class QuantizeWeightsTest : public ::testing::Test {
- protected:
- QuantizeWeightsTest() {}
-
- // The name of the weights input array.
- const string kWeightsName = "weights";
- // The zero_point of the values in the input array.
- const int kZeroPoint = 128;
-
- // Prepare a hypothetical TOCO model of a quantizable fully connected float
- // layer.
- void PrepareModel(Model* model, int elements_per_dim) {
- std::vector<string> fc_input_names = {"inputs", kWeightsName};
-
- const int kDim = 4;
- const int buf_size = std::pow(elements_per_dim, static_cast<double>(kDim));
- auto in_buf = absl::make_unique<float[]>(buf_size);
- // Initialize the array with values from -128.0 to 127.0, since these values
- // should be exactly representable by quantization.
- for (int i = 0; i < buf_size; i++) {
- in_buf[i] = static_cast<float>(i % 256 - kZeroPoint);
- }
-
- for (const string& fc_input_name : fc_input_names) {
- Array& in_array = model->GetOrCreateArray(fc_input_name);
- in_array.data_type = ArrayDataType::kFloat;
-
- // Initialize shape for the input array.
- Shape* in_array_shape = in_array.mutable_shape();
- std::vector<int>* in_array_shape_dim = in_array_shape->mutable_dims();
- in_array_shape_dim->resize(kDim, elements_per_dim);
- auto& in_array_buffer =
- in_array.GetMutableBuffer<ArrayDataType::kFloat>();
- in_array_buffer.data.resize(buf_size);
- float* buf_ptr =
- in_array.GetMutableBuffer<ArrayDataType::kFloat>().data.data();
- std::copy(in_buf.get(), in_buf.get() + buf_size, buf_ptr);
- }
-
- auto* fc_op = new FullyConnectedOperator;
- fc_op->inputs = fc_input_names;
- fc_op->outputs = {"fc_op_outputs"};
- Array& out_array = model->GetOrCreateArray(fc_op->outputs[0]);
- out_array.data_type = ArrayDataType::kFloat;
- Shape* out_array_shape = out_array.mutable_shape();
- std::vector<int>* out_array_shape_dim = out_array_shape->mutable_dims();
- out_array_shape_dim->resize(kDim, elements_per_dim);
- model->operators.push_back(std::unique_ptr<Operator>(fc_op));
- }
-};
-
-TEST_F(QuantizeWeightsTest, QuantizedFullyConnected) {
- // Test that weight arrays that are large enough are quantized.
- Model model;
- // 6 elements per dim gives us 1296 elements, which is sufficient to be
- // quantized.
- PrepareModel(&model, 6);
-
- // Check the state of the graph before the transformation.
- const auto& float_array_map = model.GetArrayMap();
- EXPECT_EQ(float_array_map.size(), 3);
- // Before the transformation, all arrays should be type float.
- for (const auto& element : float_array_map) {
- EXPECT_EQ(element.second->data_type, ArrayDataType::kFloat);
- }
- const std::vector<float> float_weight_vals =
- model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kFloat>().data;
-
- // Invoke the transformation.
- GraphTransformationsSet graph_transformation_set;
- graph_transformation_set.Add(new toco::QuantizeWeights);
- (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
-
- // Check the state of the graph after the transformation.
- const auto& quantized_array_map = model.GetArrayMap();
- EXPECT_EQ(quantized_array_map.size(), 4);
- // After the transformation, three arrays should be type float and one array
- // should be uint8.
- int num_float = 0;
- int num_uint8 = 0;
- for (const auto& element : quantized_array_map) {
- if (element.second->data_type == ArrayDataType::kFloat) {
- num_float++;
- } else if (element.second->data_type == ArrayDataType::kUint8) {
- num_uint8++;
- } else {
- FAIL() << "Unexpected array type.";
- }
- }
- EXPECT_EQ(num_float, 3);
- EXPECT_EQ(num_uint8, 1);
- // Ensure that the values were quantized correctly.
- const std::vector<uint8>& quantized_weight_vals =
- model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kUint8>().data;
- for (int i = 0; i < quantized_weight_vals.size(); i++) {
- EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i] + kZeroPoint);
- }
-
- // Ensure that a Dequantize operator has been inserted before the
- // FullyConnectedLayer.
- EXPECT_EQ(model.operators[0]->type, OperatorType::kDequantize);
-}
-
-TEST_F(QuantizeWeightsTest, NotQuantizedFullyConnected) {
- // Test that weight arrays that are too small are left untouched.
- Model model;
- // 5 elements per dim gives us 625 elements, which is NOT sufficient to be
- // quantized.
- PrepareModel(&model, 5);
-
- // Check the state of the graph before the transformation.
- const auto& float_array_map = model.GetArrayMap();
- EXPECT_EQ(float_array_map.size(), 3);
- // Before the transformation, all arrays should be type float.
- for (auto it = float_array_map.begin(); it != float_array_map.end(); it++) {
- EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat);
- }
- std::vector<float> float_weight_vals =
- model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kFloat>().data;
-
- // Invoke the transformation.
- GraphTransformationsSet graph_transformation_set;
- graph_transformation_set.Add(new toco::QuantizeWeights);
- (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
-
- // Check the state of the graph after the transformation.
- const auto& post_array_map = model.GetArrayMap();
- EXPECT_EQ(post_array_map.size(), 3);
- for (auto it = post_array_map.begin(); it != post_array_map.end(); it++) {
- EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat);
- }
- // Ensure that the values remain unchanged.
- std::vector<float> const& quantized_weight_vals =
- model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kFloat>().data;
- for (int i = 0; i < quantized_weight_vals.size(); i++) {
- EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i]);
- }
-}
-
-} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index fa1c459f0e..2e100e37f6 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -1768,6 +1768,7 @@ struct PowOperator : Operator {
//
// Inputs:
// Inputs[0]: required: A boolean input tensor.
+// Inputs[1]: required: reduction_indices.
//
// TensorFlow equivalent: tf.reduce_any.
struct TensorFlowAnyOperator : Operator {
diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD
index 709c53606b..71cdb7703e 100644
--- a/tensorflow/contrib/lite/toco/tflite/BUILD
+++ b/tensorflow/contrib/lite/toco/tflite/BUILD
@@ -91,6 +91,7 @@ cc_library(
"//tensorflow/contrib/lite/schema:schema_fbs",
"//tensorflow/contrib/lite/toco:model",
"//tensorflow/contrib/lite/toco:tooling_util",
+ "//tensorflow/contrib/lite/tools/optimize:quantize_weights",
"@com_google_absl//absl/strings",
"@flatbuffers",
],
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index 5ad307af14..c79469f59b 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -16,10 +16,12 @@ limitations under the License.
#include "flatbuffers/flexbuffers.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/toco/tflite/operator.h"
#include "tensorflow/contrib/lite/toco/tflite/types.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/contrib/lite/tools/optimize/quantize_weights.h"
#include "tensorflow/contrib/lite/version.h"
namespace toco {
@@ -61,6 +63,13 @@ details::OperatorKey GetOperatorKey(
return details::OperatorKey(op.type, custom_code, version);
}
+void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder,
+ string* file_contents) {
+ const uint8_t* buffer = builder.GetBufferPointer();
+ int size = builder.GetSize();
+ *file_contents = string(reinterpret_cast<const char*>(buffer), size);
+}
+
} // Anonymous namespace.
namespace details {
@@ -311,14 +320,16 @@ Offset<Vector<Offset<Buffer>>> ExportBuffers(
return builder->CreateVector(buffer_vector);
}
-void Export(const Model& model, bool allow_custom_ops,
+void Export(const Model& model, bool allow_custom_ops, bool quantize_weights,
string* output_file_contents) {
const auto ops_by_type = BuildOperatorByTypeMap();
- Export(model, allow_custom_ops, output_file_contents, ops_by_type);
+ Export(model, allow_custom_ops, quantize_weights, output_file_contents,
+ ops_by_type);
}
void Export(
- const Model& model, bool allow_custom_ops, string* output_file_contents,
+ const Model& model, bool allow_custom_ops, bool quantize_weights,
+ string* output_file_contents,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
@@ -365,7 +376,7 @@ void Export(
"the standard TensorFlow Lite runtime. If you have a custom "
"implementation for them you can disable this error with "
"--allow_custom_ops, or by setting allow_custom_ops=True "
- "when calling tf.contrib.lite.toco_convert(). Here is a list "
+ "when calling tf.contrib.lite.TocoConverter(). Here is a list "
"of operators for which you will need custom implementations: "
<< absl::StrJoin(error_summary_final, ", ") << ".";
}
@@ -390,9 +401,24 @@ void Export(
CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
builder.CreateVector(subgraphs), description, buffers);
::tflite::FinishModelBuffer(builder, new_model_location);
- const uint8_t* buffer = builder.GetBufferPointer();
- int size = builder.GetSize();
- *output_file_contents = string(reinterpret_cast<const char*>(buffer), size);
+
+ if (quantize_weights) {
+ // Call the quantize_weights tool.
+ LOG(INFO) << "Quantizing TFLite model after conversion to flatbuffer. "
+ "dump_graphviz will only output the model before this "
+ "transformation. To visualize the output graph use "
+ "lite/tools/optimize.py.";
+ flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);
+ const uint8_t* buffer = builder.GetBufferPointer();
+ const ::tflite::Model* input_model = ::tflite::GetModel(buffer);
+ if (::tflite::optimize::QuantizeWeights(&q_builder, input_model) !=
+ kTfLiteOk) {
+ LOG(QFATAL) << "Quantize weights transformation failed.";
+ }
+ WriteModelToString(q_builder, output_file_contents);
+ } else {
+ WriteModelToString(builder, output_file_contents);
+ }
}
} // namespace tflite
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h
index 58ea5c725c..915d5dd3d6 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.h
+++ b/tensorflow/contrib/lite/toco/tflite/export.h
@@ -25,18 +25,19 @@ namespace tflite {
// Transform the given tf.mini model into a TF Lite flatbuffer and deposit the
// result in the given string.
-void Export(const Model& model, bool allow_custom_ops,
+void Export(const Model& model, bool allow_custom_ops, bool quantize_weights,
string* output_file_contents);
// This if backward-compatibility.
// TODO(ycling): Remove the deprecated entry functions.
inline void Export(const Model& model, string* output_file_contents) {
- Export(model, true, output_file_contents);
+ Export(model, true, false, output_file_contents);
}
// Export API with custom TFLite operator mapping.
void Export(
- const Model& model, bool allow_custom_ops, string* output_file_contents,
+ const Model& model, bool allow_custom_ops, bool quantize_weights,
+ string* output_file_contents,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
namespace details {
diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
index a95937ba0f..4994ea30de 100644
--- a/tensorflow/contrib/lite/toco/tflite/export_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -52,6 +52,42 @@ class ExportTest : public ::testing::Test {
input_model_.operators.emplace_back(new SubOperator);
}
+ void BuildQuantizableTestModel() {
+ input_model_.GetOrCreateArray("inputs");
+ Array& weight_array = input_model_.GetOrCreateArray("weights");
+
+ // Make the buffer large enough for QuantizeWeights transformation to take
+ // effect.
+ int buf_size = 1296;
+ auto weight_buf = absl::make_unique<float[]>(buf_size);
+ for (int i = 0; i < buf_size; i++) {
+ // Fill the array with some garbage values.
+ weight_buf[i] = static_cast<float>(i % 128);
+ }
+
+ weight_array.data_type = ArrayDataType::kFloat;
+
+ // Initialize shape for the input array.
+ Shape* weight_array_shape = weight_array.mutable_shape();
+ std::vector<int>* weight_array_shape_dim =
+ weight_array_shape->mutable_dims();
+ weight_array_shape_dim->resize(4, 6);
+ auto& weight_array_buffer =
+ weight_array.GetMutableBuffer<ArrayDataType::kFloat>();
+ weight_array_buffer.data.resize(buf_size);
+ float* buf_ptr =
+ weight_array.GetMutableBuffer<ArrayDataType::kFloat>().data.data();
+ std::copy(weight_buf.get(), weight_buf.get() + buf_size, buf_ptr);
+
+ {
+ auto* op = new ConvOperator;
+ op->padding.type = PaddingType::kSame;
+ op->inputs = {"inputs", "weights"};
+ input_model_.operators.emplace_back(op);
+ }
+ input_model_.operators.emplace_back(new AddOperator);
+ }
+
Model input_model_;
};
@@ -81,7 +117,7 @@ TEST_F(ExportTest, Export) {
BuildTestModel();
string result;
- Export(input_model_, true, &result);
+ Export(input_model_, true, false, &result);
auto* model = ::tflite::GetModel(result.data());
@@ -108,6 +144,20 @@ TEST_F(ExportTest, Export) {
EXPECT_THAT(indices, ElementsAre(1, 0, 3, 2));
}
+TEST_F(ExportTest, QuantizeWeights) {
+ // Sanity check for quantize_weights parameter.
+ BuildQuantizableTestModel();
+ string unquantized_result;
+ Export(input_model_, true, /*quantize_weights*/ false, &unquantized_result);
+
+ BuildQuantizableTestModel();
+ string quantized_result;
+ Export(input_model_, true, /*quantize_weights*/ true, &quantized_result);
+
+ // The quantized models should be smaller.
+ EXPECT_LT(quantized_result.size(), unquantized_result.size());
+}
+
// This test is based on a hypothetical scenario that dilation is supported
// only in Conv version 2. So Toco populates version=1 when dialation
// parameters are all 1, and version=2 otehrwise.
@@ -239,7 +289,7 @@ TEST_F(VersionedOpExportTest, Export) {
string result;
const auto ops_by_type = BuildFakeOperatorByTypeMap();
- Export(input_model_, true, &result, ops_by_type);
+ Export(input_model_, true, false, &result, ops_by_type);
auto* model = ::tflite::GetModel(result.data());
auto operator_codes = model->operator_codes();
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
index c6d0a03452..f83a290195 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -160,10 +160,12 @@ bool ParseTocoFlagsFromCommandLineFlags(
"Ignored if the output format is not TFLite."),
Flag("quantize_weights", parsed_flags.quantize_weights.bind(),
parsed_flags.quantize_weights.default_value(),
- "Store weights as quantized weights followed by dequantize "
- "operations. Computation is still done in float, but reduces model "
- "size (at the cost of accuracy and latency)."),
- };
+ "Deprecated. Please use --post_training_quantize instead."),
+ Flag("post_training_quantize", parsed_flags.post_training_quantize.bind(),
+ parsed_flags.post_training_quantize.default_value(),
+ "Boolean indicating whether to quantize the weights of the "
+ "converted float model. Model size will be reduced and there will "
+ "be latency improvements (at the cost of accuracy).")};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
if (asked_for_help) {
@@ -257,6 +259,7 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone);
READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone);
READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone);
+ READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone);
// Deprecated flag handling.
if (parsed_toco_flags.input_type.specified()) {
@@ -291,9 +294,19 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
toco_flags->set_inference_input_type(input_type);
}
if (parsed_toco_flags.quantize_weights.value()) {
- QCHECK_NE(toco_flags->inference_type(), IODataType::QUANTIZED_UINT8)
- << "quantize_weights is not supported with inference_type "
- "QUANTIZED_UINT8.";
+ LOG(WARNING)
+ << "--quantize_weights is deprecated. Falling back to "
+ "--post_training_quantize. Please switch --post_training_quantize.";
+ toco_flags->set_post_training_quantize(
+ parsed_toco_flags.quantize_weights.value());
+ }
+ if (parsed_toco_flags.quantize_weights.value()) {
+ if (toco_flags->inference_type() == IODataType::QUANTIZED_UINT8) {
+ LOG(WARNING)
+ << "--post_training_quantize quantizes a graph of inference_type "
+ "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.";
+ toco_flags->set_inference_type(IODataType::FLOAT);
+ }
}
#undef READ_TOCO_FLAG
diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto
index b4a9870d58..c1dd621429 100644
--- a/tensorflow/contrib/lite/toco/toco_flags.proto
+++ b/tensorflow/contrib/lite/toco/toco_flags.proto
@@ -37,7 +37,7 @@ enum FileFormat {
// of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job.
//
-// Next ID to use: 26.
+// Next ID to use: 27.
message TocoFlags {
// Input file format
optional FileFormat input_format = 1;
@@ -173,6 +173,7 @@ message TocoFlags {
// Store weights as quantized weights followed by dequantize operations.
// Computation is still done in float, but reduces model size (at the cost of
// accuracy and latency).
+ // DEPRECATED: Please use post_training_quantize instead.
optional bool quantize_weights = 20 [default = false];
// Full filepath of folder to dump the graphs at various stages of processing
@@ -183,4 +184,9 @@ message TocoFlags {
// Boolean indicating whether to dump the graph after every graph
// transformation.
optional bool dump_graphviz_include_video = 25;
+
+ // Boolean indicating whether to quantize the weights of the converted float
+ // model. Model size will be reduced and there will be latency improvements
+ // (at the cost of accuracy).
+ optional bool post_training_quantize = 26 [default = false];
}
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 34130a02b0..7db7acb44d 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -281,12 +281,6 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
RunGraphTransformations(model, "general graph transformations",
transformations);
- if (toco_flags.quantize_weights()) {
- // Run the quantize weights transformation after batchnorms have been
- // folded into the weights.
- RunGraphTransformations(model, "quantize weights transformation",
- {new QuantizeWeights});
- }
if (quantize_output) {
if (toco_flags.propagate_fake_quant_num_bits()) {
RunGraphTransformations(model,
@@ -404,7 +398,9 @@ void Export(const TocoFlags& toco_flags, const Model& model,
ExportTensorFlowGraphDef(model, output_file_contents);
break;
case TFLITE:
- toco::tflite::Export(model, allow_custom_ops, output_file_contents);
+ toco::tflite::Export(model, allow_custom_ops,
+ toco_flags.post_training_quantize(),
+ output_file_contents);
break;
case GRAPHVIZ_DOT:
DumpGraphviz(model, output_file_contents);
diff --git a/tensorflow/contrib/lite/tools/accuracy/BUILD b/tensorflow/contrib/lite/tools/accuracy/BUILD
index 21941f5c8b..1b60d6a60d 100644
--- a/tensorflow/contrib/lite/tools/accuracy/BUILD
+++ b/tensorflow/contrib/lite/tools/accuracy/BUILD
@@ -6,6 +6,7 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
+load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
common_linkopts = tflite_linkopts() + select({
"//conditions:default": [],
@@ -44,6 +45,10 @@ tf_cc_test(
data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"],
linkopts = common_linkopts,
linkstatic = 1,
+ tags = [
+ "tflite_not_portable_android",
+ "tflite_not_portable_ios",
+ ],
deps = [
":utils",
"@com_google_googletest//:gtest",
@@ -102,6 +107,10 @@ tf_cc_test(
data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"],
linkopts = common_linkopts,
linkstatic = 1,
+ tags = [
+ "tflite_not_portable_android",
+ "tflite_not_portable_ios",
+ ],
deps = [
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:scope",
@@ -152,6 +161,7 @@ tf_cc_test(
srcs = ["file_reader_stage_test.cc"],
linkopts = common_linkopts,
linkstatic = 1,
+ tags = ["tflite_not_portable_ios"],
deps = [
":file_reader_stage",
"@com_google_googletest//:gtest",
@@ -226,6 +236,7 @@ tf_cc_test(
srcs = ["eval_pipeline_test.cc"],
linkopts = common_linkopts,
linkstatic = 1,
+ tags = ["tflite_not_portable_ios"],
deps = [
":eval_pipeline",
"//tensorflow/cc:cc_ops",
@@ -277,6 +288,7 @@ tf_cc_test(
srcs = ["eval_pipeline_builder_test.cc"],
linkopts = common_linkopts,
linkstatic = 1,
+ tags = ["tflite_not_portable_ios"],
deps = [
":eval_pipeline_builder",
"//tensorflow/cc:cc_ops",
@@ -312,3 +324,5 @@ cc_library(
},
),
)
+
+tflite_portable_test_suite()
diff --git a/tensorflow/contrib/lite/tools/accuracy/README.md b/tensorflow/contrib/lite/tools/accuracy/README.md
index 769ef201d2..8100cd1e8c 100644
--- a/tensorflow/contrib/lite/tools/accuracy/README.md
+++ b/tensorflow/contrib/lite/tools/accuracy/README.md
@@ -28,13 +28,11 @@ Tensor input = ... read input for the model ...
Tensor ground_truth = ... read ground truth for the model ...
TF_CHECK_OK(eval_pipeline.Run(input1, ground_truth1));
```
-For further examples, check the usage in [imagenet accuracy evaluation binary]
-(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc)
+For further examples, check the usage in [imagenet accuracy evaluation binary](ilsvrc/imagenet_model_evaluator.cc)
## Measuring accuracy of published models.
### ILSVRC (Imagenet Large Scale Visual Recognition Contest) classification task
-For measuring accuracy for [ILSVRC 2012 image classification task]
-(http://www.image-net.org/challenges/LSVRC/2012/), the binary can be built
+For measuring accuracy for [ILSVRC 2012 image classification task](http://www.image-net.org/challenges/LSVRC/2012/), the binary can be built
using these
[instructions.](ilsvrc/)
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD
index db4b688a45..a66812fe87 100644
--- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD
@@ -6,6 +6,7 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
+load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
common_linkopts = tflite_linkopts() + select({
"//conditions:default": [],
@@ -52,6 +53,10 @@ tf_cc_test(
data = [":testdata/grace_hopper.jpg"],
linkopts = common_linkopts,
linkstatic = 1,
+ tags = [
+ "tflite_not_portable_android",
+ "tflite_not_portable_ios",
+ ],
deps = [
":inception_preprocessing",
"//tensorflow/contrib/lite/tools/accuracy:android_required_build_flags",
@@ -85,6 +90,7 @@ cc_library(
],
"//conditions:default": [
"//tensorflow/core:framework",
+ "//tensorflow/core:lib",
],
},
),
@@ -95,6 +101,7 @@ tf_cc_test(
srcs = ["imagenet_topk_eval_test.cc"],
linkopts = common_linkopts,
linkstatic = 1,
+ tags = ["tflite_not_portable_ios"],
deps = [
":imagenet_topk_eval",
"@com_google_googletest//:gtest",
@@ -137,6 +144,7 @@ cc_library(
],
"//conditions:default": [
"//tensorflow/core:tensorflow",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:framework_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -169,3 +177,5 @@ tf_cc_binary(
},
),
)
+
+tflite_portable_test_suite()
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md
index 9b3b99451d..362ea3ac34 100644
--- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md
@@ -30,9 +30,17 @@ The binary takes the following parameters:
This is the path to the output file. The output is a CSV file that has top-10 accuracies in each row. Each line of output file is the cumulative accuracy after processing images in a sorted order. So first line is accuracy after processing the first image, second line is accuracy after procesing first two images. The last line of the file is accuracy after processing the entire validation set.
and the following optional parameters:
+
+* `blacklist_file_path`: `string` \
+ Path to blacklist file. This file contains the indices of images that are blacklisted for evaluation. 1762 images are blacklisted in ILSVRC dataset. For details please refer to readme.txt of ILSVRC2014 devkit.
+
* `num_images`: `int` (default=0) \
The number of images to process, if 0, all images in the directory are processed otherwise only num_images will be processed.
+* `num_threads`: `int` (default=4) \
+ The number of threads to use for evaluation.
+
+
## Downloading ILSVRC
In order to use this tool to run evaluation on the full 50K ImageNet dataset,
download the data set from http://image-net.org/request.
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt
new file mode 100644
index 0000000000..b2f00e034e
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt
@@ -0,0 +1,1762 @@
+36
+50
+56
+103
+127
+195
+199
+226
+230
+235
+251
+254
+288
+397
+485
+543
+556
+601
+605
+652
+653
+663
+666
+697
+699
+705
+745
+774
+815
+816
+845
+848
+951
+977
+1006
+1008
+1018
+1056
+1066
+1079
+1102
+1128
+1133
+1188
+1193
+1194
+1266
+1271
+1372
+1382
+1405
+1426
+1430
+1441
+1477
+1502
+1518
+1606
+1621
+1642
+1658
+1716
+1722
+1734
+1750
+1807
+1880
+1882
+1936
+1951
+1970
+1977
+1983
+2086
+2112
+2146
+2152
+2217
+2304
+2321
+2404
+2526
+2554
+2563
+2647
+2675
+2732
+2733
+2827
+2839
+2854
+2865
+2872
+2880
+2886
+2893
+2915
+2973
+2993
+3019
+3020
+3044
+3047
+3049
+3117
+3167
+3197
+3201
+3282
+3311
+3315
+3344
+3345
+3378
+3425
+3477
+3497
+3514
+3525
+3531
+3587
+3637
+3650
+3657
+3686
+3720
+3732
+3798
+3802
+3823
+3847
+3971
+4007
+4059
+4072
+4087
+4099
+4124
+4126
+4156
+4195
+4197
+4241
+4275
+4321
+4333
+4352
+4356
+4368
+4377
+4428
+4440
+4497
+4509
+4513
+4526
+4528
+4565
+4570
+4596
+4633
+4677
+4696
+4743
+4759
+4778
+4835
+4976
+5032
+5058
+5061
+5066
+5140
+5145
+5177
+5197
+5219
+5226
+5228
+5240
+5289
+5292
+5385
+5433
+5445
+5448
+5465
+5488
+5549
+5553
+5609
+5638
+5666
+5683
+5711
+5729
+5760
+5793
+5819
+5837
+5855
+5858
+5961
+5966
+6048
+6197
+6199
+6201
+6206
+6215
+6220
+6264
+6278
+6280
+6305
+6388
+6411
+6466
+6490
+6509
+6523
+6529
+6625
+6754
+6818
+6886
+6890
+6893
+6902
+6912
+6942
+7067
+7141
+7144
+7214
+7217
+7278
+7312
+7320
+7329
+7342
+7345
+7369
+7408
+7428
+7463
+7556
+7557
+7582
+7613
+7621
+7624
+7647
+7671
+7679
+7734
+7736
+7747
+7750
+7777
+7851
+7854
+7883
+7889
+7902
+7985
+7999
+8070
+8087
+8096
+8100
+8128
+8180
+8195
+8367
+8377
+8465
+8497
+8508
+8528
+8538
+8581
+8657
+8692
+8742
+8784
+8839
+8861
+8912
+8970
+8982
+8987
+9103
+9155
+9180
+9248
+9284
+9300
+9357
+9382
+9414
+9450
+9463
+9493
+9522
+9543
+9563
+9630
+9643
+9653
+9693
+9747
+9787
+9847
+9851
+9892
+9913
+9929
+9965
+10026
+10027
+10055
+10154
+10189
+10243
+10297
+10337
+10346
+10347
+10377
+10403
+10483
+10518
+10540
+10559
+10567
+10568
+10580
+10606
+10615
+10618
+10645
+10685
+10707
+10710
+10807
+10837
+10856
+10873
+10989
+11046
+11054
+11132
+11163
+11218
+11243
+11255
+11265
+11292
+11306
+11307
+11310
+11343
+11349
+11407
+11411
+11422
+11427
+11431
+11439
+11496
+11644
+11662
+11690
+11692
+11725
+11743
+11767
+11812
+11867
+11871
+11897
+11975
+12001
+12046
+12076
+12119
+12158
+12216
+12252
+12261
+12264
+12293
+12296
+12306
+12357
+12358
+12371
+12415
+12422
+12472
+12497
+12499
+12538
+12540
+12544
+12569
+12645
+12647
+12652
+12699
+12727
+12750
+12832
+12849
+12873
+12889
+12902
+12996
+13029
+13065
+13073
+13075
+13079
+13268
+13338
+13372
+13529
+13530
+13537
+13623
+13626
+13637
+13644
+13646
+13681
+13778
+13782
+13805
+13846
+13853
+13881
+13914
+13961
+13975
+13979
+14011
+14135
+14143
+14144
+14161
+14170
+14207
+14212
+14215
+14260
+14311
+14368
+14373
+14400
+14509
+14523
+14566
+14594
+14628
+14629
+14633
+14649
+14652
+14705
+14709
+14732
+14734
+14802
+14834
+14865
+14883
+14933
+14965
+15003
+15100
+15159
+15178
+15272
+15289
+15308
+15319
+15327
+15353
+15357
+15363
+15408
+15429
+15438
+15469
+15485
+15495
+15501
+15524
+15530
+15551
+15598
+15613
+15614
+15631
+15646
+15647
+15661
+15679
+15684
+15758
+15775
+15826
+15838
+15840
+15931
+15940
+15969
+15976
+16003
+16037
+16045
+16116
+16200
+16233
+16247
+16339
+16340
+16345
+16361
+16400
+16408
+16430
+16468
+16474
+16500
+16521
+16565
+16569
+16584
+16613
+16645
+16662
+16671
+16719
+16724
+16760
+16764
+16805
+16849
+16893
+16896
+16954
+16979
+17023
+17026
+17034
+17038
+17049
+17054
+17061
+17073
+17074
+17133
+17163
+17176
+17177
+17217
+17237
+17246
+17298
+17312
+17324
+17337
+17365
+17415
+17442
+17449
+17576
+17578
+17581
+17588
+17589
+17591
+17593
+17605
+17661
+17688
+17689
+17695
+17697
+17703
+17736
+17746
+17758
+17788
+17798
+17828
+17841
+17884
+17898
+17924
+17956
+17960
+18001
+18013
+18025
+18052
+18097
+18106
+18158
+18211
+18223
+18240
+18261
+18266
+18297
+18325
+18329
+18335
+18340
+18351
+18433
+18462
+18466
+18524
+18569
+18581
+18631
+18696
+18748
+18766
+18787
+18793
+18950
+18961
+19001
+19008
+19011
+19154
+19177
+19217
+19255
+19286
+19320
+19333
+19360
+19403
+19407
+19419
+19464
+19499
+19510
+19519
+19555
+19564
+19605
+19610
+19689
+19699
+19705
+19707
+19725
+19732
+19741
+19774
+19799
+19838
+19877
+19903
+19940
+19945
+19952
+19973
+19987
+20024
+20086
+20111
+20114
+20174
+20193
+20201
+20245
+20299
+20329
+20439
+20485
+20534
+20562
+20575
+20578
+20601
+20604
+20605
+20648
+20658
+20665
+20677
+20693
+20697
+20699
+20791
+20794
+20808
+20876
+20890
+20906
+20914
+20990
+21065
+21128
+21144
+21151
+21156
+21175
+21199
+21204
+21207
+21225
+21236
+21241
+21342
+21351
+21429
+21533
+21550
+21622
+21676
+21727
+21764
+21785
+21822
+21830
+21845
+21853
+21867
+21909
+21910
+21923
+21924
+21937
+21948
+21955
+21962
+22008
+22017
+22026
+22037
+22072
+22075
+22135
+22138
+22160
+22167
+22190
+22287
+22375
+22440
+22457
+22460
+22471
+22481
+22484
+22488
+22515
+22553
+22679
+22703
+22714
+22730
+22735
+22752
+22768
+22809
+22813
+22817
+22846
+22902
+22910
+22944
+22986
+23026
+23053
+23065
+23088
+23117
+23124
+23126
+23132
+23142
+23165
+23172
+23223
+23264
+23280
+23322
+23335
+23439
+23453
+23455
+23474
+23501
+23518
+23580
+23589
+23608
+23614
+23641
+23649
+23660
+23698
+23728
+23766
+23809
+23859
+23874
+23902
+23946
+24040
+24105
+24132
+24137
+24151
+24153
+24157
+24171
+24271
+24281
+24296
+24303
+24308
+24328
+24332
+24338
+24402
+24440
+24453
+24466
+24504
+24531
+24543
+24547
+24556
+24562
+24610
+24649
+24660
+24693
+24706
+24745
+24834
+24948
+24963
+25056
+25057
+25083
+25093
+25120
+25150
+25161
+25197
+25219
+25220
+25253
+25257
+25290
+25327
+25332
+25344
+25387
+25390
+25422
+25453
+25481
+25489
+25587
+25599
+25600
+25622
+25681
+25686
+25702
+25708
+25740
+25776
+25870
+25918
+25973
+25978
+25986
+25987
+26033
+26038
+26041
+26087
+26113
+26155
+26162
+26184
+26235
+26299
+26301
+26318
+26364
+26383
+26430
+26511
+26528
+26561
+26618
+26653
+26688
+26697
+26778
+26940
+26951
+27023
+27029
+27037
+27046
+27051
+27118
+27244
+27252
+27258
+27272
+27283
+27303
+27381
+27392
+27403
+27422
+27437
+27440
+27476
+27493
+27494
+27501
+27506
+27550
+27559
+27571
+27581
+27596
+27604
+27612
+27665
+27687
+27701
+27711
+27732
+27759
+27766
+27772
+27797
+27813
+27854
+27864
+27865
+27879
+27894
+27907
+27958
+27963
+27969
+28003
+28027
+28032
+28051
+28058
+28079
+28093
+28120
+28132
+28194
+28227
+28324
+28328
+28331
+28360
+28373
+28419
+28431
+28436
+28451
+28467
+28471
+28527
+28541
+28588
+28640
+28649
+28662
+28670
+28678
+28722
+28768
+28780
+28835
+28863
+28879
+28885
+28928
+28948
+28954
+28963
+28969
+29020
+29065
+29077
+29105
+29117
+29143
+29166
+29172
+29299
+29302
+29342
+29357
+29378
+29410
+29411
+29414
+29415
+29447
+29473
+29488
+29499
+29505
+29533
+29537
+29601
+29637
+29650
+29667
+29671
+29681
+29686
+29708
+29721
+29749
+29755
+29771
+29853
+29886
+29894
+29919
+29928
+29990
+30008
+30064
+30067
+30107
+30150
+30160
+30164
+30186
+30195
+30219
+30243
+30282
+30314
+30324
+30389
+30418
+30497
+30550
+30592
+30615
+30624
+30640
+30650
+30695
+30720
+30741
+30750
+30751
+30767
+30830
+30856
+30885
+30901
+30907
+30953
+30985
+31005
+31027
+31034
+31045
+31057
+31071
+31109
+31119
+31227
+31230
+31250
+31303
+31320
+31371
+31401
+31440
+31447
+31464
+31478
+31487
+31494
+31525
+31553
+31554
+31558
+31572
+31588
+31639
+31641
+31683
+31698
+31704
+31708
+31717
+31722
+31781
+31786
+31788
+31791
+31803
+31850
+31853
+31862
+31886
+31901
+31944
+32020
+32048
+32052
+32073
+32094
+32116
+32147
+32180
+32212
+32218
+32256
+32270
+32305
+32411
+32414
+32430
+32465
+32484
+32534
+32584
+32589
+32608
+32612
+32613
+32615
+32641
+32674
+32697
+32708
+32757
+32763
+32796
+32824
+32861
+32877
+32944
+32945
+32946
+32984
+33004
+33012
+33029
+33050
+33090
+33096
+33097
+33124
+33139
+33161
+33170
+33173
+33179
+33191
+33293
+33367
+33370
+33371
+33373
+33399
+33415
+33436
+33440
+33443
+33488
+33551
+33563
+33564
+33629
+33643
+33664
+33685
+33696
+33714
+33722
+33728
+33764
+33809
+33868
+33883
+33913
+33942
+33956
+33994
+34081
+34089
+34091
+34098
+34178
+34207
+34269
+34287
+34348
+34392
+34445
+34447
+34455
+34529
+34579
+34591
+34643
+34659
+34692
+34729
+34758
+34836
+34857
+34862
+34883
+34930
+34942
+34957
+34963
+35003
+35089
+35180
+35187
+35209
+35220
+35239
+35247
+35253
+35263
+35380
+35393
+35394
+35408
+35452
+35485
+35486
+35557
+35578
+35639
+35663
+35688
+35746
+35832
+35862
+35890
+35903
+35917
+35929
+35946
+35984
+36060
+36084
+36090
+36124
+36135
+36151
+36197
+36249
+36269
+36303
+36364
+36377
+36398
+36402
+36418
+36421
+36435
+36499
+36511
+36521
+36544
+36556
+36601
+36627
+36640
+36660
+36673
+36676
+36787
+36790
+36797
+36821
+36840
+36901
+36921
+36934
+37006
+37041
+37051
+37112
+37160
+37167
+37213
+37231
+37242
+37274
+37313
+37332
+37391
+37416
+37522
+37594
+37621
+37664
+37699
+37731
+37915
+37968
+38030
+38070
+38117
+38128
+38135
+38172
+38184
+38224
+38277
+38295
+38311
+38428
+38464
+38529
+38549
+38599
+38623
+38673
+38681
+38713
+38722
+38726
+38762
+38867
+38872
+38944
+38947
+39015
+39023
+39028
+39043
+39068
+39080
+39097
+39118
+39171
+39197
+39236
+39254
+39271
+39277
+39280
+39336
+39338
+39340
+39341
+39358
+39364
+39497
+39503
+39537
+39541
+39559
+39560
+39562
+39596
+39600
+39613
+39623
+39656
+39670
+39781
+39810
+39832
+39861
+39875
+39892
+39918
+39919
+40008
+40016
+40082
+40091
+40095
+40164
+40213
+40234
+40274
+40279
+40324
+40332
+40341
+40349
+40365
+40438
+40446
+40482
+40501
+40510
+40516
+40541
+40544
+40545
+40574
+40617
+40659
+40668
+40742
+40754
+40758
+40764
+40765
+40795
+40858
+40901
+40985
+40986
+41080
+41112
+41121
+41136
+41196
+41199
+41219
+41233
+41246
+41278
+41376
+41401
+41409
+41434
+41470
+41492
+41502
+41517
+41571
+41572
+41608
+41648
+41699
+41773
+41779
+41801
+41837
+41843
+41849
+41855
+41873
+41881
+41901
+41924
+41926
+41935
+41962
+42008
+42062
+42069
+42072
+42094
+42097
+42104
+42112
+42117
+42137
+42147
+42170
+42185
+42224
+42237
+42250
+42254
+42257
+42276
+42282
+42298
+42321
+42351
+42372
+42378
+42420
+42446
+42453
+42466
+42470
+42502
+42514
+42518
+42527
+42662
+42721
+42727
+42743
+42794
+42840
+42843
+42871
+42872
+42897
+42950
+42956
+42967
+42969
+42975
+42995
+43005
+43008
+43046
+43052
+43091
+43103
+43124
+43198
+43225
+43228
+43385
+43394
+43402
+43405
+43408
+43423
+43503
+43529
+43557
+43647
+43656
+43704
+43706
+43714
+43745
+43748
+43759
+43812
+43927
+43950
+43997
+43998
+44016
+44018
+44025
+44060
+44066
+44099
+44128
+44149
+44150
+44169
+44184
+44198
+44254
+44272
+44293
+44310
+44352
+44389
+44399
+44400
+44442
+44451
+44470
+44474
+44522
+44569
+44590
+44713
+44738
+44787
+44823
+44829
+44845
+44895
+44918
+44975
+45024
+45121
+45148
+45154
+45179
+45208
+45210
+45215
+45218
+45220
+45235
+45265
+45282
+45283
+45285
+45286
+45303
+45351
+45359
+45396
+45407
+45414
+45472
+45519
+45522
+45564
+45621
+45641
+45660
+45678
+45695
+45696
+45710
+45780
+45800
+45823
+45828
+45862
+45947
+45964
+46001
+46050
+46084
+46113
+46132
+46146
+46198
+46221
+46234
+46236
+46256
+46272
+46298
+46325
+46337
+46347
+46374
+46386
+46388
+46437
+46491
+46560
+46561
+46589
+46600
+46656
+46660
+46664
+46673
+46690
+46700
+46808
+46809
+46828
+46918
+46963
+46979
+46984
+47005
+47088
+47097
+47100
+47143
+47147
+47261
+47320
+47369
+47450
+47503
+47533
+47538
+47576
+47601
+47608
+47618
+47621
+47624
+47659
+47681
+47698
+47708
+47745
+47817
+47826
+47879
+47883
+47917
+47937
+47957
+48000
+48023
+48076
+48099
+48130
+48133
+48281
+48298
+48321
+48349
+48351
+48353
+48358
+48371
+48426
+48455
+48522
+48526
+48544
+48573
+48606
+48609
+48646
+48667
+48699
+48701
+48740
+48773
+48777
+48785
+48847
+48886
+48940
+48986
+49029
+49054
+49100
+49121
+49137
+49157
+49191
+49222
+49291
+49315
+49347
+49374
+49376
+49381
+49407
+49427
+49481
+49497
+49624
+49785
+49791
+49835
+49875
+49877
+49981
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc
index f361341f7c..2a8a2b9b59 100644
--- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc
@@ -52,18 +52,22 @@ class ResultsWriter : public ImagenetModelEvaluator::Observer {
explicit ResultsWriter(std::unique_ptr<CSVWriter> writer)
: writer_(std::move(writer)) {}
- void OnEvaluationStart(int total_number_of_images) override {}
+ void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
+ shard_id_image_count_map) override {}
void OnSingleImageEvaluationComplete(
- const ImagenetTopKAccuracy::AccuracyStats& stats,
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
const string& image) override;
private:
- std::unique_ptr<CSVWriter> writer_;
+ std::unique_ptr<CSVWriter> writer_ GUARDED_BY(mu_);
+ mutex mu_;
};
void ResultsWriter::OnSingleImageEvaluationComplete(
- const ImagenetTopKAccuracy::AccuracyStats& stats, const string& image) {
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) {
+ mutex_lock lock(mu_);
TF_CHECK_OK(writer_->WriteRow(GetAccuracies(stats)));
writer_->Flush();
}
@@ -71,33 +75,40 @@ void ResultsWriter::OnSingleImageEvaluationComplete(
// Logs results to standard output with `kLogDelayUs` microseconds.
class ResultsLogger : public ImagenetModelEvaluator::Observer {
public:
- void OnEvaluationStart(int total_number_of_images) override;
+ void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
+ shard_id_image_count_map) override;
void OnSingleImageEvaluationComplete(
- const ImagenetTopKAccuracy::AccuracyStats& stats,
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
const string& image) override;
private:
- int total_num_images_ = 0;
- uint64 last_logged_time_us_ = 0;
+ uint64_t last_logged_time_us_ GUARDED_BY(mu_) = 0;
+ int total_num_images_ GUARDED_BY(mu_);
static constexpr int kLogDelayUs = 500 * 1000;
+ mutex mu_;
};
-void ResultsLogger::OnEvaluationStart(int total_number_of_images) {
- total_num_images_ = total_number_of_images;
- LOG(ERROR) << "Starting model evaluation: " << total_num_images_;
+void ResultsLogger::OnEvaluationStart(
+ const std::unordered_map<uint64_t, int>& shard_id_image_count_map) {
+ int total_num_images = 0;
+ for (const auto& kv : shard_id_image_count_map) {
+ total_num_images += kv.second;
+ }
+ LOG(ERROR) << "Starting model evaluation: " << total_num_images;
+ mutex_lock lock(mu_);
+ total_num_images_ = total_num_images;
}
void ResultsLogger::OnSingleImageEvaluationComplete(
- const ImagenetTopKAccuracy::AccuracyStats& stats, const string& image) {
- int num_evaluated = stats.number_of_images;
-
- double current_percent = num_evaluated * 100.0 / total_num_images_;
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) {
auto now_us = Env::Default()->NowMicros();
-
+ int num_evaluated = stats.number_of_images;
+ mutex_lock lock(mu_);
if ((now_us - last_logged_time_us_) >= kLogDelayUs) {
last_logged_time_us_ = now_us;
-
+ double current_percent = num_evaluated * 100.0 / total_num_images_;
LOG(ERROR) << "Evaluated " << num_evaluated << "/" << total_num_images_
<< " images, " << std::setprecision(2) << std::fixed
<< current_percent << "%";
@@ -108,15 +119,20 @@ int Main(int argc, char* argv[]) {
// TODO(shashishekhar): Make this binary configurable and model
// agnostic.
string output_file_path;
+ int num_threads = 4;
std::vector<Flag> flag_list = {
Flag("output_file_path", &output_file_path, "Path to output file."),
+ Flag("num_threads", &num_threads, "Number of threads."),
};
Flags::Parse(&argc, argv, flag_list);
std::unique_ptr<ImagenetModelEvaluator> evaluator;
CHECK(!output_file_path.empty()) << "Invalid output file path.";
- TF_CHECK_OK(ImagenetModelEvaluator::Create(argc, argv, &evaluator));
+ CHECK(num_threads > 0) << "Invalid number of threads.";
+
+ TF_CHECK_OK(
+ ImagenetModelEvaluator::Create(argc, argv, num_threads, &evaluator));
std::ofstream output_stream(output_file_path, std::ios::out);
CHECK(output_stream) << "Unable to open output file path: '"
@@ -136,6 +152,7 @@ int Main(int argc, char* argv[]) {
ResultsLogger logger;
evaluator->AddObserver(&results_writer);
evaluator->AddObserver(&logger);
+ LOG(ERROR) << "Starting evaluation with: " << num_threads << " threads.";
TF_CHECK_OK(evaluator->EvaluateModel());
return 0;
}
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
index a88a4a0fce..63616fc3b4 100644
--- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
@@ -29,7 +29,10 @@ limitations under the License.
#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h"
#include "tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h"
#include "tensorflow/contrib/lite/tools/accuracy/utils.h"
+#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
@@ -57,6 +60,21 @@ std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
return result;
}
+template <typename T>
+std::vector<std::vector<T>> Split(const std::vector<T>& v, int n) {
+ CHECK_GT(n, 0);
+ std::vector<std::vector<T>> vecs(n);
+ int input_index = 0;
+ int vec_index = 0;
+ while (input_index < v.size()) {
+ vecs[vec_index].push_back(v[input_index]);
+ vec_index = (vec_index + 1) % n;
+ input_index++;
+ }
+ CHECK_EQ(vecs.size(), n);
+ return vecs;
+}
+
// File pattern for imagenet files.
const char* const kImagenetFilePattern = "*.[jJ][pP][eE][gG]";
@@ -65,8 +83,36 @@ const char* const kImagenetFilePattern = "*.[jJ][pP][eE][gG]";
namespace tensorflow {
namespace metrics {
+class CompositeObserver : public ImagenetModelEvaluator::Observer {
+ public:
+ explicit CompositeObserver(const std::vector<Observer*>& observers)
+ : observers_(observers) {}
+
+ void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
+ shard_id_image_count_map) override {
+ mutex_lock lock(mu_);
+ for (auto observer : observers_) {
+ observer->OnEvaluationStart(shard_id_image_count_map);
+ }
+ }
+
+ void OnSingleImageEvaluationComplete(
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) override {
+ mutex_lock lock(mu_);
+ for (auto observer : observers_) {
+ observer->OnSingleImageEvaluationComplete(shard_id, stats, image);
+ }
+ }
+
+ private:
+ const std::vector<ImagenetModelEvaluator::Observer*>& observers_
+ GUARDED_BY(mu_);
+ mutex mu_;
+};
+
/*static*/ Status ImagenetModelEvaluator::Create(
- int argc, char* argv[],
+ int argc, char* argv[], int num_threads,
std::unique_ptr<ImagenetModelEvaluator>* model_evaluator) {
Params params;
const std::vector<Flag> flag_list = {
@@ -82,8 +128,12 @@ namespace metrics {
Flag("num_images", &params.number_of_images,
"Number of examples to evaluate, pass 0 for all "
"examples. Default: 100"),
- tensorflow::Flag("model_file", &params.model_file_path,
- "Path to test tflite model file."),
+ Flag("blacklist_file_path", &params.blacklist_file_path,
+ "Path to blacklist file (optional)."
+ "Path to blacklist file where each line is a single integer that is "
+ "equal to number of blacklisted image."),
+ Flag("model_file", &params.model_file_path,
+ "Path to test tflite model file."),
};
const bool parse_result = Flags::Parse(&argc, argv, flag_list);
if (!parse_result)
@@ -100,6 +150,12 @@ namespace metrics {
Env::Default()->FileExists(params.model_output_labels_path),
"Invalid model output labels path.");
+ if (!params.blacklist_file_path.empty()) {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ Env::Default()->FileExists(params.blacklist_file_path),
+ "Invalid blacklist path.");
+ }
+
if (params.number_of_images < 0) {
return errors::InvalidArgument("Invalid: num_examples");
}
@@ -109,28 +165,30 @@ namespace metrics {
utils::GetTFliteModelInfo(params.model_file_path, &model_info),
"Invalid TFLite model.");
- *model_evaluator =
- absl::make_unique<ImagenetModelEvaluator>(model_info, params);
+ *model_evaluator = absl::make_unique<ImagenetModelEvaluator>(
+ model_info, params, num_threads);
return Status::OK();
}
-Status ImagenetModelEvaluator::EvaluateModel() {
- if (model_info_.input_shapes.size() != 1) {
- return errors::InvalidArgument("Invalid input shape");
- }
-
- const TensorShape& input_shape = model_info_.input_shapes[0];
- // Input should be of the shape {1, height, width, 3}
- if (input_shape.dims() != 4 || input_shape.dim_size(3) != 3) {
- return errors::InvalidArgument("Invalid input shape for the model.");
- }
-
+struct ImageLabel {
+ string image;
+ string label;
+};
+
+Status EvaluateModelForShard(const uint64_t shard_id,
+ const std::vector<ImageLabel>& image_labels,
+ const std::vector<string>& model_labels,
+ const utils::ModelInfo& model_info,
+ const ImagenetModelEvaluator::Params& params,
+ ImagenetModelEvaluator::Observer* observer,
+ ImagenetTopKAccuracy* eval) {
+ const TensorShape& input_shape = model_info.input_shapes[0];
const int image_height = input_shape.dim_size(1);
const int image_width = input_shape.dim_size(2);
- const bool is_quantized = (model_info_.input_types[0] == DT_UINT8);
+ const bool is_quantized = (model_info.input_types[0] == DT_UINT8);
RunTFLiteModelStage::Params tfl_model_params;
- tfl_model_params.model_file_path = params_.model_file_path;
+ tfl_model_params.model_file_path = params.model_file_path;
if (is_quantized) {
tfl_model_params.input_type = {DT_UINT8};
tfl_model_params.output_type = {DT_UINT8};
@@ -144,29 +202,77 @@ Status ImagenetModelEvaluator::EvaluateModel() {
InceptionPreprocessingStage inc(image_height, image_width, is_quantized);
RunTFLiteModelStage tfl_model_stage(tfl_model_params);
EvalPipelineBuilder builder;
- std::vector<string> model_labels;
- TF_RETURN_IF_ERROR(
- utils::ReadFileLines(params_.model_output_labels_path, &model_labels));
- if (model_labels.size() != 1001) {
- return errors::InvalidArgument("Invalid number of labels: ",
- model_labels.size());
- }
- ImagenetTopKAccuracy eval(model_labels, params_.num_ranks);
std::unique_ptr<EvalPipeline> eval_pipeline;
auto build_status = builder.WithInputStage(&reader)
.WithPreprocessingStage(&inc)
.WithRunModelStage(&tfl_model_stage)
- .WithAccuracyEval(&eval)
+ .WithAccuracyEval(eval)
.WithInput("input_file", DT_STRING)
.Build(root, &eval_pipeline);
TF_RETURN_WITH_CONTEXT_IF_ERROR(build_status,
"Failure while building eval pipeline.");
-
std::unique_ptr<Session> session(NewSession(SessionOptions()));
TF_RETURN_IF_ERROR(eval_pipeline->AttachSession(std::move(session)));
+
+ for (const auto& image_label : image_labels) {
+ TF_CHECK_OK(eval_pipeline->Run(CreateStringTensor(image_label.image),
+ CreateStringTensor(image_label.label)));
+ observer->OnSingleImageEvaluationComplete(
+ shard_id, eval->GetTopKAccuracySoFar(), image_label.image);
+ }
+ return Status::OK();
+}
+
+Status FilterBlackListedImages(const string& blacklist_file_path,
+ std::vector<ImageLabel>* image_labels) {
+ if (!blacklist_file_path.empty()) {
+ std::vector<string> lines;
+ TF_RETURN_IF_ERROR(utils::ReadFileLines(blacklist_file_path, &lines));
+ std::vector<int> blacklist_ids;
+ blacklist_ids.reserve(lines.size());
+ // Populate blacklist_ids with indices of images.
+ std::transform(lines.begin(), lines.end(),
+ std::back_inserter(blacklist_ids),
+ [](const string& val) { return std::stoi(val) - 1; });
+
+ std::vector<ImageLabel> filtered_images;
+ std::sort(blacklist_ids.begin(), blacklist_ids.end());
+ const size_t size_post_filtering =
+ image_labels->size() - blacklist_ids.size();
+ filtered_images.reserve(size_post_filtering);
+ int blacklist_index = 0;
+ for (int image_index = 0; image_index < image_labels->size();
+ image_index++) {
+ if (blacklist_index < blacklist_ids.size() &&
+ blacklist_ids[blacklist_index] == image_index) {
+ blacklist_index++;
+ continue;
+ }
+ filtered_images.push_back((*image_labels)[image_index]);
+ }
+
+ if (filtered_images.size() != size_post_filtering) {
+ return errors::Internal("Invalid number of filtered images");
+ }
+ *image_labels = filtered_images;
+ }
+ return Status::OK();
+}
+
+Status ImagenetModelEvaluator::EvaluateModel() const {
+ if (model_info_.input_shapes.size() != 1) {
+ return errors::InvalidArgument("Invalid input shape");
+ }
+
+ const TensorShape& input_shape = model_info_.input_shapes[0];
+ // Input should be of the shape {1, height, width, 3}
+ if (input_shape.dims() != 4 || input_shape.dim_size(3) != 3) {
+ return errors::InvalidArgument("Invalid input shape for the model.");
+ }
+
string data_path =
StripTrailingSlashes(params_.ground_truth_images_path) + "/";
@@ -174,31 +280,70 @@ Status ImagenetModelEvaluator::EvaluateModel() {
std::vector<string> image_files;
TF_CHECK_OK(
Env::Default()->GetMatchingPaths(imagenet_file_pattern, &image_files));
- std::vector<string> image_labels;
- TF_CHECK_OK(
- utils::ReadFileLines(params_.ground_truth_labels_path, &image_labels));
- CHECK_EQ(image_files.size(), image_labels.size());
+ std::vector<string> ground_truth_image_labels;
+ TF_CHECK_OK(utils::ReadFileLines(params_.ground_truth_labels_path,
+ &ground_truth_image_labels));
+ CHECK_EQ(image_files.size(), ground_truth_image_labels.size());
// Process files in filename sorted order.
std::sort(image_files.begin(), image_files.end());
+
+ std::vector<ImageLabel> image_labels;
+ image_labels.reserve(image_files.size());
+ for (int i = 0; i < image_files.size(); i++) {
+ image_labels.push_back({image_files[i], ground_truth_image_labels[i]});
+ }
+
+ // Filter any blacklisted images.
+ TF_CHECK_OK(
+ FilterBlackListedImages(params_.blacklist_file_path, &image_labels));
+
if (params_.number_of_images > 0) {
- image_files = GetFirstN(image_files, params_.number_of_images);
image_labels = GetFirstN(image_labels, params_.number_of_images);
}
- for (Observer* observer : observers_) {
- observer->OnEvaluationStart(image_files.size());
+ std::vector<string> model_labels;
+ TF_RETURN_IF_ERROR(
+ utils::ReadFileLines(params_.model_output_labels_path, &model_labels));
+ if (model_labels.size() != 1001) {
+ return errors::InvalidArgument("Invalid number of labels: ",
+ model_labels.size());
}
- for (int i = 0; i < image_files.size(); i++) {
- TF_CHECK_OK(eval_pipeline->Run(CreateStringTensor(image_files[i]),
- CreateStringTensor(image_labels[i])));
- auto stats = eval.GetTopKAccuracySoFar();
+ ImagenetTopKAccuracy eval(model_labels, params_.num_ranks);
- for (Observer* observer : observers_) {
- observer->OnSingleImageEvaluationComplete(stats, image_files[i]);
- }
+ auto img_labels = Split(image_labels, num_threads_);
+
+ BlockingCounter counter(num_threads_);
+
+ CompositeObserver observer(observers_);
+
+ ::tensorflow::thread::ThreadPool pool(Env::Default(), "evaluation_pool",
+ num_threads_);
+ std::unordered_map<uint64_t, int> shard_id_image_count_map;
+ std::vector<std::function<void()>> thread_funcs;
+ thread_funcs.reserve(num_threads_);
+ for (int i = 0; i < num_threads_; i++) {
+ const auto& image_label = img_labels[i];
+ const uint64_t shard_id = i + 1;
+ shard_id_image_count_map[shard_id] = image_label.size();
+ auto func = [shard_id, &image_label, &model_labels, this, &observer, &eval,
+ &counter]() {
+ TF_CHECK_OK(EvaluateModelForShard(shard_id, image_label, model_labels,
+ model_info_, params_, &observer,
+ &eval));
+ counter.DecrementCount();
+ };
+ thread_funcs.push_back(func);
}
+
+ observer.OnEvaluationStart(shard_id_image_count_map);
+ for (const auto& func : thread_funcs) {
+ pool.Schedule(func);
+ }
+
+ counter.Wait();
+
return Status::OK();
}
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h
index 5f42b2a50e..97e4232b35 100644
--- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h
@@ -56,6 +56,13 @@ class ImagenetModelEvaluator {
// Path to the model file.
string model_file_path;
+ // Path to black list file. 1762 images were blacklisted from
+ // original ILSVRC dataset. This black list file is present in
+ // ILSVRC2014 devkit. Please refer to readme.txt of the ILSVRC2014
+ // devkit for details.
+ // This file is a list of image indices in a sorted order.
+ string blacklist_file_path;
+
// The maximum number of images to calculate accuracy.
// 0 means all images, a positive number means only the specified
// number of images.
@@ -66,6 +73,7 @@ class ImagenetModelEvaluator {
};
// An evaluation observer.
+ // Observers can be called from multiple threads and need to be thread safe.
class Observer {
public:
Observer() = default;
@@ -76,38 +84,41 @@ class ImagenetModelEvaluator {
Observer& operator=(const Observer&&) = delete;
// Called on start of evaluation.
- virtual void OnEvaluationStart(int total_number_of_images) = 0;
+ // `shard_id_image_count_map` map from shard id to image count.
+ virtual void OnEvaluationStart(
+ const std::unordered_map<uint64_t, int>& shard_id_image_count_map) = 0;
// Called when evaluation was complete for `image`.
virtual void OnSingleImageEvaluationComplete(
- const ImagenetTopKAccuracy::AccuracyStats& stats,
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
const string& image) = 0;
virtual ~Observer() = default;
};
ImagenetModelEvaluator(const utils::ModelInfo& model_info,
- const Params& params)
- : model_info_(model_info), params_(params) {}
+ const Params& params, const int num_threads)
+ : model_info_(model_info), params_(params), num_threads_(num_threads) {}
// Factory method to create the evaluator by parsing command line arguments.
- static Status Create(int argc, char* argv[],
+ static Status Create(int argc, char* argv[], int num_threads,
std::unique_ptr<ImagenetModelEvaluator>* evaluator);
// Adds an observer that can observe evaluation events..
void AddObserver(Observer* observer) { observers_.push_back(observer); }
- const Params& params() { return params_; }
+ const Params& params() const { return params_; }
// Evaluates the provided model over the dataset.
- Status EvaluateModel();
+ Status EvaluateModel() const;
private:
- std::vector<Observer*> observers_;
const utils::ModelInfo model_info_;
const Params params_;
+ const int num_threads_;
+ std::vector<Observer*> observers_;
};
} // namespace metrics
} // namespace tensorflow
-#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ILSVRC_IMAGENET_MODEL_EVALUATOR_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc
index d46075d234..c75baa82b1 100644
--- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc
@@ -77,26 +77,33 @@ Status ImagenetTopKAccuracy::ComputeEval(
CHECK_EQ(kNumCategories, probabilities.size());
std::vector<int> topK = GetTopK(probabilities, k_);
int ground_truth_index = GroundTruthIndex(ground_truth_label);
- for (size_t i = 0; i < topK.size(); ++i) {
- if (ground_truth_index == topK[i]) {
- for (size_t j = i; j < topK.size(); j++) {
- accuracy_counts_[j] += 1;
- }
- break;
- }
- }
- num_samples_++;
+ UpdateSamples(topK, ground_truth_index);
return Status::OK();
}
const ImagenetTopKAccuracy::AccuracyStats
ImagenetTopKAccuracy::GetTopKAccuracySoFar() const {
+ mutex_lock lock(mu_);
AccuracyStats stats;
stats.number_of_images = num_samples_;
stats.topk_counts = accuracy_counts_;
return stats;
}
+void ImagenetTopKAccuracy::UpdateSamples(const std::vector<int>& counts,
+ int ground_truth_index) {
+ mutex_lock lock(mu_);
+ for (size_t i = 0; i < counts.size(); ++i) {
+ if (ground_truth_index == counts[i]) {
+ for (size_t j = i; j < counts.size(); j++) {
+ accuracy_counts_[j] += 1;
+ }
+ break;
+ }
+ }
+ num_samples_++;
+}
+
int ImagenetTopKAccuracy::GroundTruthIndex(const string& label) const {
auto index = std::find(ground_truth_labels_.cbegin(),
ground_truth_labels_.cend(), label);
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h
index 5a575ff244..cad646a30c 100644
--- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace metrics {
@@ -69,12 +70,14 @@ class ImagenetTopKAccuracy : public AccuracyEval {
private:
int GroundTruthIndex(const string& label) const;
- std::vector<string> ground_truth_labels_;
+ void UpdateSamples(const std::vector<int>& counts, int ground_truth_index);
+ const std::vector<string> ground_truth_labels_;
const int k_;
- std::vector<int> accuracy_counts_;
- int num_samples_;
+ std::vector<int> accuracy_counts_ GUARDED_BY(mu_);
+ int num_samples_ GUARDED_BY(mu_);
+ mutable mutex mu_;
};
} // namespace metrics
} // namespace tensorflow
-#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_TOPK_EVAL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ILSVRC_IMAGENET_TOPK_EVAL_H_
diff --git a/tensorflow/contrib/lite/tools/benchmark/README.md b/tensorflow/contrib/lite/tools/benchmark/README.md
index f1e257ad10..8d997639fb 100644
--- a/tensorflow/contrib/lite/tools/benchmark/README.md
+++ b/tensorflow/contrib/lite/tools/benchmark/README.md
@@ -9,7 +9,7 @@ of runs. Aggregrate latency statistics are reported after running the benchmark.
The instructions below are for running the binary on Desktop and Android,
for iOS please use the
-[iOS benchmark app] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios).
+[iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios).
## Parameters
@@ -17,11 +17,6 @@ The binary takes the following required parameters:
* `graph`: `string` \
The path to the TFLite model file.
-* `input_layer`: `string` \
- The name of the input layer, this is typically the first layer of the model.
-* `input_layer_shape`: `string` \
- The shape of the input layer. This is a comma separated string of the shape
- of tensor of input layer.
and the following optional parameters:
@@ -29,11 +24,13 @@ and the following optional parameters:
The number of threads to use for running TFLite interpreter.
* `warmup_runs`: `int` (default=1) \
The number of warmup runs to do before starting the benchmark.
+* `num_runs`: `int` (default=50) \
+ The number of runs. Increase this to reduce variance.
* `run_delay`: `float` (default=-1.0) \
The delay in seconds between subsequent benchmark runs. Non-positive values
mean use no delay.
* `use_nnapi`: `bool` (default=false) \
- Whether to use [Android NNAPI] (https://developer.android.com/ndk/guides/neuralnetworks/).
+ Whether to use [Android NNAPI](https://developer.android.com/ndk/guides/neuralnetworks/).
This API is available on recent Android devices.
## To build/install/run
@@ -75,8 +72,6 @@ adb push mobilenet_quant_v1_224.tflite /data/local/tmp
```
adb shell /data/local/tmp/benchmark_model \
--graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \
- --input_layer="input" \
- --input_layer_shape="1,224,224,3" \
--num_threads=4
```
@@ -93,13 +88,10 @@ For example:
```
bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model \
--graph=mobilenet_quant_v1_224.tflite \
- --input_layer="Placeholder" \
- --input_layer_shape="1,224,224,3" \
--num_threads=4
```
-The MobileNet graph used as an example here may be downloaded from
-https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip
+The MobileNet graph used as an example here may be downloaded from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip).
## Reducing variance between runs on Android.
@@ -117,8 +109,6 @@ can use the following command:
```
adb shell taskset f0 /data/local/tmp/benchmark_model \
--graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \
- --input_layer="input" \
- --input_layer_shape="1,224,224,3" \
--num_threads=1
```
@@ -205,5 +195,3 @@ Memory (bytes): count=0
Average inference timings in us: Warmup: 83235, Init: 38467, no stats: 79760.9
```
-
-
diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/README.md b/tensorflow/contrib/lite/tools/benchmark/ios/README.md
index c8d3307e29..46144f7bf8 100644
--- a/tensorflow/contrib/lite/tools/benchmark/ios/README.md
+++ b/tensorflow/contrib/lite/tools/benchmark/ios/README.md
@@ -17,8 +17,8 @@ Mobilenet_1.0_224 model
## To build/install/run
-- Follow instructions at [iOS build for TFLite]
-(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/ios.md)
+- Follow instructions at
+[iOS build for TFLite](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/ios.md)
to build TFLite.
Running
diff --git a/tensorflow/contrib/lite/tools/optimize/BUILD b/tensorflow/contrib/lite/tools/optimize/BUILD
index 01fbce0ac7..51ccaedc23 100644
--- a/tensorflow/contrib/lite/tools/optimize/BUILD
+++ b/tensorflow/contrib/lite/tools/optimize/BUILD
@@ -9,3 +9,17 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+
+cc_library(
+ name = "quantize_weights",
+ srcs = ["quantize_weights.cc"],
+ hdrs = ["quantize_weights.h"],
+ deps = [
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels/internal:tensor_utils",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/core:tflite_portable_logging",
+ "@com_google_absl//absl/memory",
+ "@flatbuffers",
+ ],
+)
diff --git a/tensorflow/contrib/lite/tools/optimize/g3doc/quantize_weights.md b/tensorflow/contrib/lite/tools/optimize/g3doc/quantize_weights.md
new file mode 100644
index 0000000000..93fe576583
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/optimize/g3doc/quantize_weights.md
@@ -0,0 +1,70 @@
+# TFLite Quantize Weights Tool
+
+## Recommended usage
+
+The Quantize Weights transformation is integrated with
+[tflite_convert](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md#transformation-flags).
+
+The recommended way of invoking this tool is by simply adding the
+`--post_training_quantize` flag to your original tflite_convert invocation. For
+example,
+
+```
+tflite_convert \
+ --output_file=/tmp/foo.tflite \
+ --saved_model_dir=/tmp/saved_model \
+ --post_training_quantize
+```
+
+## Overview
+
+The Quantize Weights tool provides a simple way to quantize the weights for a
+float TFLite model.
+
+TODO(raghuramank): Add link to weight quantization tutorial.
+
+### Size reduction
+
+float32 weights will be converted to 8 bit integers. This results in a model
+that is around 1/4th the size of the original model.
+
+### Latency reduction
+
+TFLite also has "hybrid" kernels implemented for many operations. These "hybrid"
+kernels take 8 bit integer weights and float inputs, dynamically quantize the
+inputs tensor (based on the input tensor's min and max elements), and does
+computations using the 8 bit integer values. This results in a 2-4x reduction in
+latency for "hybrid" kernels. In this mode the inference type is still FLOAT
+since the inputs and output to each operation is still float.
+
+For operations that do not yet have "hybrid" kernels implemented, we introduce a
+Dequantize operation after 8 bit integer weights. These convert weights back to
+float32 during inference to allow original float32 kernels to run. Since we
+cache dequantized results, the result of each of this dequantized path will be
+on-par with the original float model.
+
+TODO(yunluli): Fill in latency results from latency experiments.
+
+### Accuracy
+
+Since this technique quantizes weights after the model has already been trained,
+there can be accuracy drops depending on the model. For common CNN networks, the
+observed accuracy drops are small and can be seen below.
+
+TODO(yunluli): Fill in accuracy results from accuracy experiments.
+
+## Direct usage
+
+One can also invoke the Quantize Weights directly via C++ if they have a float
+`::tflite::Model` that they want to convert. They must provide a
+`flatbuffers::FlatBufferBuilder` which owns the underlying buffer of the created
+model. Here is an example invocation:
+
+```
+::tflite::Model* input_model = ...;
+flatbuffers::FlatBufferBuilder builder;
+TfLiteStatus status = ::tflite::optimize::QuantizeWeights(&builder, input_model);
+CHECK(status, kTfLiteStatusOk);
+const uint8_t* buffer = builder->GetBufferPointer();
+tflite::Model* output_model = ::tflite::GetModel(buffer);
+```
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
index 0758514e39..e0ed7c7946 100644
--- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
@@ -20,7 +20,9 @@ limitations under the License.
#include <vector>
#include "flatbuffers/flexbuffers.h"
+#include "absl/memory/memory.h"
#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/core/platform/logging.h"
@@ -30,6 +32,16 @@ namespace optimize {
namespace {
+typedef struct {
+ TensorT* tensor;
+ // The index of the tensor to quantize in subgraph->tensors.
+ int32_t tensor_idx;
+ // The index of the tensor of the weight tensor to be quantize in op->inputs.
+ int32_t op_input_idx;
+ // True if the tensor supports hybrid evaluation.
+ bool eval_hybrid;
+} TensorInfo;
+
// The minimum number of elements a weights array must have to be quantized
// by this transformation.
// TODO(suharshs): Make this configurable.
@@ -41,9 +53,9 @@ const int kWeightsMinSize = 1024;
// Although this code originates from FakeQuantization in quantized training,
// we may deviate from that implementation as we please since we do not fine
// tune the weights with quantized training.
-void GetQuantizationParams(const float min, const float max,
- const int quant_min, const int quant_max,
- QuantizationParametersT* quantization_params) {
+void GetAsymmetricQuantizationParams(
+ const float min, const float max, const int quant_min, const int quant_max,
+ QuantizationParametersT* quantization_params) {
// Adjust the boundaries to guarantee 0 is included.
const float quant_min_float = std::min(static_cast<float>(quant_min), 0.0f);
const float quant_max_float = std::max(static_cast<float>(quant_max), 0.0f);
@@ -57,25 +69,25 @@ void GetQuantizationParams(const float min, const float max,
} else {
zero_point = static_cast<int64_t>(std::round(zero_point_from_min));
}
- quantization_params->scale = {scale};
- quantization_params->zero_point = {zero_point};
+ quantization_params->scale = std::vector<float>(1, scale);
+ quantization_params->zero_point = std::vector<int64_t>(1, zero_point);
}
// Returns the number of elements in tensor.
-uint64 NumElements(const TensorT* tensor) {
+uint64_t NumElements(const TensorT* tensor) {
if (tensor->shape.empty()) {
LOG(FATAL) << "Tensor has no shape information.";
}
- uint64 num_elements = 1;
- for (const uint64 dim : tensor->shape) {
+ uint64_t num_elements = 1;
+ for (const uint64_t dim : tensor->shape) {
num_elements *= dim;
}
return num_elements;
}
-uint64 CountTensorConsumers(const ModelT* model, const SubGraphT* subgraph,
- int32_t tensor_idx) {
- uint64 count = 0;
+uint64_t CountTensorConsumers(const ModelT* model, const SubGraphT* subgraph,
+ int32_t tensor_idx) {
+ uint64_t count = 0;
for (int op_idx = 0; op_idx < subgraph->operators.size(); ++op_idx) {
const OperatorT* op = subgraph->operators[op_idx].get();
if (op == nullptr) {
@@ -90,49 +102,118 @@ uint64 CountTensorConsumers(const ModelT* model, const SubGraphT* subgraph,
return count;
}
-// Returns true if the Operator's weight tensor should be quantized.
-bool GetQuantizableTensorFromOperator(const ModelT* model, const OperatorT* op,
- TensorT** tensor, int32_t* tensor_idx,
- int32_t* op_input_index) {
- SubGraphT* subgraph = model->subgraphs.at(0).get();
- const BuiltinOperator op_code =
- model->operator_codes[op->opcode_index]->builtin_code;
-
+// Gets the list of op->inputs indices of the weights inputs to be quantized for
+// the provided op.
+std::vector<int32_t> GetWeightInputIndices(const BuiltinOperator& op_code) {
if (op_code == BuiltinOperator_CONV_2D ||
op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
op_code == BuiltinOperator_FULLY_CONNECTED ||
- op_code == BuiltinOperator_SVDF) {
- *op_input_index = 1;
- } else if (op_code == BuiltinOperator_LSTM) {
- // TODO(suharshs): Add RNN, and sequential/bidi versions.
- *op_input_index = 2;
- } else {
- return false;
+ op_code == BuiltinOperator_EMBEDDING_LOOKUP) {
+ return {1};
+ } else if (op_code == BuiltinOperator_SVDF) {
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/svdf.cc
+ return {1, 2};
+ } else if (op_code == BuiltinOperator_LSTM ||
+ op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM) {
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/lstm.cc
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+ return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16};
+ } else if (op_code == BuiltinOperator_RNN ||
+ op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/basic_rnn.cc
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
+ return {1, 2};
+ } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM) {
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+ return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16,
+ 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 33};
+ } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN) {
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+ return {1, 2, 4, 5};
}
- *tensor_idx = op->inputs[*op_input_index];
-
- // TODO(suharshs): Support shared weights, i.e. If two tensors share the
- // same weight array, things may break. (i.e. SSD object detection)
- if (CountTensorConsumers(model, subgraph, *tensor_idx) != 1) {
- LOG(INFO) << "Skipping quantization of tensor that is shared between "
- "multiple multiple operations.";
- return false;
+ return {};
+}
+
+// Returns true if the operator supports hybrid evaluation.
+bool IsHybridEvaluationOp(const OperatorT* op, const BuiltinOperator& op_code) {
+ // Operations that support hybrid evaluation.
+ bool eval_hybrid = false;
+ if (op_code == BuiltinOperator_FULLY_CONNECTED ||
+ op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF ||
+ op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
+ op_code == BuiltinOperator_RNN ||
+ op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
+ op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM ||
+ op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
+ eval_hybrid = true;
+ } else if (op_code == BuiltinOperator_LSTM) {
+ const LSTMOptionsT* options = op->builtin_options.AsLSTMOptions();
+ // Only lstm kernel_type full supports hybrid evaluation.
+ if (options->kernel_type == LSTMKernelType_FULL) {
+ eval_hybrid = true;
+ }
}
+ return eval_hybrid;
+}
+
+// Returns a vector of TensorInfos for each input tensor of op that should be
+// quantized.
+std::vector<TensorInfo> GetQuantizableTensorsFromOperator(const ModelT* model,
+ const OperatorT* op) {
+ SubGraphT* subgraph = model->subgraphs.at(0).get();
+ const BuiltinOperator op_code =
+ model->operator_codes[op->opcode_index]->builtin_code;
+
+ std::vector<TensorInfo> tensor_infos;
+
+ bool eval_hybrid = IsHybridEvaluationOp(op, op_code);
+
+ bool skipped_tensor = false;
+ std::vector<int32_t> op_input_indices = GetWeightInputIndices(op_code);
+ for (const int32_t op_input_idx : op_input_indices) {
+ int32_t tensor_idx = op->inputs[op_input_idx];
+
+ // TODO(suharshs): Support shared weights, i.e. If two tensors share the
+ // same weight array, things may break. (i.e. SSD object detection)
+ if (CountTensorConsumers(model, subgraph, tensor_idx) != 1) {
+ LOG(INFO) << "Skipping quantization of tensor that is shared between "
+ "multiple multiple operations.";
+ skipped_tensor = true;
+ continue;
+ }
+
+ TensorT* tensor = subgraph->tensors[tensor_idx].get();
- *tensor = subgraph->tensors[*tensor_idx].get();
+ if (tensor->type != TensorType_FLOAT32) {
+ LOG(INFO) << "Skipping quantization of tensor that is not type float.";
+ skipped_tensor = true;
+ continue;
+ }
+
+ const uint64_t num_elements = NumElements(tensor);
+ if (num_elements < kWeightsMinSize) {
+ LOG(INFO) << "Skipping quantization of tensor because it has fewer than "
+ << kWeightsMinSize << " elements (" << num_elements << ").";
+ skipped_tensor = true;
+ continue;
+ }
- if ((*tensor)->type != TensorType_FLOAT32) {
- LOG(INFO) << "Skipping quantization of tensor that is not type float.";
- return false;
+ TensorInfo tensor_info;
+ tensor_info.eval_hybrid = eval_hybrid;
+ tensor_info.op_input_idx = op_input_idx;
+ tensor_info.tensor_idx = tensor_idx;
+ tensor_info.tensor = tensor;
+
+ tensor_infos.push_back(tensor_info);
}
- const uint64 num_elements = NumElements(*tensor);
- if (num_elements < kWeightsMinSize) {
- LOG(INFO) << "Skipping quantization of tensor because it has fewer than "
- << kWeightsMinSize << " elements (" << num_elements << ").";
- return false;
+
+ // For hybrid operations we either need to quantize all tensors or none. So
+ // if we skipped any tensors we need to return no quantized tensors.
+ if (eval_hybrid && skipped_tensor) {
+ return {};
}
- return true;
+ return tensor_infos;
}
// Quantizes tensor using asymmetric quantization with the min and max elements
@@ -140,14 +221,19 @@ bool GetQuantizableTensorFromOperator(const ModelT* model, const OperatorT* op,
TfLiteStatus AsymmetricQuantizeTensor(ModelT* model, TensorT* tensor) {
BufferT* buffer = model->buffers[tensor->buffer].get();
float* float_data = reinterpret_cast<float*>(buffer->data.data());
- const uint64 num_elements = NumElements(tensor);
- LOG(INFO) << "Quantizing tensor with " << num_elements << " elements.";
+ const uint64_t num_elements = NumElements(tensor);
+ LOG(INFO) << "Quantizing tensor " << tensor->name << " with " << num_elements
+ << " elements for float evaluation.";
// Compute the quantization params.
float min_value = *std::min_element(float_data, float_data + num_elements);
float max_value = *std::max_element(float_data, float_data + num_elements);
- GetQuantizationParams(min_value, max_value, 0, 255,
- tensor->quantization.get());
+
+ if (tensor->quantization == nullptr) {
+ tensor->quantization = absl::make_unique<QuantizationParametersT>();
+ }
+ GetAsymmetricQuantizationParams(min_value, max_value, 0, 255,
+ tensor->quantization.get());
// Quantize the buffer.
std::vector<uint8_t> quantized_buffer;
@@ -173,6 +259,40 @@ TfLiteStatus AsymmetricQuantizeTensor(ModelT* model, TensorT* tensor) {
return kTfLiteOk;
}
+// Quantizes tensor using symmetric quantization with the min and max elements
+// of the tensor. This is need for operations with hybrid evaluation
+// implemented.
+TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor) {
+ BufferT* buffer = model->buffers[tensor->buffer].get();
+ float* float_data = reinterpret_cast<float*>(buffer->data.data());
+ const uint64_t num_elements = NumElements(tensor);
+ LOG(INFO) << "Quantizing tensor " << tensor->name << " with " << num_elements
+ << " elements for hybrid evaluation.";
+
+ std::vector<int8_t> quantized_buffer;
+ quantized_buffer.resize(num_elements);
+
+ float min_value, max_value, scaling_factor;
+ tensor_utils::SymmetricQuantizeFloats(float_data, num_elements,
+ quantized_buffer.data(), &min_value,
+ &max_value, &scaling_factor);
+
+ if (tensor->quantization == nullptr) {
+ tensor->quantization = absl::make_unique<QuantizationParametersT>();
+ }
+ tensor->quantization->scale = std::vector<float>(1, scaling_factor);
+ tensor->quantization->zero_point = std::vector<int64_t>(1, 0);
+
+ uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(quantized_buffer.data());
+ model->buffers[tensor->buffer]->data.assign(uint8_buffer,
+ uint8_buffer + num_elements);
+
+ // Update the tensor type.
+ tensor->type = TensorType_UINT8;
+
+ return kTfLiteOk;
+}
+
// Returns the index of the Dequantize op_code.
// If a Dequantize op_code doesn't exist, adds it and returns its index.
int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) {
@@ -181,7 +301,7 @@ int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) {
return i;
}
}
- model->operator_codes.push_back(std::make_unique<OperatorCodeT>());
+ model->operator_codes.push_back(absl::make_unique<OperatorCodeT>());
int op_code_idx = model->operator_codes.size() - 1;
model->operator_codes[op_code_idx]->builtin_code = BuiltinOperator_DEQUANTIZE;
// TODO(suharshs): How should the version be set in this op_code?
@@ -214,7 +334,8 @@ void MakeTensor(const string& name, const std::vector<int32_t>& shape,
} // namespace
TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
- const Model* input_model) {
+ const Model* input_model,
+ bool use_hybrid_evaluation) {
std::unique_ptr<ModelT> model;
model.reset(input_model->UnPack());
@@ -231,40 +352,42 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
for (int i = 0; i < subgraph->operators.size(); ++i) {
OperatorT* op = subgraph->operators[i].get();
- TensorT* tensor;
- // The index of the weight tensor in subgraph->tensors.
- int32_t tensor_idx;
- int32_t op_input_idx; // The index of tensor_idx in the op->inputs.
- // TODO(suharshs): Support hybrid ops that require symmetric quantization.
- if (GetQuantizableTensorFromOperator(model.get(), op, &tensor, &tensor_idx,
- &op_input_idx)) {
- // Quantize the tensors.
- TF_LITE_ENSURE_STATUS(AsymmetricQuantizeTensor(model.get(), tensor));
-
- // Create a new tensor to be the output of the dequantize op.
- std::unique_ptr<TensorT> dequantize_output;
- MakeTensor(tensor->name + "_dequantize", tensor->shape,
- &dequantize_output);
- int32_t dequantize_output_idx = subgraph->tensors.size();
- subgraph->tensors.push_back(std::move(dequantize_output));
-
- // Create the Dequantize operation.
- std::unique_ptr<OperatorT> dequantize_op;
- MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx,
- dequantize_output_idx);
-
- // Update the op_input of tensor_idx to dequantize_output_idx.
- op->inputs[op_input_idx] = dequantize_output_idx;
- // Insert the updated op.
- new_operators.push_back(std::move(subgraph->operators[i]));
-
- // Insert the newly created Dequantize operation.
- new_operators.push_back(std::move(dequantize_op));
- } else {
- // If this tensor wasn't quantizable, just copy the op over as-is.
- new_operators.push_back(std::move(subgraph->operators[i]));
+ std::vector<TensorInfo> tensor_infos =
+ GetQuantizableTensorsFromOperator(model.get(), op);
+
+ for (const TensorInfo& tensor_info : tensor_infos) {
+ if (use_hybrid_evaluation && tensor_info.eval_hybrid) {
+ // Quantize the tensor.
+ TF_LITE_ENSURE_STATUS(
+ SymmetricQuantizeTensor(model.get(), tensor_info.tensor));
+ } else {
+ // Quantize the tensor.
+ TF_LITE_ENSURE_STATUS(
+ AsymmetricQuantizeTensor(model.get(), tensor_info.tensor));
+
+ // Create a new tensor to be the output of the dequantize op.
+ std::unique_ptr<TensorT> dequantize_output;
+ MakeTensor(tensor_info.tensor->name + "_dequantize",
+ tensor_info.tensor->shape, &dequantize_output);
+ const int32_t dequantize_output_idx = subgraph->tensors.size();
+ subgraph->tensors.push_back(std::move(dequantize_output));
+
+ // Create the Dequantize operation.
+ std::unique_ptr<OperatorT> dequantize_op;
+ MakeDequantizeOperator(model.get(), &dequantize_op,
+ tensor_info.tensor_idx, dequantize_output_idx);
+
+ // Update the op_input of tensor_idx to dequantize_output_idx.
+ op->inputs[tensor_info.op_input_idx] = dequantize_output_idx;
+
+ // Insert the newly created Dequantize operation.
+ new_operators.push_back(std::move(dequantize_op));
+ }
}
+ // After (maybe) quantizing inputs, we copy the operator into the new list.
+ new_operators.push_back(std::move(subgraph->operators[i]));
}
+
// At this point all unique_ptrs in the original operators are invalid, and
// we need to replace it with the new_operators vector.
subgraph->operators = std::move(new_operators);
@@ -276,5 +399,10 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
return kTfLiteOk;
}
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model) {
+ return QuantizeWeights(builder, input_model, true);
+}
+
} // namespace optimize
} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.h b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h
index a408c1662d..3743c0ce53 100644
--- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.h
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h
@@ -32,6 +32,12 @@ namespace optimize {
TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
const Model* input_model);
+// Same as above, but if use_hybrid_evaluation is false, will disable using
+// hybrid eval for operations that support it.
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model,
+ bool use_hybrid_evaluation);
+
} // namespace optimize
} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc
index 0e0676e5ff..efaf9929e9 100644
--- a/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc
@@ -48,29 +48,67 @@ class QuantizeWeightsTest : public ::testing::Test {
return nullptr;
}
- void CheckWeights(const Model* model_packed) {
- std::unique_ptr<ModelT> model;
- model.reset(model_packed->UnPack());
+ void SymmetricDequantizeAndCompare(const BufferT* input_buffer,
+ const BufferT* output_buffer,
+ float scale) {
+ const float* input_buffer_data =
+ reinterpret_cast<const float*>(input_buffer->data.data());
+ const int8_t* output_buffer_data =
+ reinterpret_cast<const int8_t*>(output_buffer->data.data());
+ for (int i = 0; i < output_buffer->data.size(); i++) {
+ float diff = input_buffer_data[i] - (output_buffer_data[i] * scale);
+ ASSERT_TRUE(std::abs(diff) <= scale);
+ }
+ }
+
+ void AsymmetricDequantizeAndCompare(const BufferT* input_buffer,
+ const BufferT* output_buffer, float scale,
+ int64_t zero_point) {
+ const float* input_buffer_data =
+ reinterpret_cast<const float*>(input_buffer->data.data());
+ const uint8_t* output_buffer_data = output_buffer->data.data();
+ for (int i = 0; i < output_buffer->data.size(); i++) {
+ float diff =
+ input_buffer_data[i] - ((output_buffer_data[i] - zero_point) * scale);
+ ASSERT_TRUE(std::abs(diff) <= scale);
+ }
+ }
+
+ void CheckWeights(const Model* input_model_packed,
+ const Model* output_model_packed,
+ bool use_hybrid_evaluation) {
+ std::unique_ptr<ModelT> input_model;
+ input_model.reset(input_model_packed->UnPack());
- SubGraphT* subgraph = model->subgraphs.at(0).get();
+ std::unique_ptr<ModelT> output_model;
+ output_model.reset(output_model_packed->UnPack());
+
+ SubGraphT* subgraph = output_model->subgraphs.at(0).get();
for (int i = 0; i < subgraph->operators.size(); ++i) {
OperatorT* op = subgraph->operators[i].get();
const BuiltinOperator op_code =
- model->operator_codes[op->opcode_index]->builtin_code;
+ output_model->operator_codes[op->opcode_index]->builtin_code;
// These are the operations that should be quantized.
+ // TODO(suharshs): Right now this test only checks the relevant operations
+ // for the mobilenet v1 model used in the tests below.
int32_t tensor_idx;
if (op_code == BuiltinOperator_CONV_2D ||
op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
op_code == BuiltinOperator_FULLY_CONNECTED) {
tensor_idx = op->inputs[1];
- } else if (op_code == BuiltinOperator_LSTM) {
- // TODO(suharshs): Add tests for LSTMs.
- tensor_idx = op->inputs[1];
} else {
continue;
}
+
+ bool eval_hybrid = false;
+ // These are the ops that support hybrid evaluation.
+ if (op_code == BuiltinOperator_FULLY_CONNECTED ||
+ op_code == BuiltinOperator_CONV_2D) {
+ eval_hybrid = true;
+ }
+
const TensorT* tensor = subgraph->tensors[tensor_idx].get();
int tensor_size = GetElementsNum(tensor);
// If the tensor_size is less than 1024 we expect the tensor to remain
@@ -80,27 +118,45 @@ class QuantizeWeightsTest : public ::testing::Test {
const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx);
// The weight tensor should not come from a dequantize op.
ASSERT_TRUE(preceding_op == nullptr);
+ } else if (use_hybrid_evaluation && eval_hybrid) {
+ // The input to the op should still be uint8.
+ ASSERT_TRUE(tensor->type == TensorType_UINT8) << tensor->name;
+ // The weight tensor should not come from a dequantize op.
+ const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx);
+ ASSERT_TRUE(preceding_op == nullptr);
+
+ // Test symmetric quantization.
+ SymmetricDequantizeAndCompare(
+ input_model->buffers[tensor->buffer].get(),
+ output_model->buffers[tensor->buffer].get(),
+ tensor->quantization->scale[0]);
+
} else {
// The input to the op should still be float.
ASSERT_TRUE(tensor->type == TensorType_FLOAT32) << tensor->name;
const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx);
ASSERT_TRUE(preceding_op != nullptr);
// The float input should be the dequantize output.
- ASSERT_TRUE(
- model->operator_codes[preceding_op->opcode_index]->builtin_code ==
- BuiltinOperator_DEQUANTIZE);
+ ASSERT_TRUE(output_model->operator_codes[preceding_op->opcode_index]
+ ->builtin_code == BuiltinOperator_DEQUANTIZE);
// Finally, ensure that the input to the dequantize operation is
// quantized.
- ASSERT_TRUE(subgraph->tensors[preceding_op->inputs[0]]->type ==
- TensorType_UINT8);
- // TODO(suharshs): Add more rigorous testing for the numerical values in
- // the tensors.
+ const TensorT* quantized_tensor =
+ subgraph->tensors[preceding_op->inputs[0]].get();
+ ASSERT_TRUE(quantized_tensor->type == TensorType_UINT8);
+
+ // Test the assymetric quantization.
+ AsymmetricDequantizeAndCompare(
+ input_model->buffers[quantized_tensor->buffer].get(),
+ output_model->buffers[quantized_tensor->buffer].get(),
+ quantized_tensor->quantization->scale[0],
+ quantized_tensor->quantization->zero_point[0]);
}
}
}
};
-TEST_F(QuantizeWeightsTest, SimpleTest) {
+TEST_F(QuantizeWeightsTest, SimpleTestWithHybrid) {
string model_path =
"third_party/tensorflow/contrib/lite/tools/optimize/testdata/"
"mobilenet_v1_0.25_128.tflite";
@@ -114,7 +170,25 @@ TEST_F(QuantizeWeightsTest, SimpleTest) {
const uint8_t* buffer = builder.GetBufferPointer();
const Model* output_model = GetModel(buffer);
- CheckWeights(output_model);
+ CheckWeights(input_model, output_model, true);
+}
+
+TEST_F(QuantizeWeightsTest, SimpleTestWithoutHybrid) {
+ string model_path =
+ "third_party/tensorflow/contrib/lite/tools/optimize/testdata/"
+ "mobilenet_v1_0.25_128.tflite";
+ std::unique_ptr<FlatBufferModel> input_fb =
+ FlatBufferModel::BuildFromFile(model_path.data());
+ const Model* input_model = input_fb->GetModel();
+
+ flatbuffers::FlatBufferBuilder builder;
+ // Disable hybrid evaluation.
+ EXPECT_EQ(QuantizeWeights(&builder, input_model, false), kTfLiteOk);
+
+ const uint8_t* buffer = builder.GetBufferPointer();
+ const Model* output_model = GetModel(buffer);
+
+ CheckWeights(input_model, output_model, false);
}
// TODO(suharshs): Add tests that run the resulting model.
diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt
index 7d26429f9c..22b11f1c57 100644
--- a/tensorflow/contrib/makefile/proto_text_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt
@@ -56,7 +56,6 @@ tensorflow/core/lib/hash/hash.cc
tensorflow/core/lib/hash/crc32c.cc
tensorflow/core/lib/hash/crc32c_accelerate.cc
tensorflow/core/lib/core/threadpool.cc
-tensorflow/core/lib/core/stringpiece.cc
tensorflow/core/lib/core/status.cc
tensorflow/core/lib/core/coding.cc
tensorflow/core/lib/core/arena.cc
diff --git a/tensorflow/contrib/model_pruning/BUILD b/tensorflow/contrib/model_pruning/BUILD
index e662b11be8..3cffd76a25 100644
--- a/tensorflow/contrib/model_pruning/BUILD
+++ b/tensorflow/contrib/model_pruning/BUILD
@@ -113,7 +113,7 @@ py_library(
py_test(
name = "pruning_utils_test",
- size = "small",
+ size = "medium",
srcs = ["python/pruning_utils_test.py"],
srcs_version = "PY2AND3",
deps = [
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index 5319a8b655..93e589907e 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -22,6 +22,7 @@ py_library(
"python/training/ggt.py",
"python/training/lars_optimizer.py",
"python/training/lazy_adam_optimizer.py",
+ "python/training/matrix_functions.py",
"python/training/model_average_optimizer.py",
"python/training/moving_average_optimizer.py",
"python/training/multitask_optimizer_wrapper.py",
@@ -381,3 +382,18 @@ py_test(
"@six_archive//:six",
],
)
+
+py_test(
+ name = "matrix_functions_test",
+ srcs = ["python/training/matrix_functions_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":opt_py",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
diff --git a/tensorflow/contrib/opt/python/training/matrix_functions.py b/tensorflow/contrib/opt/python/training/matrix_functions.py
new file mode 100644
index 0000000000..baab577638
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/matrix_functions.py
@@ -0,0 +1,155 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Matrix functions contains iterative methods for M^p."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+
+
+def matrix_square_root(mat_a, mat_a_size, iter_count=100, ridge_epsilon=1e-4):
+ """Iterative method to get matrix square root.
+
+ Stable iterations for the matrix square root, Nicholas J. Higham
+
+ Page 231, Eq 2.6b
+ http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.6.8799&rep=rep1&type=pdf
+
+ Args:
+ mat_a: the symmetric PSD matrix whose matrix square root be computed
+ mat_a_size: size of mat_a.
+ iter_count: Maximum number of iterations.
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
+
+ Returns:
+ mat_a^0.5
+ """
+
+ def _iter_condition(i, unused_mat_y, unused_old_mat_y, unused_mat_z,
+ unused_old_mat_z, err, old_err):
+ # This method require that we check for divergence every step.
+ return math_ops.logical_and(i < iter_count, err < old_err)
+
+ def _iter_body(i, mat_y, unused_old_mat_y, mat_z, unused_old_mat_z, err,
+ unused_old_err):
+ current_iterate = 0.5 * (3.0 * identity - math_ops.matmul(mat_z, mat_y))
+ current_mat_y = math_ops.matmul(mat_y, current_iterate)
+ current_mat_z = math_ops.matmul(current_iterate, mat_z)
+ # Compute the error in approximation.
+ mat_sqrt_a = current_mat_y * math_ops.sqrt(norm)
+ mat_a_approx = math_ops.matmul(mat_sqrt_a, mat_sqrt_a)
+ residual = mat_a - mat_a_approx
+ current_err = math_ops.sqrt(math_ops.reduce_sum(residual * residual)) / norm
+ return i + 1, current_mat_y, mat_y, current_mat_z, mat_z, current_err, err
+
+ identity = linalg_ops.eye(math_ops.to_int32(mat_a_size))
+ mat_a = mat_a + ridge_epsilon * identity
+ norm = math_ops.sqrt(math_ops.reduce_sum(mat_a * mat_a))
+ mat_init_y = mat_a / norm
+ mat_init_z = identity
+ init_err = norm
+
+ _, _, prev_mat_y, _, _, _, _ = control_flow_ops.while_loop(
+ _iter_condition, _iter_body, [
+ 0, mat_init_y, mat_init_y, mat_init_z, mat_init_z, init_err,
+ init_err + 1.0
+ ])
+ return prev_mat_y * math_ops.sqrt(norm)
+
+
+def matrix_inverse_pth_root(mat_g,
+ mat_g_size,
+ alpha,
+ iter_count=100,
+ epsilon=1e-6,
+ ridge_epsilon=1e-6):
+ """Computes mat_g^alpha, where alpha = -1/p, p a positive integer.
+
+ We use an iterative Schur-Newton method from equation 3.2 on page 9 of:
+
+ A Schur-Newton Method for the Matrix p-th Root and its Inverse
+ by Chun-Hua Guo and Nicholas J. Higham
+ SIAM Journal on Matrix Analysis and Applications,
+ 2006, Vol. 28, No. 3 : pp. 788-804
+ https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf
+
+ Args:
+ mat_g: the symmetric PSD matrix whose power it to be computed
+ mat_g_size: size of mat_g.
+ alpha: exponent, must be -1/p for p a positive integer.
+ iter_count: Maximum number of iterations.
+ epsilon: accuracy indicator, useful for early termination.
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
+
+ Returns:
+ mat_g^alpha
+ """
+
+ identity = linalg_ops.eye(math_ops.to_int32(mat_g_size))
+
+ def mat_power(mat_m, p):
+ """Computes mat_m^p, for p a positive integer.
+
+ Power p is known at graph compile time, so no need for loop and cond.
+ Args:
+ mat_m: a square matrix
+ p: a positive integer
+
+ Returns:
+ mat_m^p
+ """
+ assert p == int(p) and p > 0
+ power = None
+ while p > 0:
+ if p % 2 == 1:
+ power = math_ops.matmul(mat_m, power) if power is not None else mat_m
+ p //= 2
+ mat_m = math_ops.matmul(mat_m, mat_m)
+ return power
+
+ def _iter_condition(i, mat_m, _):
+ return math_ops.logical_and(
+ i < iter_count,
+ math_ops.reduce_max(math_ops.abs(mat_m - identity)) > epsilon)
+
+ def _iter_body(i, mat_m, mat_x):
+ mat_m_i = (1 - alpha) * identity + alpha * mat_m
+ return (i + 1, math_ops.matmul(mat_power(mat_m_i, -1.0 / alpha), mat_m),
+ math_ops.matmul(mat_x, mat_m_i))
+
+ if mat_g_size == 1:
+ mat_h = math_ops.pow(mat_g + ridge_epsilon, alpha)
+ else:
+ damped_mat_g = mat_g + ridge_epsilon * identity
+ z = (1 - 1 / alpha) / (2 * linalg_ops.norm(damped_mat_g))
+ # The best value for z is
+ # (1 - 1/alpha) * (c_max^{-alpha} - c_min^{-alpha}) /
+ # (c_max^{1-alpha} - c_min^{1-alpha})
+ # where c_max and c_min are the largest and smallest singular values of
+ # damped_mat_g.
+ # The above estimate assumes that c_max > c_min * 2^p. (p = -1/alpha)
+ # Can replace above line by the one below, but it is less accurate,
+ # hence needs more iterations to converge.
+ # z = (1 - 1/alpha) / math_ops.trace(damped_mat_g)
+ # If we want the method to always converge, use z = 1 / norm(damped_mat_g)
+ # or z = 1 / math_ops.trace(damped_mat_g), but these can result in many
+ # extra iterations.
+ _, _, mat_h = control_flow_ops.while_loop(
+ _iter_condition, _iter_body,
+ [0, damped_mat_g * z, identity * math_ops.pow(z, -alpha)])
+ return mat_h
diff --git a/tensorflow/contrib/opt/python/training/matrix_functions_test.py b/tensorflow/contrib/opt/python/training/matrix_functions_test.py
new file mode 100644
index 0000000000..518fa38233
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/matrix_functions_test.py
@@ -0,0 +1,63 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for Matrix functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.opt.python.training import matrix_functions
+from tensorflow.python.platform import test
+
+TOLERANCE = 1e-3
+
+
+def np_power(mat_g, alpha):
+ """Computes mat_g^alpha for a square symmetric matrix mat_g."""
+
+ mat_u, diag_d, mat_v = np.linalg.svd(mat_g)
+ diag_d = np.power(diag_d, alpha)
+ return np.dot(np.dot(mat_u, np.diag(diag_d)), mat_v)
+
+
+class MatrixFunctionTests(test.TestCase):
+
+ def testMatrixSquareRootFunction(self):
+ """Tests for matrix square roots."""
+
+ size = 20
+ mat_a = np.random.rand(size, size)
+ mat = np.dot(mat_a, mat_a.T)
+ expected_mat = np_power(mat, 0.5)
+ mat_root = matrix_functions.matrix_square_root(mat, size)
+ self.assertAllCloseAccordingToType(
+ expected_mat, mat_root, atol=TOLERANCE, rtol=TOLERANCE)
+
+ def testMatrixInversePthRootFunction(self):
+ """Tests for matrix inverse pth roots."""
+
+ size = 20
+ mat_a = np.random.rand(size, size)
+ mat = np.dot(mat_a, mat_a.T)
+ expected_mat = np_power(mat, -0.125)
+ mat_root = matrix_functions.matrix_inverse_pth_root(mat, size, -0.125)
+ self.assertAllCloseAccordingToType(
+ expected_mat, mat_root, atol=TOLERANCE, rtol=TOLERANCE)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/opt/python/training/shampoo.py b/tensorflow/contrib/opt/python/training/shampoo.py
index 294627f42a..f161521b97 100644
--- a/tensorflow/contrib/opt/python/training/shampoo.py
+++ b/tensorflow/contrib/opt/python/training/shampoo.py
@@ -23,6 +23,7 @@ from __future__ import division
from __future__ import print_function
import numpy as np
+from tensorflow.contrib.opt.python.training import matrix_functions
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -76,7 +77,7 @@ class ShampooOptimizer(optimizer.Optimizer):
learning_rate=1.0,
svd_interval=1,
precond_update_interval=1,
- epsilon=0.1,
+ epsilon=1e-4,
alpha=0.5,
use_iterative_root=False,
use_locking=False,
@@ -255,81 +256,18 @@ class ShampooOptimizer(optimizer.Optimizer):
def _compute_power_iter(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name,
iter_count=100, epsilon=1e-6):
- """Computes mat_g^alpha, where alpha = -1/p, p a positive integer.
+ """Computes mat_g^alpha, where alpha = -1/p, p a positive integer."""
+
+ mat_g_sqrt = matrix_functions.matrix_square_root(mat_g, mat_g_size,
+ iter_count, self._epsilon)
+ mat_h = matrix_functions.matrix_inverse_pth_root(
+ mat_g_sqrt,
+ mat_g_size,
+ 2 * alpha,
+ iter_count,
+ epsilon,
+ ridge_epsilon=0.0)
- We use an iterative Schur-Newton method from equation 3.2 on page 9 of:
-
- A Schur-Newton Method for the Matrix p-th Root and its Inverse
- by Chun-Hua Guo and Nicholas J. Higham
- SIAM Journal on Matrix Analysis and Applications,
- 2006, Vol. 28, No. 3 : pp. 788-804
- https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf
-
- Args:
- var: the variable we are updating.
- mat_g: the symmetric PSD matrix whose power it to be computed
- mat_g_size: size of mat_g.
- alpha: exponent, must be -1/p for p a positive integer.
- mat_h_slot_name: name of slot to store the power, if needed.
- iter_count: Maximum number of iterations.
- epsilon: accuracy indicator, useful for early termination.
-
- Returns:
- mat_g^alpha
- """
-
- identity = linalg_ops.eye(math_ops.to_int32(mat_g_size))
-
- def MatPower(mat_m, p):
- """Computes mat_m^p, for p a positive integer.
-
- Power p is known at graph compile time, so no need for loop and cond.
- Args:
- mat_m: a square matrix
- p: a positive integer
-
- Returns:
- mat_m^p
- """
- assert p == int(p) and p > 0
- power = None
- while p > 0:
- if p % 2 == 1:
- power = math_ops.matmul(mat_m, power) if power is not None else mat_m
- p //= 2
- mat_m = math_ops.matmul(mat_m, mat_m)
- return power
-
- def IterCondition(i, mat_m, _):
- return math_ops.logical_and(
- i < iter_count,
- math_ops.reduce_max(math_ops.abs(mat_m - identity)) > epsilon)
-
- def IterBody(i, mat_m, mat_x):
- mat_m_i = (1 - alpha) * identity + alpha * mat_m
- return (i + 1, math_ops.matmul(MatPower(mat_m_i, -1.0/alpha), mat_m),
- math_ops.matmul(mat_x, mat_m_i))
-
- if mat_g_size == 1:
- mat_h = math_ops.pow(mat_g + self._epsilon, alpha)
- else:
- damped_mat_g = mat_g + self._epsilon * identity
- z = (1 - 1 / alpha) / (2 * linalg_ops.norm(damped_mat_g))
- # The best value for z is
- # (1 - 1/alpha) * (c_max^{-alpha} - c_min^{-alpha}) /
- # (c_max^{1-alpha} - c_min^{1-alpha})
- # where c_max and c_min are the largest and smallest singular values of
- # damped_mat_g.
- # The above estimate assumes that c_max > c_min * 2^p. (p = -1/alpha)
- # Can replace above line by the one below, but it is less accurate,
- # hence needs more iterations to converge.
- # z = (1 - 1/alpha) / math_ops.trace(damped_mat_g)
- # If we want the method to always converge, use z = 1 / norm(damped_mat_g)
- # or z = 1 / math_ops.trace(damped_mat_g), but these can result in many
- # extra iterations.
- _, _, mat_h = control_flow_ops.while_loop(
- IterCondition, IterBody,
- [0, damped_mat_g * z, identity * math_ops.pow(z, -alpha)])
if mat_h_slot_name is not None:
return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h)
return mat_h
@@ -422,6 +360,8 @@ class ShampooOptimizer(optimizer.Optimizer):
mat_gbar_weight_t * precond_update_interval, i),
lambda: mat_g)
+ mat_g_updated = mat_g_updated / float(shape[i].value)
+
if self._svd_interval == 1:
mat_h = self._compute_power(var, mat_g_updated, shape[i], neg_alpha)
else:
@@ -443,7 +383,13 @@ class ShampooOptimizer(optimizer.Optimizer):
name="precond_" + str(i))
else:
# Tensor size is too large -- perform diagonal Shampoo update
- grad_outer = math_ops.reduce_sum(grad * grad, axis=axes)
+ # Only normalize non-vector cases.
+ if axes:
+ normalizer = 1.0 if indices is not None else float(shape[i].value)
+ grad_outer = math_ops.reduce_sum(grad * grad, axis=axes) / normalizer
+ else:
+ grad_outer = grad * grad
+
if i == 0 and indices is not None:
assert self._mat_gbar_decay == 1.0
mat_g_updated = state_ops.scatter_add(mat_g, indices,
diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py
index b3688ab181..05bcf2cfa3 100644
--- a/tensorflow/contrib/opt/python/training/shampoo_test.py
+++ b/tensorflow/contrib/opt/python/training/shampoo_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
TOLERANCE = 1e-3
+RIDGE_EPSILON = 1e-4
def np_power(mat_g, alpha):
@@ -77,8 +78,8 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# let up compute this in numpy
# Update rule is var = var - lr * mat_g^{-0.5} * grad
# lr = 1
- mat_g = np.outer(grad_np, grad_np)
- mat_h = np_power(mat_g + 0.1 * np.eye(size), -0.5)
+ mat_g = np.outer(grad_np, grad_np) / grad_np.shape[0]
+ mat_h = np_power(mat_g + RIDGE_EPSILON * np.eye(size), -0.5)
new_val_np = init_var_np - np.dot(mat_h, grad_np)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -88,8 +89,8 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
update_2.run()
new_val = sess.run(var)
- mat_g += np.outer(grad_np_2, grad_np_2)
- mat_h = np_power(mat_g + 0.1 * np.eye(size), -0.5)
+ mat_g += np.outer(grad_np_2, grad_np_2) / grad_np.shape[0]
+ mat_h = np_power(mat_g + RIDGE_EPSILON * np.eye(size), -0.5)
new_val_np -= np.dot(mat_h, grad_np_2)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -128,10 +129,10 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# let up compute this in numpy
# Update rule is var = var - lr * mat_g1^{-0.25} * grad * mat_g2^{-0.25}
# lr = 1
- mat_g1 = np.dot(grad_np, grad_np.transpose())
- mat_left = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.25)
- mat_g2 = np.dot(grad_np.transpose(), grad_np)
- mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ mat_g1 = np.dot(grad_np, grad_np.transpose()) / grad_np.shape[0]
+ mat_left = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.25)
+ mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np = init_var_np - np.dot(np.dot(mat_left, grad_np), mat_right)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -141,10 +142,10 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
update_2.run()
new_val = sess.run(var)
- mat_g1 += np.dot(grad_np_2, grad_np_2.transpose())
- mat_left = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.25)
- mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2)
- mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ mat_g1 += np.dot(grad_np_2, grad_np_2.transpose()) / grad_np_2.shape[0]
+ mat_left = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.25)
+ mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np -= np.dot(np.dot(mat_left, grad_np_2), mat_right)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -188,12 +189,18 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# let up compute this in numpy
# Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
# lr = 1
- mat_g1 = np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2]))
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2 = np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2]))
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3 = np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1]))
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1 = (
+ np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2])) /
+ grad_np.shape[0])
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 = (
+ np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2])) /
+ grad_np.shape[1])
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 = (
+ np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1])) /
+ grad_np.shape[2])
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
precond_grad = np.tensordot(grad_np, mat_g1_a, axes=([0], [0]))
precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
@@ -207,12 +214,18 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
update_2.run()
new_val = sess.run(var)
- mat_g1 += np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2]))
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2]))
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1]))
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1 += (
+ np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) /
+ grad_np_2.shape[0])
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 += (
+ np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) /
+ grad_np_2.shape[1])
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 += (
+ np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) /
+ grad_np_2.shape[2])
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
precond_grad = np.tensordot(grad_np_2, mat_g1_a, axes=([0], [0]))
precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
@@ -265,19 +278,21 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# let up compute this in numpy
# Update rule is var = var - lr * gg^{-0.5} * grad
# lr = 1
- mat_g = grad_np * grad_np + 0.1
- new_val_np = init_var_np - np.power(mat_g, -0.5) * grad_np
-
- self.assertAllCloseAccordingToType(new_val_np, new_val)
+ mat_g = (grad_np * grad_np)
+ new_val_np = init_var_np - np.power(mat_g + RIDGE_EPSILON, -0.5) * grad_np
+ self.assertAllCloseAccordingToType(
+ new_val_np, new_val, atol=TOLERANCE, rtol=TOLERANCE)
# Run another step of Shampoo
update_2.run()
new_val = sess.run(var)
- mat_g += grad_np_2 * grad_np_2
- new_val_np -= np.power(mat_g, -0.5) * grad_np_2
+ mat_g += (grad_np_2 * grad_np_2)
+ new_val_np -= np.power(mat_g + RIDGE_EPSILON, -0.5) * grad_np_2
+
+ self.assertAllCloseAccordingToType(
+ new_val_np, new_val, atol=TOLERANCE, rtol=TOLERANCE)
- self.assertAllCloseAccordingToType(new_val_np, new_val)
@parameterized.named_parameters(('Var', False), ('ResourceVar', True))
def testLargeMatrix(self, use_resource_var):
@@ -322,10 +337,11 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# with broadcasting
# lr = 1
- mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True)
- mat_left = np.power(mat_g1 + 0.1, -0.25)
- mat_g2 = np.dot(grad_np.transpose(), grad_np)
- mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ mat_g1 = np.sum(
+ grad_np * grad_np, axis=1, keepdims=True) / grad_np.shape[0]
+ mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25)
+ mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np = init_var_np - np.dot(grad_np * mat_left, mat_right)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -335,10 +351,11 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
update_2.run()
new_val = sess.run(var)
- mat_g1 += np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True)
- mat_left = np.power(mat_g1 + 0.1, -0.25)
- mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2)
- mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ mat_g1 += np.sum(
+ grad_np_2 * grad_np_2, axis=1, keepdims=True) / grad_np_2.shape[0]
+ mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25)
+ mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np -= np.dot(grad_np_2 * mat_left, mat_right)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -405,9 +422,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True)
mat_g1_acc = np.zeros((size[0], 1))
mat_g1_acc[grad_indices] += mat_g1
- mat_left = np.power(mat_g1 + 0.1, -0.25)
- mat_g2 = np.dot(grad_np.transpose(), grad_np)
- mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25)
+ mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np = init_var_np
new_val_np[grad_indices, :] -= np.dot(grad_np * mat_left, mat_right)
@@ -420,9 +437,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g1 = np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True)
mat_g1_acc[grad_indices_2] += mat_g1
- mat_left = np.power(mat_g1_acc[grad_indices_2] + 0.1, -0.25)
- mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2)
- mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ mat_left = np.power(mat_g1_acc[grad_indices_2] + RIDGE_EPSILON, -0.25)
+ mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np[grad_indices_2, :] -= np.dot(grad_np_2 * mat_left, mat_right)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -474,12 +491,15 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_dense = np.zeros_like(init_var_np)
grad_dense[grad_indices] = grad_np
- mat_g1 = np.tensordot(grad_dense, grad_dense, axes=([1, 2], [1, 2]))
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2 = np.tensordot(grad_dense, grad_dense, axes=([0, 2], [0, 2]))
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3 = np.tensordot(grad_dense, grad_dense, axes=([0, 1], [0, 1]))
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1 = np.tensordot(
+ grad_dense, grad_dense, axes=([1, 2], [1, 2])) / grad_dense.shape[0]
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 = np.tensordot(
+ grad_dense, grad_dense, axes=([0, 2], [0, 2])) / grad_dense.shape[1]
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 = np.tensordot(
+ grad_dense, grad_dense, axes=([0, 1], [0, 1])) / grad_dense.shape[2]
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
precond_grad = np.tensordot(grad_dense, mat_g1_a, axes=([0], [0]))
precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
@@ -536,12 +556,15 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# let up compute this in numpy
# Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
# lr = 1
- mat_g1 = np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2]))
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2 = np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2]))
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3 = np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1]))
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1 = np.tensordot(
+ grad_np, grad_np, axes=([1, 2], [1, 2])) / grad_np.shape[0]
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 = np.tensordot(
+ grad_np, grad_np, axes=([0, 2], [0, 2])) / grad_np.shape[1]
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 = np.tensordot(
+ grad_np, grad_np, axes=([0, 1], [0, 1])) / grad_np.shape[2]
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
gbar_np = gbar_weight * grad_np
precond_grad = np.tensordot(gbar_np, mat_g1_a, axes=([0], [0]))
@@ -556,12 +579,15 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
update_2.run()
new_val = sess.run(var)
- mat_g1 += np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2]))
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2]))
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1]))
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1 += np.tensordot(
+ grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) / grad_np_2.shape[0]
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 += np.tensordot(
+ grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) / grad_np_2.shape[1]
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 += np.tensordot(
+ grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) / grad_np_2.shape[2]
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
gbar_np_2 = gbar_decay * gbar_np + gbar_weight * grad_np_2
precond_grad = np.tensordot(gbar_np_2, mat_g1_a, axes=([0], [0]))
@@ -626,13 +652,19 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# let up compute this in numpy
# Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
# lr = 1
- mat_g1 += np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2]))
- mat_g2 += np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2]))
- mat_g3 += np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1]))
+ mat_g1 += np.tensordot(
+ grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) / grad_np[i].shape[0]
+ mat_g2 += np.tensordot(
+ grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) / grad_np[i].shape[1]
+ mat_g3 += np.tensordot(
+ grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) / grad_np[i].shape[2]
if (i + 1) % svd_interval == 0:
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]),
+ -0.5 / 3.0)
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]),
+ -0.5 / 3.0)
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]),
+ -0.5 / 3.0)
precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0]))
precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
@@ -700,17 +732,23 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
# lr = 1
if (i + 1) % precond_update_interval == 0:
- mat_g1 += (np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2]))
- * precond_update_interval)
- mat_g2 += (np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2]))
- * precond_update_interval)
- mat_g3 += (np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1]))
- * precond_update_interval)
+ mat_g1 += (
+ np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) /
+ grad_np[i].shape[0] * precond_update_interval)
+ mat_g2 += (
+ np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) /
+ grad_np[i].shape[1] * precond_update_interval)
+ mat_g3 += (
+ np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) /
+ grad_np[i].shape[2] * precond_update_interval)
if (i + 1) % svd_interval == 0:
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]),
+ -0.5 / 3.0)
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]),
+ -0.5 / 3.0)
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]),
+ -0.5 / 3.0)
precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0]))
precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
index 29acfc602e..200b0d2008 100644
--- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
+++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.opt.python.training import shampoo
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
@@ -361,3 +362,74 @@ class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer):
super(AdamWOptimizer, self).__init__(
weight_decay, learning_rate=learning_rate, beta1=beta1, beta2=beta2,
epsilon=epsilon, use_locking=use_locking, name=name)
+
+
+@tf_export("contrib.opt.ShampooWOptimizer")
+class ShampooWOptimizer(DecoupledWeightDecayExtension,
+ shampoo.ShampooOptimizer):
+ """Optimizer that implements the Shampoo algorithm with weight decay.
+
+ For further information see the documentation of the Shampoo Optimizer.
+ """
+
+ def __init__(self,
+ weight_decay,
+ global_step,
+ max_matrix_size=768,
+ gbar_decay=0.0,
+ gbar_weight=1.0,
+ mat_gbar_decay=1.0,
+ mat_gbar_weight=1.0,
+ learning_rate=1.0,
+ svd_interval=1,
+ precond_update_interval=1,
+ epsilon=1e-4,
+ alpha=0.5,
+ use_iterative_root=False,
+ use_locking=False,
+ name="ShampooW"):
+ """Construct a new ShampooW optimizer.
+
+ For further information see the documentation of the Shampoo Optimizer.
+
+ Args:
+ weight_decay: A `Tensor` or a floating point value. The weight decay.
+ global_step: tensorflow variable indicating the step.
+ max_matrix_size: We do not perform SVD for matrices larger than this.
+ gbar_decay:
+ gbar_weight: Used to update gbar: gbar[t] = gbar_decay[t] * gbar[t-1] +
+ gbar_weight[t] * g[t]
+ mat_gbar_decay:
+ mat_gbar_weight: Used to update mat_gbar: mat_gbar_j[t] =
+ mat_gbar_decay[t] * mat_gbar_j[t-1] + mat_gbar_weight[t] * gg_j[t]
+ learning_rate: Similar to SGD
+ svd_interval: We should do SVD after this many steps. Default = 1, i.e.
+ every step. Usually 20 leads to no loss of accuracy, and 50 or 100 is
+ also OK. May also want more often early,
+ and less often later - set in caller as for example:
+ "svd_interval = lambda(T): tf.cond(
+ T < 2000, lambda: 20.0, lambda: 1000.0)"
+ precond_update_interval: We should update the preconditioners after this
+ many steps. Default = 1. Usually less than svd_interval.
+ epsilon: epsilon * I_n is added to each mat_gbar_j for stability
+ alpha: total power of the preconditioners.
+ use_iterative_root: should the optimizer use SVD (faster) or the iterative
+ root method (for TPU) for finding the roots of PSD matrices.
+ use_locking: If `True` use locks for update operations.
+ name: name of optimizer.
+ """
+ super(ShampooWOptimizer, self).__init__(
+ weight_decay,
+ global_step=global_step,
+ max_matrix_size=max_matrix_size,
+ gbar_decay=gbar_decay,
+ gbar_weight=gbar_weight,
+ mat_gbar_decay=mat_gbar_weight,
+ learning_rate=learning_rate,
+ svd_interval=svd_interval,
+ precond_update_interval=precond_update_interval,
+ epsilon=epsilon,
+ alpha=alpha,
+ use_iterative_root=use_iterative_root,
+ use_locking=use_locking,
+ name=name)
diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD
index e7eb4ac563..b897224c6d 100644
--- a/tensorflow/contrib/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/BUILD
@@ -36,6 +36,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
+ ":keras_saved_model",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lib",
@@ -101,23 +102,33 @@ py_library(
tags = ["no_windows"],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:lib",
+ "//tensorflow/python:metrics",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:saver",
"//tensorflow/python:util",
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:export",
+ "//tensorflow/python/estimator:keras",
+ "//tensorflow/python/estimator:model_fn",
"//tensorflow/python/keras:engine",
- "//tensorflow/python/saved_model:constants",
+ "//tensorflow/python/saved_model",
],
)
py_test(
name = "keras_saved_model_test",
- size = "small",
+ size = "medium",
srcs = ["python/saved_model/keras_saved_model_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":saved_model_py",
+ ":keras_saved_model",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
"//tensorflow/python/keras",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/saved_model/__init__.py b/tensorflow/contrib/saved_model/__init__.py
index 95e1a8967b..074dc655ac 100644
--- a/tensorflow/contrib/saved_model/__init__.py
+++ b/tensorflow/contrib/saved_model/__init__.py
@@ -26,10 +26,13 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long
from tensorflow.contrib.saved_model.python.saved_model.keras_saved_model import *
from tensorflow.contrib.saved_model.python.saved_model.signature_def_utils import *
-# pylint: enable=unused-import,widcard-import,line-too-long
+# pylint: enable=unused-import,wildcard-import,line-too-long
from tensorflow.python.util.all_util import remove_undocumented
-_allowed_symbols = ["get_signature_def_by_key", "load_model", "save_model"]
+_allowed_symbols = [
+ "get_signature_def_by_key",
+ "load_keras_model",
+ "save_keras_model"]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py
index e2a969f053..2c5c8c4afd 100644
--- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py
+++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py
@@ -20,28 +20,69 @@ from __future__ import print_function
import os
+from tensorflow.python.client import session
+from tensorflow.python.estimator import keras as estimator_keras_util
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.export import export as export_helpers
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import models as models_lib
+from tensorflow.python.keras import optimizers
from tensorflow.python.keras.models import model_from_json
from tensorflow.python.lib.io import file_io
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import constants
+from tensorflow.python.saved_model import utils_impl as saved_model_utils
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.util import compat
-def save_model(model, saved_model_path):
+def save_keras_model(
+ model, saved_model_path, custom_objects=None, as_text=None):
"""Save a `tf.keras.Model` into Tensorflow SavedModel format.
- `save_model` generates such files/folders under the `saved_model_path` folder:
+ `save_model` generates new files/folders under the `saved_model_path` folder:
1) an asset folder containing the json string of the model's
- configuration(topology).
+ configuration (topology).
2) a checkpoint containing the model weights.
+ 3) a saved_model.pb file containing the model's MetaGraphs. The prediction
+ graph is always exported. The evaluaton and training graphs are exported
+ if the following conditions are met:
+ - Evaluation: model loss is defined.
+ - Training: model is compiled with an optimizer defined under `tf.train`.
+ This is because `tf.keras.optimizers.Optimizer` instances cannot be
+ saved to checkpoints.
- Note that subclassed models can not be saved via this function, unless you
- provide an implementation for get_config() and from_config().
- Also note that `tf.keras.optimizers.Optimizer` instances can not currently be
- saved to checkpoints. Use optimizers from `tf.train`.
+ Model Requirements:
+ - Model must be a sequential model or functional model. Subclassed models can
+ not be saved via this function, unless you provide an implementation for
+ get_config() and from_config().
+ - All variables must be saveable by the model. In general, this condition is
+ met through the use of layers defined in the keras library. However,
+ there is currently a bug with variables created in Lambda layer functions
+ not being saved correctly (see
+ https://github.com/keras-team/keras/issues/9740).
+
+ Note that each mode is exported in separate graphs, so different modes do not
+ share variables. To use the train graph with evaluation or prediction graphs,
+ create a new checkpoint if variable values have been updated.
Args:
model: A `tf.keras.Model` to be saved.
saved_model_path: a string specifying the path to the SavedModel directory.
+ The SavedModel will be saved to a timestamped folder created within this
+ directory.
+ custom_objects: Optional dictionary mapping string names to custom classes
+ or functions (e.g. custom loss functions).
+ as_text: whether to write the `SavedModel` proto in text format.
+
+ Returns:
+ String path to the SavedModel folder, a subdirectory of `saved_model_path`.
Raises:
NotImplementedError: If the passed in model is a subclassed model.
@@ -49,35 +90,200 @@ def save_model(model, saved_model_path):
if not model._is_graph_network:
raise NotImplementedError
- # save model configuration as a json string under assets folder.
- model_json = model.to_json()
- assets_destination_dir = os.path.join(
- compat.as_bytes(saved_model_path),
- compat.as_bytes(constants.ASSETS_DIRECTORY))
+ export_dir = export_helpers.get_timestamped_export_dir(saved_model_path)
+ temp_export_dir = export_helpers.get_temp_export_dir(export_dir)
+
+ builder = saved_model_builder.SavedModelBuilder(temp_export_dir)
+
+ # Manually save variables to export them in an object-based checkpoint. This
+ # skips the `builder.add_meta_graph_and_variables()` step, which saves a
+ # named-based checkpoint.
+ # TODO(b/113134168): Add fn to Builder to save with object-based saver.
+ # TODO(b/113178242): This should only export the model json structure. Only
+ # one save is needed once the weights can be copied from the model to clone.
+ checkpoint_path = _export_model_json_and_variables(model, temp_export_dir)
+
+ # Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that
+ # Keras models and `Estimator`s are exported with the same format.
+ # Every time a mode is exported, the code checks to see if new variables have
+ # been created (e.g. optimizer slot variables). If that is the case, the
+ # checkpoint is re-saved to include the new variables.
+ export_args = {'builder': builder,
+ 'model': model,
+ 'custom_objects': custom_objects,
+ 'checkpoint_path': checkpoint_path}
+
+ has_saved_vars = False
+ if model.optimizer:
+ if isinstance(model.optimizer, optimizers.TFOptimizer):
+ _export_mode(model_fn_lib.ModeKeys.TRAIN, has_saved_vars, **export_args)
+ has_saved_vars = True
+ _export_mode(model_fn_lib.ModeKeys.EVAL, has_saved_vars, **export_args)
+ else:
+ logging.warning(
+ 'Model was compiled with an optimizer, but the optimizer is not from '
+ '`tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving '
+ 'graph was exported. The train and evaluate graphs were not added to '
+ 'the SavedModel.')
+ _export_mode(model_fn_lib.ModeKeys.PREDICT, has_saved_vars, **export_args)
+
+ builder.save(as_text)
+
+ gfile.Rename(temp_export_dir, export_dir)
+ return export_dir
- if not file_io.file_exists(assets_destination_dir):
- file_io.recursive_create_dir(assets_destination_dir)
+def _export_model_json_and_variables(model, saved_model_path):
+ """Save model variables and json structure into SavedModel subdirectories."""
+ # Save model configuration as a json string under assets folder.
+ model_json = model.to_json()
model_json_filepath = os.path.join(
- compat.as_bytes(assets_destination_dir),
- compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
+ saved_model_utils.get_or_create_assets_dir(saved_model_path),
+ compat.as_text(constants.SAVED_MODEL_FILENAME_JSON))
file_io.write_string_to_file(model_json_filepath, model_json)
- # save model weights in checkpoint format.
- checkpoint_destination_dir = os.path.join(
- compat.as_bytes(saved_model_path),
- compat.as_bytes(constants.VARIABLES_DIRECTORY))
+ # Save model weights in checkpoint format under variables folder.
+ saved_model_utils.get_or_create_variables_dir(saved_model_path)
+ checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path)
+ model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True)
+ return checkpoint_prefix
- if not file_io.file_exists(checkpoint_destination_dir):
- file_io.recursive_create_dir(checkpoint_destination_dir)
- checkpoint_prefix = os.path.join(
- compat.as_text(checkpoint_destination_dir),
- compat.as_text(constants.VARIABLES_FILENAME))
- model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True)
+def _get_var_list(model):
+ """Return list of all checkpointed saveable objects in the model."""
+ return checkpointable_utils.named_saveables(model)
+
+
+def _export_mode(
+ mode, has_saved_vars, builder, model, custom_objects, checkpoint_path):
+ """Export a model, and optionally save new vars from the clone model.
+
+ Args:
+ mode: A `tf.estimator.ModeKeys` string.
+ has_saved_vars: A `boolean` indicating whether the SavedModel has already
+ exported variables.
+ builder: A `SavedModelBuilder` object.
+ model: A `tf.keras.Model` object.
+ custom_objects: A dictionary mapping string names to custom classes
+ or functions.
+ checkpoint_path: String path to checkpoint.
+
+ Raises:
+ ValueError: If the train/eval mode is being exported, but the model does
+ not have an optimizer.
+ """
+ compile_clone = (mode != model_fn_lib.ModeKeys.PREDICT)
+ if compile_clone and not model.optimizer:
+ raise ValueError(
+ 'Model does not have an optimizer. Cannot export mode %s' % mode)
+
+ model_graph = ops.get_default_graph()
+ with ops.Graph().as_default() as g:
+
+ K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN)
+
+ # Clone the model into blank graph. This will create placeholders for inputs
+ # and targets.
+ clone = models_lib.clone_and_build_model(
+ model, custom_objects=custom_objects, compile_clone=compile_clone)
+
+ # Make sure that iterations variable is added to the global step collection,
+ # to ensure that, when the SavedModel graph is loaded, the iterations
+ # variable is returned by `tf.train.get_global_step()`. This is required for
+ # compatibility with the SavedModelEstimator.
+ if compile_clone:
+ g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations)
+
+ # Extract update and train ops from train/test/predict functions.
+ if mode == model_fn_lib.ModeKeys.TRAIN:
+ clone._make_train_function()
+ builder._add_train_op(clone.train_function.updates_op)
+ elif mode == model_fn_lib.ModeKeys.EVAL:
+ clone._make_test_function()
+ else:
+ clone._make_predict_function()
+ g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates)
+
+ clone_var_list = checkpointable_utils.named_saveables(clone)
+
+ with session.Session().as_default():
+ if has_saved_vars:
+ # Confirm all variables in the clone have an entry in the checkpoint.
+ status = clone.load_weights(checkpoint_path)
+ status.assert_existing_objects_matched()
+ else:
+ # Confirm that variables between the clone and model match up exactly,
+ # not counting optimizer objects. Optimizer objects are ignored because
+ # if the model has not trained, the slot variables will not have been
+ # created yet.
+ # TODO(b/113179535): Replace with checkpointable equivalence.
+ _assert_same_non_optimizer_objects(model, model_graph, clone, g)
+
+ # TODO(b/113178242): Use value transfer for checkpointable objects.
+ clone.load_weights(checkpoint_path)
+
+ # Add graph and variables to SavedModel.
+ # TODO(b/113134168): Switch to add_meta_graph_and_variables.
+ clone.save_weights(checkpoint_path, save_format='tf', overwrite=True)
+ builder._has_saved_variables = True
+
+ # Add graph to the SavedModel builder.
+ builder.add_meta_graph(
+ model_fn_lib.EXPORT_TAG_MAP[mode],
+ signature_def_map=_create_signature_def_map(clone, mode),
+ saver=saver_lib.Saver(clone_var_list),
+ main_op=variables.local_variables_initializer())
+ return None
+
+
+def _create_signature_def_map(model, mode):
+ """Create a SignatureDef map from a Keras model."""
+ inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)}
+ if model.optimizer:
+ targets_dict = {x.name.split(':')[0]: x
+ for x in model.targets if x is not None}
+ inputs_dict.update(targets_dict)
+ outputs_dict = {name: x
+ for name, x in zip(model.output_names, model.outputs)}
+ export_outputs = model_fn_lib.export_outputs_for_mode(
+ mode,
+ predictions=outputs_dict,
+ loss=model.total_loss if model.optimizer else None,
+ metrics=estimator_keras_util._convert_keras_metrics_to_estimator(model))
+ return export_helpers.build_all_signature_defs(
+ inputs_dict,
+ export_outputs=export_outputs,
+ serving_only=(mode == model_fn_lib.ModeKeys.PREDICT))
+
+
+def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph):
+ """Assert model and clone contain the same checkpointable objects."""
+
+ def get_non_optimizer_objects(m, g):
+ """Gather set of model and optimizer checkpointable objects."""
+ # Set default graph because optimizer.variables() returns optimizer
+ # variables defined in the default graph.
+ with g.as_default():
+ all_objects = set(checkpointable_utils.list_objects(m))
+ optimizer_and_variables = set()
+ for obj in all_objects:
+ if isinstance(obj, optimizers.TFOptimizer):
+ optimizer_and_variables.update(checkpointable_utils.list_objects(obj))
+ optimizer_and_variables.update(set(obj.optimizer.variables()))
+ return all_objects - optimizer_and_variables
+
+ model_objects = get_non_optimizer_objects(model, model_graph)
+ clone_objects = get_non_optimizer_objects(clone, clone_graph)
+
+ if len(model_objects) != len(clone_objects):
+ raise errors.InternalError(
+ None, None,
+ 'Model and clone must use the same variables.'
+ '\n\tModel variables: %s\n\t Clone variables: %s'
+ % (model_objects, clone_objects))
-def load_model(saved_model_path):
+def load_keras_model(saved_model_path):
"""Load a keras.Model from SavedModel.
load_model reinstantiates model state by:
diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
index 107ae1b07b..8a0dbef788 100644
--- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
+++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
@@ -20,18 +20,35 @@ from __future__ import print_function
import os
import shutil
+
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.saved_model.python.saved_model import keras_saved_model
from tensorflow.python import keras
+from tensorflow.python.client import session
+from tensorflow.python.eager import context
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import training
+from tensorflow.python.keras.utils import tf_utils
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
+from tensorflow.python.saved_model import constants
+from tensorflow.python.saved_model import loader_impl
+from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import training as training_module
class TestModelSavingandLoading(test.TestCase):
+ def _save_model_dir(self, dirname='saved_model'):
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
+ return os.path.join(temp_dir, dirname)
+
def test_saving_sequential_model(self):
with self.test_session():
model = keras.models.Sequential()
@@ -48,13 +65,11 @@ class TestModelSavingandLoading(test.TestCase):
model.train_on_batch(x, y)
ref_y = model.predict(x)
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
- temp_saved_model = os.path.join(temp_dir, 'saved_model')
- keras_saved_model.save_model(model, temp_saved_model)
+ temp_saved_model = self._save_model_dir()
+ output_path = keras_saved_model.save_keras_model(model, temp_saved_model)
- loaded_model = keras_saved_model.load_model(temp_saved_model)
+ loaded_model = keras_saved_model.load_keras_model(output_path)
y = loaded_model.predict(x)
self.assertAllClose(ref_y, y, atol=1e-05)
@@ -69,12 +84,9 @@ class TestModelSavingandLoading(test.TestCase):
x = np.random.random((1, 3))
ref_y = model.predict(x)
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
-
- temp_saved_model = os.path.join(temp_dir, 'saved_model')
- keras_saved_model.save_model(model, temp_saved_model)
- loaded_model = keras_saved_model.load_model(temp_saved_model)
+ temp_saved_model = self._save_model_dir()
+ output_path = keras_saved_model.save_keras_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_keras_model(output_path)
y = loaded_model.predict(x)
self.assertAllClose(ref_y, y, atol=1e-05)
@@ -95,12 +107,10 @@ class TestModelSavingandLoading(test.TestCase):
model.train_on_batch(x, y)
ref_y = model.predict(x)
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
- temp_saved_model = os.path.join(temp_dir, 'saved_model')
- keras_saved_model.save_model(model, temp_saved_model)
- loaded_model = keras_saved_model.load_model(temp_saved_model)
+ temp_saved_model = self._save_model_dir()
+ output_path = keras_saved_model.save_keras_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_keras_model(output_path)
y = loaded_model.predict(x)
self.assertAllClose(ref_y, y, atol=1e-05)
@@ -118,12 +128,10 @@ class TestModelSavingandLoading(test.TestCase):
y = np.random.random((1, 3))
ref_y = model.predict(x)
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
- temp_saved_model = os.path.join(temp_dir, 'saved_model')
- keras_saved_model.save_model(model, temp_saved_model)
- loaded_model = keras_saved_model.load_model(temp_saved_model)
+ temp_saved_model = self._save_model_dir()
+ output_path = keras_saved_model.save_keras_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_keras_model(output_path)
y = loaded_model.predict(x)
self.assertAllClose(ref_y, y, atol=1e-05)
@@ -142,14 +150,13 @@ class TestModelSavingandLoading(test.TestCase):
x = np.random.random((1, 3))
y = np.random.random((1, 3))
model.train_on_batch(x, y)
+ model.train_on_batch(x, y)
ref_y = model.predict(x)
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
- temp_saved_model = os.path.join(temp_dir, 'saved_model')
- keras_saved_model.save_model(model, temp_saved_model)
- loaded_model = keras_saved_model.load_model(temp_saved_model)
+ temp_saved_model = self._save_model_dir()
+ output_path = keras_saved_model.save_keras_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_keras_model(output_path)
loaded_model.compile(
loss='mse',
optimizer=training_module.RMSPropOptimizer(0.1),
@@ -170,8 +177,10 @@ class TestModelSavingandLoading(test.TestCase):
self.assertAllClose(ref_y, y, atol=1e-05)
# test saving/loading again
- keras_saved_model.save_model(loaded_model, temp_saved_model)
- loaded_model = keras_saved_model.load_model(temp_saved_model)
+ temp_saved_model2 = self._save_model_dir('saved_model_2')
+ output_path2 = keras_saved_model.save_keras_model(
+ loaded_model, temp_saved_model2)
+ loaded_model = keras_saved_model.load_keras_model(output_path2)
y = loaded_model.predict(x)
self.assertAllClose(ref_y, y, atol=1e-05)
@@ -190,11 +199,231 @@ class TestModelSavingandLoading(test.TestCase):
return self.layer2(self.layer1(inp))
model = SubclassedModel()
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
- temp_saved_model = os.path.join(temp_dir, 'saved_model')
+
+ temp_saved_model = self._save_model_dir()
with self.assertRaises(NotImplementedError):
- keras_saved_model.save_model(model, temp_saved_model)
+ keras_saved_model.save_keras_model(model, temp_saved_model)
+
+
+class LayerWithLearningPhase(keras.engine.base_layer.Layer):
+
+ def call(self, x):
+ phase = keras.backend.learning_phase()
+ output = tf_utils.smart_cond(
+ phase, lambda: x * 0, lambda: array_ops.identity(x))
+ if not context.executing_eagerly():
+ output._uses_learning_phase = True # pylint: disable=protected-access
+ return output
+
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
+
+def functional_model(uses_learning_phase):
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ x = keras.layers.Dense(3)(x)
+ if uses_learning_phase:
+ x = LayerWithLearningPhase()(x)
+ return keras.models.Model(inputs, x)
+
+
+def sequential_model(uses_learning_phase):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.Dense(3))
+ if uses_learning_phase:
+ model.add(LayerWithLearningPhase())
+ return model
+
+
+def load_model(sess, path, mode):
+ tags = model_fn_lib.EXPORT_TAG_MAP[mode]
+ sig_def_key = (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+ if mode == model_fn_lib.ModeKeys.PREDICT else mode)
+ meta_graph_def = loader_impl.load(sess, tags, path)
+ inputs = {
+ k: sess.graph.get_tensor_by_name(v.name)
+ for k, v in meta_graph_def.signature_def[sig_def_key].inputs.items()}
+ outputs = {
+ k: sess.graph.get_tensor_by_name(v.name)
+ for k, v in meta_graph_def.signature_def[sig_def_key].outputs.items()}
+ return inputs, outputs
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
+
+ def _save_model_dir(self, dirname='saved_model'):
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
+ return os.path.join(temp_dir, dirname)
+
+ @parameterized.parameters(
+ (functional_model, True, training_module.AdadeltaOptimizer(), True),
+ (functional_model, True, training_module.AdadeltaOptimizer(), False),
+ (functional_model, False, None, False),
+ (sequential_model, True, training_module.AdadeltaOptimizer(), True),
+ (sequential_model, True, training_module.AdadeltaOptimizer(), False),
+ (sequential_model, False, None, False))
+ def testSaveAndLoadSavedModelExport(
+ self, model_builder, uses_learning_phase, optimizer, train_before_export):
+ saved_model_path = self._save_model_dir()
+ with self.test_session(graph=ops.Graph()):
+ input_arr = np.random.random((1, 3))
+ target_arr = np.random.random((1, 3))
+
+ model = model_builder(uses_learning_phase)
+ if optimizer is not None:
+ model.compile(
+ loss='mse',
+ optimizer=optimizer,
+ metrics=['mae'])
+ if train_before_export:
+ model.train_on_batch(input_arr, target_arr)
+
+ ref_loss, ref_mae = model.evaluate(input_arr, target_arr)
+
+ ref_predict = model.predict(input_arr)
+
+ # Export SavedModel
+ output_path = keras_saved_model.save_keras_model(model, saved_model_path)
+
+ input_name = model.input_names[0]
+ output_name = model.output_names[0]
+ target_name = output_name + '_target'
+
+ # Load predict graph, and test predictions
+ with session.Session(graph=ops.Graph()) as sess:
+ inputs, outputs = load_model(sess, output_path,
+ model_fn_lib.ModeKeys.PREDICT)
+
+ predictions = sess.run(outputs[output_name],
+ {inputs[input_name]: input_arr})
+ self.assertAllClose(ref_predict, predictions, atol=1e-05)
+
+ if optimizer:
+ # Load eval graph, and test predictions, loss and metric values
+ with session.Session(graph=ops.Graph()) as sess:
+ inputs, outputs = load_model(sess, output_path,
+ model_fn_lib.ModeKeys.EVAL)
+
+ eval_results = sess.run(outputs, {inputs[input_name]: input_arr,
+ inputs[target_name]: target_arr})
+
+ self.assertEqual(int(train_before_export),
+ sess.run(training_module.get_global_step()))
+ self.assertAllClose(ref_loss, eval_results['loss'], atol=1e-05)
+ self.assertAllClose(
+ ref_mae, eval_results['metrics/mae/update_op'], atol=1e-05)
+ self.assertAllClose(
+ ref_predict, eval_results['predictions/' + output_name], atol=1e-05)
+
+ # Load train graph, and check for the train op, and prediction values
+ with session.Session(graph=ops.Graph()) as sess:
+ inputs, outputs = load_model(sess, output_path,
+ model_fn_lib.ModeKeys.TRAIN)
+ self.assertEqual(int(train_before_export),
+ sess.run(training_module.get_global_step()))
+ self.assertIn('loss', outputs)
+ self.assertIn('metrics/mae/update_op', outputs)
+ self.assertIn('metrics/mae/value', outputs)
+ self.assertIn('predictions/' + output_name, outputs)
+
+ # Train for a step
+ train_op = ops.get_collection(constants.TRAIN_OP_KEY)
+ train_outputs, _ = sess.run(
+ [outputs, train_op], {inputs[input_name]: input_arr,
+ inputs[target_name]: target_arr})
+ self.assertEqual(int(train_before_export) + 1,
+ sess.run(training_module.get_global_step()))
+
+ if uses_learning_phase:
+ self.assertAllClose(
+ [[0, 0, 0]], train_outputs['predictions/' + output_name],
+ atol=1e-05)
+ else:
+ self.assertNotAllClose(
+ [[0, 0, 0]], train_outputs['predictions/' + output_name],
+ atol=1e-05)
+
+ def testSaveAndLoadSavedModelWithCustomObject(self):
+ saved_model_path = self._save_model_dir()
+ with session.Session(graph=ops.Graph()) as sess:
+ def relu6(x):
+ return keras.backend.relu(x, max_value=6)
+ inputs = keras.layers.Input(shape=(1,))
+ outputs = keras.layers.Activation(relu6)(inputs)
+ model = keras.models.Model(inputs, outputs)
+ output_path = keras_saved_model.save_keras_model(
+ model, saved_model_path, custom_objects={'relu6': relu6})
+ with session.Session(graph=ops.Graph()) as sess:
+ inputs, outputs = load_model(sess, output_path,
+ model_fn_lib.ModeKeys.PREDICT)
+ input_name = model.input_names[0]
+ output_name = model.output_names[0]
+ predictions = sess.run(
+ outputs[output_name], {inputs[input_name]: [[7], [-3], [4]]})
+ self.assertAllEqual([[6], [0], [4]], predictions)
+
+ def testAssertModelCloneSameObjectsIgnoreOptimizer(self):
+ input_arr = np.random.random((1, 3))
+ target_arr = np.random.random((1, 3))
+
+ model_graph = ops.Graph()
+ clone_graph = ops.Graph()
+
+ # Create two models with the same layers but different optimizers.
+ with session.Session(graph=model_graph):
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ x = keras.layers.Dense(3)(x)
+ model = keras.models.Model(inputs, x)
+
+ model.compile(loss='mse', optimizer=training_module.AdadeltaOptimizer())
+ model.train_on_batch(input_arr, target_arr)
+
+ with session.Session(graph=clone_graph):
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ x = keras.layers.Dense(3)(x)
+ clone = keras.models.Model(inputs, x)
+ clone.compile(loss='mse', optimizer=keras.optimizers.RMSprop(lr=0.0001))
+ clone.train_on_batch(input_arr, target_arr)
+
+ keras_saved_model._assert_same_non_optimizer_objects(
+ model, model_graph, clone, clone_graph)
+
+ def testAssertModelCloneSameObjectsThrowError(self):
+ input_arr = np.random.random((1, 3))
+ target_arr = np.random.random((1, 3))
+
+ model_graph = ops.Graph()
+ clone_graph = ops.Graph()
+
+ # Create two models with the same layers but different optimizers.
+ with session.Session(graph=model_graph):
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ x = keras.layers.Dense(3)(x)
+ model = keras.models.Model(inputs, x)
+
+ model.compile(loss='mse', optimizer=training_module.AdadeltaOptimizer())
+ model.train_on_batch(input_arr, target_arr)
+
+ with session.Session(graph=clone_graph):
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ x = keras.layers.Dense(4)(x)
+ x = keras.layers.Dense(3)(x)
+ clone = keras.models.Model(inputs, x)
+ clone.compile(loss='mse', optimizer=keras.optimizers.RMSprop(lr=0.0001))
+ clone.train_on_batch(input_arr, target_arr)
+
+ with self.assertRaisesRegexp(
+ errors.InternalError, 'Model and clone must use the same variables.'):
+ keras_saved_model._assert_same_non_optimizer_objects(
+ model, model_graph, clone, clone_graph)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
index d877831fce..a6ce45c203 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
@@ -416,12 +416,17 @@ class Image(ItemHandler):
def decode_image():
"""Decodes a image based on the headers."""
- return image_ops.decode_image(image_buffer, channels=self._channels)
+ return math_ops.cast(
+ image_ops.decode_image(image_buffer, channels=self._channels),
+ self._dtype)
def decode_jpeg():
"""Decodes a jpeg image with specified '_dct_method'."""
- return image_ops.decode_jpeg(
- image_buffer, channels=self._channels, dct_method=self._dct_method)
+ return math_ops.cast(
+ image_ops.decode_jpeg(
+ image_buffer,
+ channels=self._channels,
+ dct_method=self._dct_method), self._dtype)
def check_jpeg():
"""Checks if an image is jpeg."""
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
index d783d4fef4..826242c9d7 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
@@ -37,12 +37,12 @@ from tensorflow.python.platform import test
class TFExampleDecoderTest(test.TestCase):
def _EncodedFloatFeature(self, ndarray):
- return feature_pb2.Feature(float_list=feature_pb2.FloatList(
- value=ndarray.flatten().tolist()))
+ return feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=ndarray.flatten().tolist()))
def _EncodedInt64Feature(self, ndarray):
- return feature_pb2.Feature(int64_list=feature_pb2.Int64List(
- value=ndarray.flatten().tolist()))
+ return feature_pb2.Feature(
+ int64_list=feature_pb2.Int64List(value=ndarray.flatten().tolist()))
def _EncodedBytesFeature(self, tf_encoded):
with self.test_session():
@@ -74,12 +74,14 @@ class TFExampleDecoderTest(test.TestCase):
if image_format in ['raw', 'RAW']:
return constant_op.constant(image.tostring(), dtype=dtypes.string)
- def GenerateImage(self, image_format, image_shape):
+ def GenerateImage(self, image_format, image_shape, image_dtype=np.uint8):
"""Generates an image and an example containing the encoded image.
Args:
image_format: the encoding format of the image.
image_shape: the shape of the image to generate.
+ image_dtype: the dtype of values in the image. Only 'raw' image can have
+ type different than uint8.
Returns:
image: the generated image.
@@ -87,14 +89,18 @@ class TFExampleDecoderTest(test.TestCase):
serialized image and a feature key 'image/format' set to the image
encoding format ['jpeg', 'JPEG', 'png', 'PNG', 'raw'].
"""
+ assert image_format in ['raw', 'RAW'] or image_dtype == np.uint8
num_pixels = image_shape[0] * image_shape[1] * image_shape[2]
image = np.linspace(
- 0, num_pixels - 1, num=num_pixels).reshape(image_shape).astype(np.uint8)
+ 0, num_pixels - 1,
+ num=num_pixels).reshape(image_shape).astype(image_dtype)
tf_encoded = self._Encoder(image, image_format)
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/encoded': self._EncodedBytesFeature(tf_encoded),
- 'image/format': self._StringFeature(image_format)
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/encoded': self._EncodedBytesFeature(tf_encoded),
+ 'image/format': self._StringFeature(image_format)
+ }))
return image, example.SerializeToString()
@@ -168,8 +174,7 @@ class TFExampleDecoderTest(test.TestCase):
tf_decoded_image = self.DecodeExample(
serialized_example,
- tfexample_decoder.Image(
- shape=None, channels=channels),
+ tfexample_decoder.Image(shape=None, channels=channels),
image_format='jpeg')
self.assertEqual(tf_decoded_image.get_shape().ndims, 3)
@@ -225,27 +230,38 @@ class TFExampleDecoderTest(test.TestCase):
self.assertAllClose(image, decoded_image, atol=0)
- def testDecodeExampleWithJpegEncodingAt16BitCausesError(self):
+ def testDecodeExampleWithRawEncodingFloatDtype(self):
image_shape = (2, 3, 3)
- unused_image, serialized_example = self.GenerateImage(
+ image, serialized_example = self.GenerateImage(
+ image_format='raw', image_shape=image_shape, image_dtype=np.float32)
+
+ decoded_image = self.RunDecodeExample(
+ serialized_example,
+ tfexample_decoder.Image(shape=image_shape, dtype=dtypes.float32),
+ image_format='raw')
+
+ self.assertAllClose(image, decoded_image, atol=0)
+
+ def testDecodeExampleWithJpegEncodingAt16BitDoesNotCauseError(self):
+ image_shape = (2, 3, 3)
+ # Image has type uint8 but decoding at uint16 should not cause problems.
+ image, serialized_example = self.GenerateImage(
image_format='jpeg', image_shape=image_shape)
- # decode_raw support uint16 now so ValueError will be thrown instead.
- with self.assertRaisesRegexp(
- ValueError,
- 'true_fn and false_fn must have the same type: uint16, uint8'):
- unused_decoded_image = self.RunDecodeExample(
- serialized_example,
- tfexample_decoder.Image(dtype=dtypes.uint16),
- image_format='jpeg')
+ decoded_image = self.RunDecodeExample(
+ serialized_example,
+ tfexample_decoder.Image(dtype=dtypes.uint16),
+ image_format='jpeg')
+ self.assertAllClose(image, decoded_image, atol=1.001)
def testDecodeExampleWithStringTensor(self):
tensor_shape = (2, 3, 1)
np_array = np.array([[['ab'], ['cd'], ['ef']],
[['ghi'], ['jkl'], ['mnop']]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'labels': self._BytesFeature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'labels': self._BytesFeature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -259,7 +275,9 @@ class TFExampleDecoderTest(test.TestCase):
default_value=constant_op.constant(
'', shape=tensor_shape, dtype=dtypes.string))
}
- items_to_handlers = {'labels': tfexample_decoder.Tensor('labels'),}
+ items_to_handlers = {
+ 'labels': tfexample_decoder.Tensor('labels'),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_labels] = decoder.decode(serialized_example, ['labels'])
@@ -271,9 +289,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithFloatTensor(self):
np_array = np.random.rand(2, 3, 1).astype('f')
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'array': self._EncodedFloatFeature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'array': self._EncodedFloatFeature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -282,7 +301,9 @@ class TFExampleDecoderTest(test.TestCase):
keys_to_features = {
'array': parsing_ops.FixedLenFeature(np_array.shape, dtypes.float32)
}
- items_to_handlers = {'array': tfexample_decoder.Tensor('array'),}
+ items_to_handlers = {
+ 'array': tfexample_decoder.Tensor('array'),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_array] = decoder.decode(serialized_example, ['array'])
@@ -291,9 +312,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithInt64Tensor(self):
np_array = np.random.randint(1, 10, size=(2, 3, 1))
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'array': self._EncodedInt64Feature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'array': self._EncodedInt64Feature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -302,7 +324,9 @@ class TFExampleDecoderTest(test.TestCase):
keys_to_features = {
'array': parsing_ops.FixedLenFeature(np_array.shape, dtypes.int64)
}
- items_to_handlers = {'array': tfexample_decoder.Tensor('array'),}
+ items_to_handlers = {
+ 'array': tfexample_decoder.Tensor('array'),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_array] = decoder.decode(serialized_example, ['array'])
@@ -311,9 +335,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithVarLenTensor(self):
np_array = np.array([[[1], [2], [3]], [[4], [5], [6]]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'labels': self._EncodedInt64Feature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'labels': self._EncodedInt64Feature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -322,7 +347,9 @@ class TFExampleDecoderTest(test.TestCase):
keys_to_features = {
'labels': parsing_ops.VarLenFeature(dtype=dtypes.int64),
}
- items_to_handlers = {'labels': tfexample_decoder.Tensor('labels'),}
+ items_to_handlers = {
+ 'labels': tfexample_decoder.Tensor('labels'),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_labels] = decoder.decode(serialized_example, ['labels'])
@@ -332,9 +359,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithFixLenTensorWithShape(self):
np_array = np.array([[1, 2, 3], [4, 5, 6]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'labels': self._EncodedInt64Feature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'labels': self._EncodedInt64Feature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -342,12 +370,10 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'labels':
- parsing_ops.FixedLenFeature(
- np_array.shape, dtype=dtypes.int64),
+ parsing_ops.FixedLenFeature(np_array.shape, dtype=dtypes.int64),
}
items_to_handlers = {
- 'labels': tfexample_decoder.Tensor(
- 'labels', shape=np_array.shape),
+ 'labels': tfexample_decoder.Tensor('labels', shape=np_array.shape),
}
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
@@ -357,9 +383,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithVarLenTensorToDense(self):
np_array = np.array([[1, 2, 3], [4, 5, 6]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'labels': self._EncodedInt64Feature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'labels': self._EncodedInt64Feature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -369,8 +396,7 @@ class TFExampleDecoderTest(test.TestCase):
'labels': parsing_ops.VarLenFeature(dtype=dtypes.int64),
}
items_to_handlers = {
- 'labels': tfexample_decoder.Tensor(
- 'labels', shape=np_array.shape),
+ 'labels': tfexample_decoder.Tensor('labels', shape=np_array.shape),
}
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
@@ -382,12 +408,18 @@ class TFExampleDecoderTest(test.TestCase):
np_image = np.random.rand(2, 3, 1).astype('f')
np_labels = np.array([[[1], [2], [3]], [[4], [5], [6]]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image': self._EncodedFloatFeature(np_image),
- 'image/shape': self._EncodedInt64Feature(np.array(np_image.shape)),
- 'labels': self._EncodedInt64Feature(np_labels),
- 'labels/shape': self._EncodedInt64Feature(np.array(np_labels.shape)),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image':
+ self._EncodedFloatFeature(np_image),
+ 'image/shape':
+ self._EncodedInt64Feature(np.array(np_image.shape)),
+ 'labels':
+ self._EncodedInt64Feature(np_labels),
+ 'labels/shape':
+ self._EncodedInt64Feature(np.array(np_labels.shape)),
+ }))
serialized_example = example.SerializeToString()
@@ -401,11 +433,9 @@ class TFExampleDecoderTest(test.TestCase):
}
items_to_handlers = {
'image':
- tfexample_decoder.Tensor(
- 'image', shape_keys='image/shape'),
+ tfexample_decoder.Tensor('image', shape_keys='image/shape'),
'labels':
- tfexample_decoder.Tensor(
- 'labels', shape_keys='labels/shape'),
+ tfexample_decoder.Tensor('labels', shape_keys='labels/shape'),
}
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
@@ -419,14 +449,22 @@ class TFExampleDecoderTest(test.TestCase):
np_labels = np.array([[[1], [2], [3]], [[4], [5], [6]]])
height, width, depth = np_labels.shape
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image': self._EncodedFloatFeature(np_image),
- 'image/shape': self._EncodedInt64Feature(np.array(np_image.shape)),
- 'labels': self._EncodedInt64Feature(np_labels),
- 'labels/height': self._EncodedInt64Feature(np.array([height])),
- 'labels/width': self._EncodedInt64Feature(np.array([width])),
- 'labels/depth': self._EncodedInt64Feature(np.array([depth])),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image':
+ self._EncodedFloatFeature(np_image),
+ 'image/shape':
+ self._EncodedInt64Feature(np.array(np_image.shape)),
+ 'labels':
+ self._EncodedInt64Feature(np_labels),
+ 'labels/height':
+ self._EncodedInt64Feature(np.array([height])),
+ 'labels/width':
+ self._EncodedInt64Feature(np.array([width])),
+ 'labels/depth':
+ self._EncodedInt64Feature(np.array([depth])),
+ }))
serialized_example = example.SerializeToString()
@@ -442,8 +480,7 @@ class TFExampleDecoderTest(test.TestCase):
}
items_to_handlers = {
'image':
- tfexample_decoder.Tensor(
- 'image', shape_keys='image/shape'),
+ tfexample_decoder.Tensor('image', shape_keys='image/shape'),
'labels':
tfexample_decoder.Tensor(
'labels',
@@ -459,10 +496,12 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithSparseTensor(self):
np_indices = np.array([[1], [2], [5]])
np_values = np.array([0.1, 0.2, 0.6]).astype('f')
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'indices': self._EncodedInt64Feature(np_indices),
- 'values': self._EncodedFloatFeature(np_values),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'indices': self._EncodedInt64Feature(np_indices),
+ 'values': self._EncodedFloatFeature(np_values),
+ }))
serialized_example = example.SerializeToString()
@@ -472,7 +511,9 @@ class TFExampleDecoderTest(test.TestCase):
'indices': parsing_ops.VarLenFeature(dtype=dtypes.int64),
'values': parsing_ops.VarLenFeature(dtype=dtypes.float32),
}
- items_to_handlers = {'labels': tfexample_decoder.SparseTensor(),}
+ items_to_handlers = {
+ 'labels': tfexample_decoder.SparseTensor(),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_labels] = decoder.decode(serialized_example, ['labels'])
@@ -485,11 +526,13 @@ class TFExampleDecoderTest(test.TestCase):
np_indices = np.array([[1], [2], [5]])
np_values = np.array([0.1, 0.2, 0.6]).astype('f')
np_shape = np.array([6])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'indices': self._EncodedInt64Feature(np_indices),
- 'values': self._EncodedFloatFeature(np_values),
- 'shape': self._EncodedInt64Feature(np_shape),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'indices': self._EncodedInt64Feature(np_indices),
+ 'values': self._EncodedFloatFeature(np_values),
+ 'shape': self._EncodedInt64Feature(np_shape),
+ }))
serialized_example = example.SerializeToString()
@@ -515,10 +558,12 @@ class TFExampleDecoderTest(test.TestCase):
np_indices = np.array([[1], [2], [5]])
np_values = np.array([0.1, 0.2, 0.6]).astype('f')
np_shape = np.array([6])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'indices': self._EncodedInt64Feature(np_indices),
- 'values': self._EncodedFloatFeature(np_values),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'indices': self._EncodedInt64Feature(np_indices),
+ 'values': self._EncodedFloatFeature(np_values),
+ }))
serialized_example = example.SerializeToString()
@@ -544,10 +589,12 @@ class TFExampleDecoderTest(test.TestCase):
np_values = np.array([0.1, 0.2, 0.6]).astype('f')
np_shape = np.array([6])
np_dense = np.array([0.0, 0.1, 0.2, 0.0, 0.0, 0.6]).astype('f')
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'indices': self._EncodedInt64Feature(np_indices),
- 'values': self._EncodedFloatFeature(np_values),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'indices': self._EncodedInt64Feature(np_indices),
+ 'values': self._EncodedFloatFeature(np_values),
+ }))
serialized_example = example.SerializeToString()
@@ -559,8 +606,7 @@ class TFExampleDecoderTest(test.TestCase):
}
items_to_handlers = {
'labels':
- tfexample_decoder.SparseTensor(
- shape=np_shape, densify=True),
+ tfexample_decoder.SparseTensor(shape=np_shape, densify=True),
}
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
@@ -572,9 +618,10 @@ class TFExampleDecoderTest(test.TestCase):
tensor_shape = (2, 3, 1)
np_array = np.random.rand(2, 3, 1)
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/depth_map': self._EncodedFloatFeature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'image/depth_map': self._EncodedFloatFeature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -603,9 +650,10 @@ class TFExampleDecoderTest(test.TestCase):
tensor_shape = (2, 3, 1)
np_array = np.random.rand(2, 3, 1)
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/depth_map': self._EncodedFloatFeature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'image/depth_map': self._EncodedFloatFeature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -701,12 +749,14 @@ class TFExampleDecoderTest(test.TestCase):
np_xmax = np.random.rand(num_bboxes, 1)
np_bboxes = np.hstack([np_ymin, np_xmin, np_ymax, np_xmax])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin),
- 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin),
- 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax),
- 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin),
+ 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin),
+ 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax),
+ 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax),
+ }))
serialized_example = example.SerializeToString()
with self.test_session():
@@ -740,26 +790,32 @@ class TFExampleDecoderTest(test.TestCase):
np_xmax = np.random.rand(num_bboxes, 1)
np_bboxes = np.hstack([np_ymin, np_xmin, np_ymax, np_xmax])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin),
- 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin),
- 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax),
- 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin),
+ 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin),
+ 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax),
+ 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax),
+ }))
serialized_example = example.SerializeToString()
with self.test_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
- 'image/object/bbox/ymin': parsing_ops.FixedLenSequenceFeature(
- [], dtypes.float32, allow_missing=True),
- 'image/object/bbox/xmin': parsing_ops.FixedLenSequenceFeature(
- [], dtypes.float32, allow_missing=True),
- 'image/object/bbox/ymax': parsing_ops.FixedLenSequenceFeature(
- [], dtypes.float32, allow_missing=True),
- 'image/object/bbox/xmax': parsing_ops.FixedLenSequenceFeature(
- [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/ymin':
+ parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/xmin':
+ parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/ymax':
+ parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/xmax':
+ parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
}
items_to_handlers = {
@@ -784,11 +840,16 @@ class TFExampleDecoderTest(test.TestCase):
with self.test_session():
tf_string = tf_encoded.eval()
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/encoded': feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
- value=[tf_string, tf_string])),
- 'image/format': self._StringFeature(image_format),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/encoded':
+ feature_pb2.Feature(
+ bytes_list=feature_pb2.BytesList(
+ value=[tf_string, tf_string])),
+ 'image/format':
+ self._StringFeature(image_format),
+ }))
serialized_example = example.SerializeToString()
with self.test_session():
@@ -797,8 +858,7 @@ class TFExampleDecoderTest(test.TestCase):
decoder = tfexample_decoder.TFExampleDecoder(
keys_to_features={
'image/encoded':
- parsing_ops.FixedLenFeature(
- (2,), dtypes.string),
+ parsing_ops.FixedLenFeature((2,), dtypes.string),
'image/format':
parsing_ops.FixedLenFeature(
(), dtypes.string, default_value=image_format),
@@ -814,10 +874,12 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithLookup(self):
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/object/class/text': self._BytesFeature(
- np.array(['cat', 'dog', 'guinea pig'])),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/object/class/text':
+ self._BytesFeature(np.array(['cat', 'dog', 'guinea pig'])),
+ }))
serialized_example = example.SerializeToString()
# 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2
table = lookup_ops.index_table_from_tensor(
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 56e451e2e3..298ffc1ded 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -16,6 +16,7 @@ package(
"//cloud/vmm/testing/tests/tpu:__subpackages__",
"//learning/brain:__subpackages__",
"//learning/deepmind:__subpackages__",
+ "//medical/pathology:__subpackages__",
"//tensorflow:__subpackages__",
],
)
@@ -166,6 +167,7 @@ py_library(
name = "keras_support",
srcs = [
"python/tpu/keras_support.py",
+ "python/tpu/keras_tpu_variables.py",
],
srcs_version = "PY2AND3",
visibility = [
diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
index 06553929dc..9ee5ecb123 100644
--- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
+++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
@@ -21,9 +21,9 @@ namespace tensorflow {
REGISTER_OP("CrossReplicaSum")
.Input("input: T")
+ .Input("group_assignment: int32")
.Output("output: T")
.Attr("T: {bfloat16, float}")
- .Attr("group_assignment: list(int) = []")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
An Op to sum inputs across replicated TPU instances. Each
@@ -31,15 +31,17 @@ instance supplies its own input. If group_assignment is empty, the output of
each is the sum of all the inputs, otherwise the output of each is the sum of
the inputs belonging to the same group.
-For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
-group_assignment=`[0,1,0,1]` sets `A, C` as group 0, and `B, D` as group 1.
-Thus we get the outputs: `[A+C, B+D, A+C, B+D]`.
+For example, suppose there are 8 TPU instances: `[A, B, C, D, E, F, G, H]`.
+Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0,
+and `B, D, F, H` as group 1. Thus we get the outputs:
+`[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`.
input: The local input to the sum.
+group_assignment: An int32 tensor with shape
+ [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the
+ replica ids in the ith subgroup.
output: The sum of all the distributed inputs.
T: The type of elements to be summed.
-group_assignment: The list of group ids. `group_assignment[i]` represents the
- group id of replica i.
)doc");
} // namespace tensorflow
diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
index 8e6e9aa0cd..b498599962 100644
--- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
@@ -156,7 +156,8 @@ bool NewSession(const string& service_addr,
channel_args));
NewProfileSessionResponse new_session_response;
TF_QCHECK_OK(FromGrpcStatus(
- stub->NewSession(&context, new_session_request, &new_session_response)));
+ stub->NewSession(&context, new_session_request, &new_session_response)))
+ << new_session_response.error_message();
std::cout << "Profile session succeed for host(s):"
<< str_util::Join(hostnames, ",") << std::endl;
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
index bf442d9116..3ed571aff9 100644
--- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py
+++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
@@ -21,8 +21,10 @@ from __future__ import print_function
import platform
+from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.platform import tf_logging as logging
if platform.system() != "Windows":
# pylint: disable=wildcard-import,unused-import,g-import-not-at-top
@@ -36,10 +38,35 @@ if platform.system() != "Windows":
_tpu_ops = loader.load_op_library(
resource_loader.get_path_to_datafile("_tpu_ops.so"))
+ def cross_replica_sum(x, group_assignment=None, name=None):
+ """Sum the input tensor accorss replicas according to group_assignment.
+
+ Args:
+ x: The local tensor to the sum.
+ group_assignment: Optional 2d int32 lists with shape [num_groups,
+ num_replicas_per_group]. `group_assignment[i]` represents the replica
+ ids in the ith subgroup.
+ name: Optional op name.
+
+ Returns:
+ A `Tensor` which is summed across replicas.
+ """
+ if group_assignment is None:
+ num_shards = tpu_function.get_tpu_context().number_of_shards
+ if num_shards is None:
+ logging.warning(
+ "cross_replica_sum should be used within a tpu_shard_context, but "
+ "got unset number_of_shards. Assuming 1.")
+ num_shards = 1
+ group_assignment = [list(range(num_shards))]
+
+ return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
+
@ops.RegisterGradient("CrossReplicaSum")
def _cross_replica_sum_grad(op, grad):
# The gradient of a cross replica sum is also a cross-replica sum.
- return gen_tpu_ops.cross_replica_sum(grad, op.get_attr("group_assignment"))
+ # The graident with respect to group_assignment is None.
+ return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
# This extra type checking exists to give a more helpful error message in
# the common case that uint8 and int64 values are infed. Remove when both
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index 87b900574c..ff88508d03 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -58,29 +58,38 @@ from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_reso
from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result
from tensorflow.contrib.tpu.python.ops import tpu_ops
+from tensorflow.contrib.tpu.python.tpu import keras_tpu_variables
from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
+from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as tf_session
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras import models
from tensorflow.python.keras import optimizers as keras_optimizers
from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.keras.engine import training_arrays
+from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.layers import embeddings
+from tensorflow.python.keras.utils.generic_utils import make_batches
+from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.util import tf_inspect
_SESSIONS = {}
@@ -96,9 +105,9 @@ def tpu_session(cluster_resolver):
if cluster_spec:
config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
+ logging.info('Connecting to: %s', master)
graph = ops.Graph()
session = tf_session.Session(graph=graph, target=master, config=config)
-
with graph.as_default():
session.run(tpu.initialize_system())
@@ -109,32 +118,64 @@ def tpu_session(cluster_resolver):
def reset_tpu_sessions():
_SESSIONS.clear()
+try:
+ from scipy.sparse import issparse # pylint: disable=g-import-not-at-top
+except ImportError:
+ issparse = None
-# Work-around dependency cycle between DistributionStrategy and TPU lib.
-def TPUDistributionStrategy(tpu_cluster_resolver=None, num_cores=None): # pylint: disable=invalid-name
- """Construct a TPUDistributionStrategy."""
- from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top
- # TODO(b/112705069): Remove this when TPUStrategy API is consistent.
- # We are including this for (a) backwards compatibility for open sourced
- # releases of TensorFlow and (b) to work around a circular dependency
- # where keras_support and tpu_strategy depends on each other. Once we release
- # a final version and remove support for the old API, this will be deleted.
- # (See bug above for more details)
- if tpu_cluster_resolver is None:
- tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('')
-
- args, _, _, _ = tf_inspect.getargspec(tpu_strategy.TPUStrategy.__init__)
- if len(args) == 4:
- logging.info('Detected new TPUStrategy API.')
- return tpu_strategy.TPUStrategy(tpu_cluster_resolver,
- steps_per_run=1,
- num_cores=num_cores)
- else:
- logging.info('Detected old TPUStrategy API.')
- strategy = tpu_strategy.TPUStrategy(num_cores_per_host=8)
- strategy._tpu_cluster_resolver = tpu_cluster_resolver
- return strategy
+def get_tpu_system_metadata(tpu_cluster_resolver):
+ """Retrieves TPU system metadata given a TPUClusterResolver."""
+ master = tpu_cluster_resolver.master()
+
+ # pylint: disable=protected-access
+ cluster_spec = tpu_cluster_resolver.cluster_spec()
+ cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
+ tpu_system_metadata = (
+ tpu_system_metadata_lib._query_tpu_system_metadata(
+ master,
+ cluster_def=cluster_def,
+ query_topology=False))
+
+ return tpu_system_metadata
+
+
+class TPUDistributionStrategy(object):
+ """The strategy to run Keras model on TPU."""
+
+ def __init__(self, tpu_cluster_resolver=None, using_single_core=False):
+ """Construct a TPUDistributionStrategy.
+
+ Args:
+ tpu_cluster_resolver: Any instance of `TPUClusterResolver`. If None, will
+ create one with '' as master address.
+ using_single_core: Bool. This is the debugging option, which might be
+ removed in future once the model replication functionality is mature
+ enough. If `False` (default behavior), the system automatically finds
+ the best configuration, in terms of number of TPU cores, for the model
+ replication, typically using all avaiable TPU cores. If overwrites as
+ `True`, force the model replication using single core, i.e., no
+ replication.
+ """
+
+ if tpu_cluster_resolver is None:
+ tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('')
+
+ metadata = get_tpu_system_metadata(tpu_cluster_resolver)
+ self._tpu_metadata = metadata
+ self._tpu_cluster_resolver = tpu_cluster_resolver
+ self._num_cores = 1 if using_single_core else metadata.num_cores
+
+ # Walk device list to identify TPU worker for enqueue/dequeue operations.
+ worker_re = re.compile('/job:([^/]+)')
+ for device in metadata.devices:
+ if 'TPU:0' in device.name:
+ self.worker_name = worker_re.search(device.name).group(1)
+ break
+
+ @property
+ def num_towers(self):
+ return self._num_cores
class TPUEmbedding(embeddings.Embedding):
@@ -493,7 +534,7 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
shard_infeed_tensors = []
for shard_id in range(self._strategy.num_towers):
- with ops.device('/device:CPU:0'):
+ with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
infeed_tensors = []
with ops.device('/device:TPU:%d' % shard_id):
for spec in input_specs:
@@ -638,7 +679,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
assert len(shard_infeed_tensors) == self._strategy.num_towers
infeed_ops = []
for shard_id in range(self._strategy.num_towers):
- with ops.device('/device:CPU:0'):
+ with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
infeed_ops.append(
tpu_ops.infeed_enqueue_tuple(
shard_infeed_tensors[shard_id],
@@ -716,8 +757,7 @@ class TPUFunction(object):
# Clone our CPU model, running within the TPU device context.
with TPURewriteContext(tpu_input_map):
with variable_scope.variable_scope('tpu_model_%s' % id(self.model)):
- # TODO(power): Replicate variables.
- with ops.device('/device:TPU:0'):
+ with keras_tpu_variables.replicated_scope(self._strategy.num_towers):
self._cloned_model = models.clone_model(self.model)
# Create a copy of the optimizer for this graph.
@@ -796,7 +836,7 @@ class TPUFunction(object):
# Build output ops.
outfeed_op = []
for shard_id in range(self._strategy.num_towers):
- with ops.device('/device:CPU:0'):
+ with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
outfeed_op.extend(
tpu_ops.outfeed_dequeue_tuple(
dtypes=[spec.dtype for spec in self._outfeed_spec],
@@ -814,7 +854,7 @@ class TPUFunction(object):
def _test_model_compiles(self, tpu_model_ops):
"""Verifies that the given TPUModelOp can be compiled via XLA."""
logging.info('Started compiling')
- start_time = time.clock()
+ start_time = time.time()
result = K.get_session().run(tpu_model_ops.compile_op)
proto = tpu_compilation_result.CompilationResultProto()
@@ -823,38 +863,52 @@ class TPUFunction(object):
raise RuntimeError('Compilation failed: {}'.format(
proto.status_error_message))
- end_time = time.clock()
+ end_time = time.time()
logging.info('Finished compiling. Time elapsed: %s secs',
end_time - start_time)
- def __call__(self, inputs):
- assert isinstance(inputs, list)
+ def _lookup_infeed_manager(self, inputs):
+ """Return an existing manager, or construct a new InfeedManager for inputs.
+
+ _lookup_infeed_manager will return an existing InfeedManager if one has been
+ previously assigned for this model and input. If not, it will construct a
+ new TPUNumpyInfeedManager.
+
+ Args:
+ inputs: A NumPy input to the model.
+
+ Returns:
+ A `TPUInfeedManager` object to manage infeeds for this input.
+ """
+ if inputs is None:
+ return None
- infeed_manager = None
for x, mgr in self.model._numpy_to_infeed_manager_list:
if inputs[0] is x:
- infeed_manager = mgr
- break
- if infeed_manager is None:
- infeed_manager = TPUNumpyInfeedManager(self.model._strategy)
+ return mgr
+ return TPUNumpyInfeedManager(self.model._strategy)
- # Strip sample weight from inputs
- if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or
- self.execution_mode == model_fn_lib.ModeKeys.EVAL):
- input_tensors = self.model._feed_inputs + self.model._feed_targets
- inputs = inputs[:len(input_tensors)]
- else:
- input_tensors = self.model._feed_inputs
+ def _tpu_model_ops_for_input_specs(self, input_specs, infeed_manager):
+ """Looks up the corresponding `TPUModelOp` for a given `input_specs`.
- infeed_instance = infeed_manager.make_infeed_instance(inputs)
- del inputs # To avoid accident usage.
- input_specs = infeed_instance.make_input_specs(input_tensors)
+ It instantiates a new copy of the model for each unique input shape.
+
+ Args:
+ input_specs: The specification of the inputs to train on.
+ infeed_manager: The infeed manager responsible for feeding in data.
+
+ Returns:
+ A `TPUModelOp` instance that can be used to execute a step of the model.
+ """
+ if input_specs is None or infeed_manager is None:
+ # Note: this condition is possible during the prologue or epilogue of the
+ # pipelined loop.
+ return None
# XLA requires every operation in the graph has a fixed shape. To
# handle varying batch sizes we recompile a new sub-graph for each
# unique input shape.
shape_key = tuple([tuple(spec.shape.as_list()) for spec in input_specs])
-
if shape_key not in self._compilation_cache:
with self.model.tpu_session():
logging.info('New input shapes; (re-)compiling: mode=%s, %s',
@@ -864,19 +918,42 @@ class TPUFunction(object):
self._compilation_cache[shape_key] = new_tpu_model_ops
self._test_model_compiles(new_tpu_model_ops)
- # Initialize our TPU weights on the first compile.
- self.model._initialize_weights(self._cloned_model)
- tpu_model_ops = self._compilation_cache[shape_key]
+ return self._compilation_cache[shape_key]
- infeed_dict = infeed_instance.make_feed_dict(tpu_model_ops)
+ def _construct_input_tensors_and_inputs(self, inputs):
+ """Returns input tensors and numpy array inputs corresponding to `inputs`.
- with self.model.tpu_session() as session:
- _, _, outfeed_outputs = session.run([
- tpu_model_ops.infeed_op, tpu_model_ops.execute_op,
- tpu_model_ops.outfeed_op
- ], infeed_dict)
+ Args:
+ inputs: NumPy inputs.
+
+ Returns:
+ A tuple of `input_tensors`, and `inputs`.
+ """
+ if inputs is None:
+ # Note: this condition is possible during the prologue or epilogue of the
+ # pipelined loop.
+ return None, None
+ # Strip sample weight from inputs
+ if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or
+ self.execution_mode == model_fn_lib.ModeKeys.EVAL):
+ input_tensors = self.model._feed_inputs + self.model._feed_targets
+ inputs = inputs[:len(input_tensors)]
+ return input_tensors, inputs
+ else:
+ input_tensors = self.model._feed_inputs
+ return input_tensors, inputs
+
+ def _process_outputs(self, outfeed_outputs):
+ """Processes the outputs of a model function execution.
+
+ Args:
+ outfeed_outputs: The sharded outputs of the TPU computation.
- # TODO(xiejw): Decide how to reduce outputs, or just discard all but first.
+ Returns:
+ The aggregated outputs of the TPU computation to be used in the rest of
+ the model execution.
+ """
+ # TODO(xiejw): Decide how to reduce outputs, or discard all but first.
if self.execution_mode == model_fn_lib.ModeKeys.PREDICT:
outputs = [[]] * len(self._outfeed_spec)
outputs_per_replica = len(self._outfeed_spec)
@@ -889,7 +966,139 @@ class TPUFunction(object):
return [np.concatenate(group) for group in outputs]
else:
- return outfeed_outputs[:len(outfeed_outputs) // self._strategy.num_towers]
+ return outfeed_outputs[:len(outfeed_outputs) //
+ self._strategy.num_towers]
+
+ def __call__(self, inputs):
+ """__call__ executes the function on the computational hardware.
+
+ It handles executing infeed, and preprocessing in addition to executing the
+ model on the TPU hardware.
+
+ Note: `__call__` has a sibling method `pipeline_run` which performs the same
+ operations, but with software pipelining.
+
+ Args:
+ inputs: The inputs to use to train.
+
+ Returns:
+ The output of the computation for the given mode it is executed in.
+
+ Raises:
+ RuntimeError: If there is an inappropriate use of the function.
+ """
+ assert isinstance(inputs, list)
+
+ infeed_manager = self._lookup_infeed_manager(inputs)
+ input_tensors, inputs = self._construct_input_tensors_and_inputs(inputs)
+ infeed_instance = infeed_manager.make_infeed_instance(inputs)
+ del inputs # To avoid accident usage.
+ input_specs = infeed_instance.make_input_specs(input_tensors)
+ tpu_model_ops = self._tpu_model_ops_for_input_specs(input_specs,
+ infeed_manager)
+ infeed_dict = infeed_instance.make_feed_dict(tpu_model_ops)
+
+ # Initialize our TPU weights on the first compile.
+ self.model._initialize_weights(self._cloned_model)
+
+ with self.model.tpu_session() as session:
+ _, _, outfeed_outputs = session.run([
+ tpu_model_ops.infeed_op, tpu_model_ops.execute_op,
+ tpu_model_ops.outfeed_op
+ ], infeed_dict)
+ return self._process_outputs(outfeed_outputs)
+
+ def pipeline_run(self, cur_step_inputs, next_step_inputs):
+ """pipeline_run executes the function on the computational hardware.
+
+ pipeline_run performs the same computation as __call__, however it runs the
+ infeed in a software pipelined fashion compared to the on-device execution.
+
+ Note: it is the responsibility of the caller to call `pipeline_run` in the
+ following sequence:
+ - Once with `cur_step_inputs=None` and `next_step_inputs=list(...)`
+ - `n` times with `cur_step_inputs` and `next_step_inputs` as `list`s
+ - Once with `cur_step_inputs=list(...)` and `next_step_inputs=None`
+ Additionally, it is the responsibility of the caller to pass
+ `next_step_inputs` as `cur_step_inputs` on the next invocation of
+ `pipeline_run`.
+
+ Args:
+ cur_step_inputs: The current step's inputs.
+ next_step_inputs: The next step's inputs.
+
+ Returns:
+ The output of the computation for the given mode it is executed in.
+
+ Raises:
+ RuntimeError: If there is an inappropriate use of the function.
+ """
+ # Software pipelined case.
+ next_step_infeed_manager = self._lookup_infeed_manager(next_step_inputs)
+ cur_step_infeed_manager = self._lookup_infeed_manager(cur_step_inputs)
+
+ if (next_step_infeed_manager is not None
+ and cur_step_infeed_manager is not None):
+ assert type(next_step_infeed_manager) is type(cur_step_infeed_manager)
+
+ next_input_tensors, next_step_inputs = (
+ self._construct_input_tensors_and_inputs(next_step_inputs))
+ cur_input_tensors, cur_step_inputs = (
+ self._construct_input_tensors_and_inputs(cur_step_inputs))
+
+ cur_infeed_instance = None
+ if cur_step_infeed_manager:
+ cur_infeed_instance = cur_step_infeed_manager.make_infeed_instance(
+ cur_step_inputs)
+ next_infeed_instance = None
+ if next_step_infeed_manager:
+ next_infeed_instance = next_step_infeed_manager.make_infeed_instance(
+ next_step_inputs)
+
+ del cur_step_inputs # Avoid accidental re-use.
+ del next_step_inputs # Avoid accidental re-use.
+
+ cur_tpu_model_ops = None
+ next_tpu_model_ops = None
+ infeed_dict = None
+
+ if cur_infeed_instance and cur_input_tensors and cur_step_infeed_manager:
+ cur_input_specs = cur_infeed_instance.make_input_specs(
+ cur_input_tensors)
+ cur_tpu_model_ops = self._tpu_model_ops_for_input_specs(
+ cur_input_specs, cur_step_infeed_manager)
+
+ if (next_infeed_instance
+ and next_input_tensors
+ and next_step_infeed_manager):
+ next_input_specs = next_infeed_instance.make_input_specs(
+ next_input_tensors)
+ next_tpu_model_ops = self._tpu_model_ops_for_input_specs(
+ next_input_specs, next_step_infeed_manager)
+ infeed_dict = next_infeed_instance.make_feed_dict(next_tpu_model_ops)
+
+ # Initialize our TPU weights on the first compile.
+ self.model._initialize_weights(self._cloned_model)
+
+ if next_tpu_model_ops and cur_tpu_model_ops:
+ with self.model.tpu_session() as session:
+ _, _, outfeed_outputs = session.run([
+ next_tpu_model_ops.infeed_op, cur_tpu_model_ops.execute_op,
+ cur_tpu_model_ops.outfeed_op
+ ], infeed_dict)
+ return self._process_outputs(outfeed_outputs)
+ if cur_tpu_model_ops:
+ with self.model.tpu_session() as session:
+ _, outfeed_outputs = session.run([
+ cur_tpu_model_ops.execute_op, cur_tpu_model_ops.outfeed_op])
+ return self._process_outputs(outfeed_outputs)
+ if next_tpu_model_ops:
+ with self.model.tpu_session() as session:
+ session.run(next_tpu_model_ops.infeed_op, infeed_dict)
+ return None
+ raise RuntimeError('Internal error: both current & next tpu_model_ops '
+ 'were None')
+
class KerasTPUModel(models.Model):
@@ -919,7 +1128,6 @@ class KerasTPUModel(models.Model):
self._tpu_weights_initialized = False
self._session = tpu_session(cluster_resolver)
- self._graph = self._session.graph
# If the input CPU model has already been compiled, compile our TPU model
# immediately.
@@ -982,6 +1190,10 @@ class KerasTPUModel(models.Model):
steps_per_epoch=None,
validation_steps=None,
**kwargs):
+ if context.executing_eagerly():
+ raise EnvironmentError('KerasTPUModel currently does not support eager '
+ 'mode.')
+
assert not self._numpy_to_infeed_manager_list # Ensure empty.
infeed_managers = [] # Managers to clean up at the end of the fit call.
@@ -994,7 +1206,8 @@ class KerasTPUModel(models.Model):
'https://github.com/tensorflow/tpu/tree/master/models/experimental'
'/keras')
if callable(x):
- with self.tpu_session() as sess:
+ with self.tpu_session() as sess,\
+ ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
dataset = x()
if steps_per_epoch is None:
raise ValueError('When using tf.data as input to a model, you '
@@ -1033,7 +1246,28 @@ class KerasTPUModel(models.Model):
self._numpy_to_infeed_manager_list = infeed_managers
try:
- return super(KerasTPUModel, self).fit(
+ if not kwargs.get('_pipeline', True):
+ logging.info(
+ 'Running non-pipelined training loop (`_pipeline=%s`).',
+ kwargs['_pipeline'])
+ kwargs.pop('_pipeline')
+ return super(KerasTPUModel, self).fit(
+ x,
+ y,
+ batch_size,
+ epochs,
+ verbose,
+ callbacks,
+ validation_split,
+ validation_data,
+ shuffle,
+ class_weight,
+ sample_weight,
+ initial_epoch,
+ steps_per_epoch,
+ validation_steps,
+ **kwargs)
+ return self._pipeline_fit(
x,
y,
batch_size,
@@ -1052,6 +1286,457 @@ class KerasTPUModel(models.Model):
finally:
self._numpy_to_infeed_manager_list = []
+ def evaluate(self,
+ x=None,
+ y=None,
+ batch_size=None,
+ verbose=1,
+ sample_weight=None,
+ steps=None):
+ assert not self._numpy_to_infeed_manager_list # Ensure empty.
+
+ infeed_managers = [] # Managers to clean up at the end of the fit call.
+ if isinstance(x, dataset_ops.Dataset):
+ # TODO(b/111413240): Support taking a tf.data.Dataset directly.
+ raise ValueError(
+ 'Taking a Dataset directly is not yet supported. Please '
+ 'wrap your dataset construction code in a function and '
+ 'pass that to fit instead. For examples, see: '
+ 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
+ '/keras')
+ if callable(x):
+ with self.tpu_session() as sess:
+ dataset = x()
+ if steps is None:
+ raise ValueError('When using tf.data as input to a model, you '
+ 'should specify the steps argument.')
+ if y is not None:
+ raise ValueError('When using tf.data as input to a model, y must be '
+ 'None')
+ infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess)
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ x = infeed_manager.dummy_x
+ y = infeed_manager.dummy_y
+ infeed_managers.append((x, infeed_manager))
+
+ self._numpy_to_infeed_manager_list = infeed_managers
+ try:
+ return super(KerasTPUModel, self).evaluate(
+ x,
+ y,
+ batch_size,
+ verbose,
+ sample_weight,
+ steps)
+ finally:
+ self._numpy_to_infeed_manager_list = []
+
+ def _pipeline_fit(self,
+ x,
+ y,
+ batch_size,
+ epochs,
+ verbose,
+ callbacks,
+ validation_split,
+ validation_data,
+ shuffle,
+ class_weight,
+ sample_weight,
+ initial_epoch,
+ steps_per_epoch,
+ validation_steps,
+ **kwargs):
+ # Similar to super.fit(...), but modified to support software pipelining.
+
+ # Backwards compatibility
+ if batch_size is None and steps_per_epoch is None:
+ batch_size = 32
+ # Legacy support
+ if 'nb_epoch' in kwargs:
+ logging.warning('The `nb_epoch` argument in `fit` has been renamed '
+ '`epochs`.')
+ epochs = kwargs.pop('nb_epoch')
+ if kwargs:
+ raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
+
+ # Validate and standardize user data
+ x, y, sample_weights = self._standardize_user_data(
+ x,
+ y,
+ sample_weight=sample_weight,
+ class_weight=class_weight,
+ batch_size=batch_size,
+ check_steps=True,
+ steps_name='steps_per_epoch',
+ steps=steps_per_epoch,
+ validation_split=validation_split)
+
+ # Prepare validation data
+ val_x, val_y, val_sample_weights = self._prepare_validation_data(
+ validation_data,
+ validation_split,
+ validation_steps,
+ x,
+ y,
+ sample_weights,
+ batch_size)
+ self._pipeline_fit_loop(
+ x,
+ y,
+ sample_weights=sample_weights,
+ batch_size=batch_size,
+ epochs=epochs,
+ verbose=verbose,
+ callbacks=callbacks,
+ val_inputs=val_x,
+ val_targets=val_y,
+ val_sample_weights=val_sample_weights,
+ shuffle=shuffle,
+ initial_epoch=initial_epoch,
+ steps_per_epoch=steps_per_epoch,
+ validation_steps=validation_steps)
+
+ def _pipeline_fit_loop(self,
+ inputs,
+ targets,
+ sample_weights,
+ batch_size,
+ epochs,
+ verbose,
+ callbacks,
+ val_inputs,
+ val_targets,
+ val_sample_weights,
+ shuffle,
+ initial_epoch,
+ steps_per_epoch,
+ validation_steps):
+ self._make_train_function()
+ sample_weights = sample_weights or []
+ val_sample_weights = val_sample_weights or []
+ if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = inputs + targets + sample_weights + [1]
+ else:
+ ins = inputs + targets + sample_weights
+
+ do_validation = False
+ if val_inputs:
+ do_validation = True
+ if (steps_per_epoch is None and verbose and inputs and
+ hasattr(inputs[0], 'shape') and hasattr(val_inputs[0], 'shape')):
+ print('Train on %d samples, validate on %d samples' %
+ (inputs[0].shape[0], val_inputs[0].shape[0]))
+
+ if validation_steps:
+ do_validation = True
+ if steps_per_epoch is None:
+ raise ValueError('Can only use `validation_steps` when doing step-wise '
+ 'training, i.e. `steps_per_epoch` must be set.')
+
+ num_training_samples = training_utils.check_num_samples(
+ ins, batch_size, steps_per_epoch, 'steps_per_epoch')
+ count_mode = 'steps' if steps_per_epoch else 'samples'
+ callbacks = cbks.configure_callbacks(
+ callbacks,
+ self,
+ do_validation=do_validation,
+ val_inputs=val_inputs,
+ val_targets=val_targets,
+ val_sample_weights=val_sample_weights,
+ batch_size=batch_size,
+ epochs=epochs,
+ steps_per_epoch=steps_per_epoch,
+ samples=num_training_samples,
+ validation_steps=validation_steps,
+ verbose=verbose,
+ count_mode=count_mode)
+
+ if num_training_samples is not None:
+ index_array = np.arange(num_training_samples)
+
+ # To prevent a slowdown, we find beforehand the arrays that need conversion.
+ feed = self._feed_inputs + self._feed_targets + self._feed_sample_weights
+ indices_for_conversion_to_dense = []
+ for i in range(len(feed)):
+ if issparse is not None and issparse(ins[i]) and not K.is_sparse(feed[i]):
+ indices_for_conversion_to_dense.append(i)
+
+ callbacks.on_train_begin()
+ for epoch in range(initial_epoch, epochs):
+ # Reset stateful metrics
+ for m in self.stateful_metric_functions:
+ m.reset_states()
+ # Update callbacks
+ callbacks.on_epoch_begin(epoch)
+ epoch_logs = {}
+ if steps_per_epoch is not None:
+ # Step-wise fit loop.
+ self._pipeline_fit_loop_step_wise(
+ ins=ins,
+ callbacks=callbacks,
+ steps_per_epoch=steps_per_epoch,
+ epochs=epochs,
+ do_validation=do_validation,
+ val_inputs=val_inputs,
+ val_targets=val_targets,
+ val_sample_weights=val_sample_weights,
+ validation_steps=validation_steps,
+ epoch_logs=epoch_logs)
+ else:
+ # Sample-wise fit loop.
+ self._pipeline_fit_loop_sample_wise(
+ ins=ins,
+ callbacks=callbacks,
+ index_array=index_array,
+ shuffle=shuffle,
+ batch_size=batch_size,
+ num_training_samples=num_training_samples,
+ indices_for_conversion_to_dense=indices_for_conversion_to_dense,
+ do_validation=do_validation,
+ val_inputs=val_inputs,
+ val_targets=val_targets,
+ val_sample_weights=val_sample_weights,
+ validation_steps=validation_steps,
+ epoch_logs=epoch_logs)
+
+ callbacks.on_epoch_end(epoch, epoch_logs)
+ if callbacks.model.stop_training:
+ break
+ callbacks.on_train_end()
+ return self.history
+
+ def _pipeline_fit_loop_sample_wise(self,
+ ins,
+ callbacks,
+ index_array,
+ shuffle,
+ batch_size,
+ num_training_samples,
+ indices_for_conversion_to_dense,
+ do_validation,
+ val_inputs,
+ val_targets,
+ val_sample_weights,
+ validation_steps,
+ epoch_logs):
+ f = self.train_function
+ if shuffle == 'batch':
+ index_array = training_utils.batch_shuffle(index_array, batch_size)
+ elif shuffle:
+ np.random.shuffle(index_array)
+ batches = make_batches(num_training_samples, batch_size)
+
+ ins_last_batch = None
+ last_batch_logs = None
+ batch_index = 0
+
+ for batch_index, (batch_start, batch_end) in enumerate(batches):
+ batch_ids = index_array[batch_start:batch_end]
+ try:
+ if isinstance(ins[-1], int):
+ # Do not slice the training phase flag.
+ ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
+ else:
+ ins_batch = slice_arrays(ins, batch_ids)
+ except TypeError:
+ raise TypeError('TypeError while preparing batch. If using HDF5 '
+ 'input data, pass shuffle="batch".')
+
+ # Pipeline batch logs
+ next_batch_logs = {}
+ next_batch_logs['batch'] = batch_index
+ next_batch_logs['size'] = len(batch_ids)
+ if batch_index > 0:
+ # Callbacks operate one step behind in software pipeline.
+ callbacks.on_batch_begin(batch_index - 1, last_batch_logs)
+ for i in indices_for_conversion_to_dense:
+ ins_batch[i] = ins_batch[i].toarray()
+
+ outs = f.pipeline_run(cur_step_inputs=ins_last_batch,
+ next_step_inputs=ins_batch)
+ ins_last_batch = ins_batch
+
+ if batch_index == 0:
+ assert outs is None
+ else:
+ if not isinstance(outs, list):
+ outs = [outs]
+ for l, o in zip(self.metrics_names, outs):
+ last_batch_logs[l] = o # pylint: disable=unsupported-assignment-operation
+ callbacks.on_batch_end(batch_index - 1, last_batch_logs)
+ if callbacks.model.stop_training:
+ return
+ last_batch_logs = next_batch_logs
+
+ # Final batch
+ callbacks.on_batch_begin(batch_index, last_batch_logs)
+ outs = f.pipeline_run(cur_step_inputs=ins_last_batch, next_step_inputs=None)
+ if not isinstance(outs, list):
+ outs = [outs]
+ for l, o in zip(self.metrics_names, outs):
+ last_batch_logs[l] = o
+ callbacks.on_batch_end(batch_index, last_batch_logs)
+ if callbacks.model.stop_training:
+ return
+
+ if do_validation:
+ val_outs = training_arrays.test_loop(
+ self,
+ val_inputs,
+ val_targets,
+ sample_weights=val_sample_weights,
+ batch_size=batch_size,
+ steps=validation_steps,
+ verbose=0)
+ if not isinstance(val_outs, list):
+ val_outs = [val_outs]
+ # Same labels assumed.
+ for l, o in zip(self.metrics_names, val_outs):
+ epoch_logs['val_' + l] = o
+
+ def _pipeline_fit_loop_step_wise(self,
+ ins,
+ callbacks,
+ steps_per_epoch,
+ epochs,
+ do_validation,
+ val_inputs,
+ val_targets,
+ val_sample_weights,
+ validation_steps,
+ epoch_logs):
+ f = self.train_function
+
+ # Loop prologue
+ try:
+ outs = f.pipeline_run(cur_step_inputs=None, next_step_inputs=ins)
+ assert outs is None # Function shouldn't return anything!
+ except errors.OutOfRangeError:
+ logging.warning('Your dataset iterator ran out of data on the first step '
+ 'of the epoch, preventing further training. Check to '
+ 'make sure your paths are correct and you have '
+ 'permissions to read the files. Skipping validation')
+
+ for step_index in range(steps_per_epoch - 1):
+ batch_logs = {'batch': step_index, 'size': 1}
+ callbacks.on_batch_begin(step_index, batch_logs)
+ try:
+ if step_index < steps_per_epoch - 1:
+ next_step_inputs = ins
+ else:
+ next_step_inputs = None
+ outs = f.pipeline_run(cur_step_inputs=ins,
+ next_step_inputs=next_step_inputs)
+ except errors.OutOfRangeError:
+ logging.warning('Your dataset iterator ran out of data; '
+ 'interrupting training. Make sure that your '
+ 'dataset can generate at least `steps_per_batch * '
+ 'epochs` batches (in this case, %d batches). You '
+ 'may need to use the repeat() function when '
+ 'building your dataset.' % steps_per_epoch * epochs)
+ break
+
+ if not isinstance(outs, list):
+ outs = [outs]
+ for l, o in zip(self.metrics_names, outs):
+ batch_logs[l] = o
+
+ callbacks.on_batch_end(step_index, batch_logs)
+ if callbacks.model.stop_training:
+ break
+
+ if do_validation:
+ val_outs = training_arrays.test_loop(self,
+ val_inputs,
+ val_targets,
+ sample_weights=val_sample_weights,
+ steps=validation_steps,
+ verbose=0)
+ if not isinstance(val_outs, list):
+ val_outs = [val_outs]
+ # Same labels assumed.
+ for l, o in zip(self.metrics_names, val_outs):
+ epoch_logs['val_' + l] = o
+
+ def _prepare_validation_data(self,
+ validation_data,
+ validation_split,
+ validation_steps,
+ x,
+ y,
+ sample_weights,
+ batch_size):
+ """Prepares the validation dataset.
+
+ Args:
+ validation_data: The validation data (if provided)
+ validation_split: The validation split (if provided)
+ validation_steps: The validation steps (if provided)
+ x: The main training data x (if provided)
+ y: The main training data y (if provided)
+ sample_weights: The sample weights (if provided)
+ batch_size: The training batch size (if provided)
+
+ Returns:
+ A 3-tuple of (val_x, val_y, val_sample_weights).
+
+ Raises:
+ ValueError: If the provided arguments are not compatible with
+ `KerasTPUModel`.
+ """
+ # Note: this is similar to a section of $tf/python/keras/engine/training.py
+ # It differns in that tf.data objects are not allowed to be passed directly.
+ # Additionally, it handles validating shapes & types appropriately for use
+ # in TPUs.
+ if validation_data:
+ if (isinstance(validation_data, iterator_ops.Iterator) or
+ isinstance(validation_data, iterator_ops.EagerIterator) or
+ isinstance(validation_data, dataset_ops.Dataset)):
+ raise ValueError('KerasTPUModel cannot handle a Dataset or Iterator '
+ 'for validation_data. Please instead pass a function '
+ 'that returns a `tf.data.Dataset`.')
+ if len(validation_data) == 2:
+ val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence
+ val_sample_weight = None
+ elif len(validation_data) == 3:
+ val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence
+ else:
+ raise ValueError('When passing a `validation_data` argument, it must '
+ 'contain either 2 items (x_val, y_val), or 3 items '
+ '(x_val, y_val, val_sample_weights). However we '
+ 'received `validation_data=%s`' % validation_data)
+ val_x, val_y, val_sample_weights = self._standardize_user_data(
+ val_x,
+ val_y,
+ sample_weight=val_sample_weight,
+ batch_size=batch_size,
+ steps=validation_steps)
+ elif validation_split and 0. < validation_split < 1.:
+ if training_utils.has_symbolic_tensors(x):
+ raise ValueError('If your data is in the form of symbolic tensors, you '
+ 'cannot use `validation_split`.')
+ if hasattr(x[0], 'shape'):
+ split_at = int(x[0].shape[0] * (1. - validation_split))
+ else:
+ split_at = int(len(x[0]) * (1. - validation_split))
+
+ x, val_x = (slice_arrays(x, 0, split_at), slice_arrays(x, split_at))
+ y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at))
+ sample_weights, val_sample_weights = (slice_arrays(
+ sample_weights, 0, split_at), slice_arrays(sample_weights, split_at))
+ elif validation_steps:
+ val_x = []
+ val_y = []
+ val_sample_weights = []
+ else:
+ val_x = None
+ val_y = None
+ val_sample_weights = None
+
+ return val_x, val_y, val_sample_weights
+
def _make_train_function(self):
if not self.train_function:
self.train_function = TPUFunction(
@@ -1122,7 +1807,7 @@ class KerasTPUModel(models.Model):
@contextlib.contextmanager
def tpu_session(self):
"""Yields a TPU session and sets it as the default Keras session."""
- with self._graph.as_default():
+ with self._session.graph.as_default():
default_session = K.get_session()
# N.B. We have to call `K.set_session()` AND set our session as the
# TF default. `K.get_session()` surprisingly does not return the value
@@ -1212,5 +1897,10 @@ def tpu_model(model, strategy=None):
if strategy is None:
strategy = TPUDistributionStrategy()
+ else:
+ if not isinstance(strategy, TPUDistributionStrategy):
+ raise TypeError(
+ '`strategy` must have type `tf.contrib.tpu.TPUDistributionStrategy`. '
+ 'Got: {}'.format(type(strategy)))
return KerasTPUModel(cpu_model=model, strategy=strategy)
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
new file mode 100644
index 0000000000..a423aeace7
--- /dev/null
+++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
@@ -0,0 +1,289 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Distributed variable implementation for TPUs.
+
+N.B. This is an experimental feature that should only be used for Keras support.
+
+It is unsupported and will be removed in favor of Distribution Strategy soon.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+
+from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import tf_logging as logging
+
+
+@contextlib.contextmanager
+def _handle_graph(handle):
+ with handle.graph.as_default():
+ yield
+
+
+def _enclosing_tpu_context():
+ # pylint: disable=protected-access
+ context = ops.get_default_graph()._get_control_flow_context()
+ # pylint: enable=protected-access
+ while context is not None and not isinstance(
+ context, control_flow_ops.XLAControlFlowContext):
+ context = context.outer_context
+ return context
+
+
+class ReplicatedVariable(object):
+ """A replicated variable for use on TPUs.
+
+ When accessed inside a tpu.replicate() context, this variable acts as if it
+ is a single variable whose handle is a replicated input to the computation.
+
+ Outside a tpu.replicate() context currently this object has pretty murky
+ semantics, especially with respect to things such as
+ * initialization
+ * colocation.
+ """
+
+ def __init__(self, name, variables):
+ self._name = name
+ self._primary_var = variables[0]
+ self._vars = variables
+ self._cached_value = None
+ self._dtype = variables[0].dtype
+
+ @property
+ def handle(self):
+ tpu_context = _enclosing_tpu_context()
+ if tpu_context is None:
+ return self._primary_var.handle
+
+ return tpu_context.get_replicated_var_handle(self)
+
+ @contextlib.contextmanager
+ def _assign_dependencies(self):
+ """Makes assignments depend on the cached value, if any.
+
+ This prevents undefined behavior with reads not ordered wrt writes.
+
+ Yields:
+ None.
+ """
+ if self._cached_value is not None:
+ with ops.control_dependencies([self._cached_value]):
+ yield
+ else:
+ yield
+
+ @property
+ def initializer(self):
+ return control_flow_ops.group([v.initializer for v in self._vars])
+
+ @property
+ def graph(self):
+ return self._primary_var.graph
+
+ @property
+ def _shared_name(self):
+ return self._common_name
+
+ @property
+ def _unique_id(self):
+ return self._primary_var._unique_id # pylint: disable=protected-access
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def dtype(self):
+ return self._primary_var.dtype
+
+ @property
+ def shape(self):
+ return self._primary_var.shape
+
+ def get_shape(self):
+ return self._primary_var.get_shape()
+
+ def to_proto(self, export_scope=None):
+ return self._primary_var.to_proto(export_scope=export_scope)
+
+ @property
+ def constraint(self):
+ return None
+
+ @property
+ def op(self):
+ return self.get().op
+
+ @property
+ def is_tensor_like(self):
+ return True
+
+ def _read_variable_op(self):
+ if _enclosing_tpu_context() is None:
+ return self._primary_var.read_value()
+ v = gen_resource_variable_ops.read_variable_op(self.handle, self._dtype)
+ return v
+
+ def read_value(self):
+ return self._read_variable_op()
+
+ def is_initialized(self, name=None):
+ return self._vars[0].is_initialized(name=name)
+
+ def __getitem__(self, *args):
+ return self.read_value().__getitem__(*args)
+
+ def assign(self, value, use_locking=None, name=None, read_value=False):
+ """Assign `value` to all replicas.
+
+ Outside of the tpu.rewrite context, assign explicitly to all replicas.
+ Inside of the tpu.rewrite context, assigns to the local replica.
+
+ Arguments:
+ value: Tensor to assign
+ use_locking: ignored
+ name: ignored
+ read_value: return the value from the assignment
+ Returns:
+ Assignment operation, or new value of the variable if `read_value` is True
+ """
+ del use_locking
+ if _enclosing_tpu_context() is None:
+ assign_ops = []
+ with self._assign_dependencies():
+ for var in self._vars:
+ assign_ops.append(var.assign(value, use_locking=None, name=name))
+
+ if read_value:
+ with ops.control_dependencies(assign_ops):
+ return self.read_value()
+ else:
+ return control_flow_ops.group(assign_ops)
+
+ with _handle_graph(self.handle), self._assign_dependencies():
+ value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
+ assign_op = gen_resource_variable_ops.assign_variable_op(
+ self.handle, value_tensor, name=name)
+ if read_value:
+ return self._read_variable_op()
+ return assign_op
+
+ def assign_add(self, delta, use_locking=None, name=None, read_value=True):
+ del use_locking
+ with _handle_graph(self.handle), self._assign_dependencies():
+ assign_add_op = gen_resource_variable_ops.assign_add_variable_op(
+ self.handle,
+ ops.convert_to_tensor(delta, dtype=self.dtype),
+ name=name)
+ if read_value:
+ return self._read_variable_op()
+ return assign_add_op
+
+ def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
+ del use_locking
+ with _handle_graph(self.handle), self._assign_dependencies():
+ assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(
+ self.handle,
+ ops.convert_to_tensor(delta, dtype=self.dtype),
+ name=name)
+ if read_value:
+ return self._read_variable_op()
+ return assign_sub_op
+
+ def get(self):
+ return self._primary_var
+
+ def _should_act_as_resource_variable(self):
+ """Pass resource_variable_ops.is_resource_variable check."""
+ pass
+
+ def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
+ """Converts a variable to a tensor."""
+ # pylint: disable=protected-access
+ if _enclosing_tpu_context() is None:
+ return self._primary_var._dense_var_to_tensor(dtype, name, as_ref)
+ # pylint: enable=protected-access
+ if dtype is not None and dtype != self.dtype:
+ return NotImplemented
+ if as_ref:
+ return self.handle
+ else:
+ return self.read_value()
+
+
+# Register a conversion function which reads the value of the variable,
+# allowing instances of the class to be used as tensors.
+def _tensor_conversion(var, dtype=None, name=None, as_ref=False):
+ return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
+
+
+def replicated_fetch_function(var):
+ # pylint: disable=protected-access
+ return ([var._dense_var_to_tensor()], lambda v: v[0])
+ # pylint: enable=protected-access
+
+
+ops.register_tensor_conversion_function(ReplicatedVariable, _tensor_conversion)
+ops.register_dense_tensor_like_type(ReplicatedVariable)
+session_lib.register_session_run_conversion_functions(
+ ReplicatedVariable, replicated_fetch_function)
+
+
+def replicated_scope(num_replicas):
+ """Variable scope for constructing replicated variables."""
+
+ def _replicated_variable_getter(getter, name, *args, **kwargs):
+ """Getter that constructs replicated variables."""
+ collections = kwargs.pop("collections", None)
+ if collections is None:
+ collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ kwargs["collections"] = []
+
+ logging.info("Constructing replicated variable %s", name)
+ variables = []
+ index = {}
+ for i in range(num_replicas):
+ replica_name = "{}/{}".format(name, i)
+ with ops.device("device:TPU:{}".format(i)):
+ v = getter(*args, name=replica_name, **kwargs)
+ variables.append(v)
+ index[i] = v
+ result = ReplicatedVariable(name, variables)
+
+ g = ops.get_default_graph()
+ # If "trainable" is True, next_creator() will add the member variables
+ # to the TRAINABLE_VARIABLES collection, so we manually remove
+ # them and replace with the MirroredVariable. We can't set
+ # "trainable" to False for next_creator() since that causes functions
+ # like implicit_gradients to skip those variables.
+ if kwargs.get("trainable", True):
+ collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
+ for v in index.values():
+ if v in l:
+ l.remove(v)
+ g.add_to_collections(collections, result)
+
+ return result
+
+ return variable_scope.variable_scope(
+ "", custom_getter=_replicated_variable_getter)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 3c735a0b85..1e21cc5252 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -78,10 +78,10 @@ def initialize_system(embedding_config=None, job=None):
embedding_config: If not None, an `EmbeddingLayerConfiguration` proto
describing the desired configuration of the hardware embedding lookup
tables. If embedding_config is None, no hardware embeddings can be used.
- job: The job (the XXX in TensorFlow device specification /job:XXX)
- that contains the TPU devices that will be initialized. If job=None
- it is assumed there is only one job in the TensorFlow flock, and an
- error will be returned if this assumption does not hold.
+ job: The job (the XXX in TensorFlow device specification /job:XXX) that
+ contains the TPU devices that will be initialized. If job=None it is
+ assumed there is only one job in the TensorFlow flock, and an error will
+ be returned if this assumption does not hold.
Returns:
A serialized `TopologyProto` that describes the TPU system. Note:
the topology must be evaluated using `Session.run` before it can be used.
@@ -118,9 +118,9 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ
is a unique name.
- We use a `ControlFlowContext` to perform the annotation since it
- integrates with Tensorflow constructs like ResourceVariables. For example,
- if a `ResourceVariable` is constructed inside a tpu.replicate() block, the
+ We use a `ControlFlowContext` to perform the annotation since it integrates
+ with Tensorflow constructs like ResourceVariables. For example, if a
+ `ResourceVariable` is constructed inside a tpu.replicate() block, the
`ResourceVariable` implementation can use
`with ops.control_dependencies(None)` to build the variable's definition
outside the replicated computation.
@@ -157,8 +157,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
def get_replicated_var_handle(self, var):
"""Returns a variable handle for replicated TPU variable 'var'.
- This is an method used by an experimental replicated variable
- implementation and is not intended as a public API.
+ This is a method used by an experimental replicated variable implementation
+ and is not intended as a public API.
Args:
var: The replicated TPU variable.
@@ -211,28 +211,24 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
if gradient_uid == "__unsupported__":
raise NotImplementedError(
"No gradient_uid calling gradient within outside_compilation")
- # When we take the gradient of an op X in an
- # outside_compilation cluster C in a forward computation we
- # would like to put the ops corresponding to the gradient of
- # X into a new outside_compilation cluster C'. However, if
- # we take the gradient of X twice, the second one should get
- # yet another new outside_compilation cluster C''.
+ # When we take the gradient of an op X in an outside_compilation
+ # cluster C in a forward computation we would like to put the ops
+ # corresponding to the gradient of X into a new outside_compilation
+ # cluster C'. However, if we take the gradient of X twice, the second
+ # one should get yet another new outside_compilation cluster C''.
#
- # The mechanism we adopt is to use a 'root_cluster' which is
- # the cluster that X was in before we took gradients, and a
- # 'gradient_uid' which is different for every invocation of
- # gradients, and put the gradient of X in cluster
- # 'root_cluster.gradient_uid'.
+ # The mechanism we adopt is to use a 'root_cluster' which is the
+ # cluster that X was in before we took gradients, and a 'gradient_uid'
+ # which is different for every invocation of gradients, and put the
+ # gradient of X in cluster 'root_cluster.gradient_uid'.
#
- # When taking a gradient of a gradient, some ops will be
- # colocated with Op in the forward pass (e.g., cluster
- # root_cluster) and some in the backward pass (e.g., cluster
- # root_cluster.initial_gradient_uid). We need all of the
- # grad-of-grad ops to be in the same cluster to avoid cyclic
- # dependencies between clusters. We adopt a heuristic that
- # puts any op clustered with root_cluster.<xxx> in
- # root_cluster.gradient_uid, even if xxx was
- # initial_gradient_uid.
+ # When taking a gradient of a gradient, some ops will be colocated
+ # with Op in the forward pass (e.g., cluster root_cluster) and some in
+ # the backward pass (e.g., cluster root_cluster.initial_gradient_uid).
+ # We need all of the grad-of-grad ops to be in the same cluster to
+ # avoid cyclic dependencies between clusters. We adopt a heuristic
+ # that puts any op clustered with root_cluster.<xxx> in
+ # root_cluster.gradient_uid, even if xxx was initial_gradient_uid.
self._in_gradient_colocation = op
parts = outside_attr.split(".")
cluster = parts[0] + "." + gradient_uid
@@ -765,11 +761,10 @@ def shard(computation,
name=None):
"""Shards `computation` for parallel execution.
- `inputs` must be a list of Tensors or None (equivalent to an empty
- list), each of which has a corresponding split axis (from
- `input_shard_axes`). Each input is split into `num_shards` pieces
- along the corresponding axis, and computation is applied to each
- shard in parallel.
+ `inputs` must be a list of Tensors or None (equivalent to an empty list), each
+ of which has a corresponding split axis (from `input_shard_axes`). Each input
+ is split into `num_shards` pieces along the corresponding axis, and
+ computation is applied to each shard in parallel.
Tensors are broadcast to all shards if they are lexically captured by
`computation`. e.g.,
@@ -791,10 +786,9 @@ def shard(computation,
Args:
computation: A Python function that builds a computation to apply to each
shard of the input.
- inputs: A list of input tensors or None (equivalent to an empty
- list). Each input tensor has a corresponding shard axes, given
- by `input_shard_axes`, which must have size divisible by
- `num_shards`.
+ inputs: A list of input tensors or None (equivalent to an empty list). Each
+ input tensor has a corresponding shard axes, given by `input_shard_axes`,
+ which must have size divisible by `num_shards`.
num_shards: The number of shards.
input_shard_axes: A list of dimensions along which to shard `inputs`, or
`None`. `None` means "shard all inputs along dimension 0". If not `None`,
@@ -913,9 +907,9 @@ def batch_parallel(computation,
Convenience wrapper around shard().
- `inputs` must be a list of Tensors or None (equivalent to an empty
- list). Each input is split into `num_shards` pieces along the 0-th
- dimension, and computation is applied to each shard in parallel.
+ `inputs` must be a list of Tensors or None (equivalent to an empty list).
+ Each input is split into `num_shards` pieces along the 0-th dimension, and
+ computation is applied to each shard in parallel.
Tensors are broadcast to all shards if they are lexically captured by
`computation`. e.g.,
@@ -933,9 +927,8 @@ def batch_parallel(computation,
Args:
computation: A Python function that builds a computation to apply to each
shard of the input.
- inputs: A list of input tensors or None (equivalent to an empty
- list). The 0-th dimension of each Tensor must have size
- divisible by `num_shards`.
+ inputs: A list of input tensors or None (equivalent to an empty list). The
+ 0-th dimension of each Tensor must have size divisible by `num_shards`.
num_shards: The number of shards.
infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
of arguments as inputs to `computation`.
@@ -968,14 +961,14 @@ def rewrite(computation,
"""Rewrites `computation` for execution on a TPU system.
Args:
- computation: A Python function that builds a computation to apply
- to the input. If the function takes n inputs, 'inputs' should be
- a list of n tensors.
+ computation: A Python function that builds a computation to apply to the
+ input. If the function takes n inputs, 'inputs' should be a list of n
+ tensors.
- `computation` may return a list of operations and tensors. Tensors must
+ `computation` may return a list of operations and tensors. Tensors must
come before operations in the returned list. The return value of
`rewrite` is a list of tensors corresponding to the tensors from the
- from `computation`.
+ output of `computation`.
All `Operation`s returned from `computation` will be executed when
evaluating any of the returned output tensors.
@@ -1070,12 +1063,12 @@ class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext):
def validate_inference_rewrite_for_variables(graph):
"""Validates whether rewrite_for_inference() 'worked' for variables.
- The rewrite_for_inference() method is supposed to append
- GuaranteeConstOps after ReadVariableOps, but this mechanism works only
- if you are using tf.get_variable() to create and access variables in your
- tpu computation. This validation method can be called immediately after
- calling tpu.rewrite_for_inference() to check whether GuaranteeConstOps
- where added to the graph.
+ The rewrite_for_inference() method is supposed to append GuaranteeConstOps
+ after ReadVariableOps, but this mechanism works only if you are using
+ tf.get_variable() to create and access variables in your tpu computation.
+ This validation method can be called immediately after calling
+ tpu.rewrite_for_inference() to check whether GuaranteeConstOps where added
+ to the graph.
Typical usages:
tpu.validate_inference_rewrite_for_variables(tf.get_default_graph())
@@ -1089,10 +1082,9 @@ def validate_inference_rewrite_for_variables(graph):
"""
if not any([x.type == "GuaranteeConst" for x in graph.get_operations()]):
raise RuntimeError(
- "No GuaranteeConst ops found in the graph after "
- "running tpu.rewrite_for_inference(...). Please "
- "check that you are using tf.get_variable() to "
- "create and access variables in your tpu "
+ "No GuaranteeConst ops found in the graph after running "
+ "tpu.rewrite_for_inference(...). Please check that you are using "
+ "tf.get_variable() to create and access variables in your tpu "
"computation.")
@@ -1108,16 +1100,16 @@ def rewrite_for_inference(computation,
in your computation, it moves the ReadVariableOps outside the TPU
computation, and adds GuaranteeConst ops just after the ReadVariableOps.
This mechanism works only if you are using tf.get_variable() to create and
- access variables in your tpu computation. You can validate whether
- this worked, by calling validate_inference_rewrite_for_variables() method
+ access variables in your tpu computation. You can validate whether this
+ worked, by calling validate_inference_rewrite_for_variables() method
immediately after this method to check whether GuaranteeConstOps where
added to the graph.
Args:
- computation: A Python function that builds a computation to apply
- to the input. If the function takes n inputs, 'inputs' should be
- a list of n tensors. If the function returns m outputs, rewrite
- will return a list of m tensors.
+ computation: A Python function that builds a computation to apply to the
+ input. If the function takes n inputs, 'inputs' should be a list of n
+ tensors. If the function returns m outputs, rewrite will return a list of
+ m tensors.
inputs: A list of input tensors or `None` (equivalent to an empty list).
infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
of arguments as inputs to `computation`.
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
index 74a675b645..1e11de6421 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
@@ -19,7 +19,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
@@ -44,8 +43,9 @@ class CrossShardOptimizer(optimizer.Optimizer):
reduction: The reduction to apply to the shard losses.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "CrossShardOptimizer".
- group_assignment: Optional list of group ids for applying the optimizer
- to subgroups.
+ group_assignment: Optional 2d int32 lists with shape
+ [num_groups, num_replicas_per_group] which describles how to apply
+ optimizer to subgroups.
Raises:
ValueError: If reduction is not a valid cross-shard reduction.
@@ -74,11 +74,22 @@ class CrossShardOptimizer(optimizer.Optimizer):
"""
if not group_assignment:
return None
- if len(group_assignment) != num_shards:
- raise ValueError("The size of group_assignment does not equal to "
- "num_shard({0}). Got group_assignment={1}".format(
- num_shards, self._group_assignment))
- subgroup_size_list = dict(collections.Counter(group_assignment)).values()
+ if not (isinstance(group_assignment, list) and
+ all(isinstance(i, list) for i in group_assignment)):
+ raise ValueError("group_assignment must be a list of list. Got {}".format(
+ group_assignment))
+
+ replica_ids = set()
+ for g in group_assignment:
+ for i in g:
+ replica_ids.add(i)
+
+ if set(range(num_shards)) != replica_ids:
+ raise ValueError("group_assignment must be a permutation of range({0})."
+ " Got group_assignment={1}".format(
+ num_shards, group_assignment))
+
+ subgroup_size_list = [len(group) for group in group_assignment]
if all(subgroup_size_list[0] == size for size in subgroup_size_list):
return subgroup_size_list[0]
else:
diff --git a/tensorflow/contrib/training/python/training/evaluation.py b/tensorflow/contrib/training/python/training/evaluation.py
index 01bac891da..16a647bf66 100644
--- a/tensorflow/contrib/training/python/training/evaluation.py
+++ b/tensorflow/contrib/training/python/training/evaluation.py
@@ -296,6 +296,7 @@ class SummaryAtEndHook(session_run_hook.SessionRunHook):
def begin(self):
if self._replace_summary_op:
+ # This can still remain None if there are no summaries.
self._summary_op = summary.merge_all()
self._global_step = training_util.get_or_create_global_step()
@@ -304,10 +305,12 @@ class SummaryAtEndHook(session_run_hook.SessionRunHook):
self._summary_writer = summary.FileWriterCache.get(self._log_dir)
def end(self, session):
- global_step = training_util.global_step(session, self._global_step)
- summary_str = session.run(self._summary_op, self._feed_dict)
+ if self._summary_op is not None:
+ global_step = training_util.global_step(session, self._global_step)
+ summary_str = session.run(self._summary_op, self._feed_dict)
+ if self._summary_writer:
+ self._summary_writer.add_summary(summary_str, global_step)
if self._summary_writer:
- self._summary_writer.add_summary(summary_str, global_step)
self._summary_writer.flush()
diff --git a/tensorflow/contrib/training/python/training/evaluation_test.py b/tensorflow/contrib/training/python/training/evaluation_test.py
index ec47fe5d97..ddd135f047 100644
--- a/tensorflow/contrib/training/python/training/evaluation_test.py
+++ b/tensorflow/contrib/training/python/training/evaluation_test.py
@@ -427,9 +427,11 @@ class EvaluateRepeatedlyTest(test.TestCase):
names_to_updates = {'Accuracy': update_op0, 'Another_accuracy': update_op1}
return names_to_values, names_to_updates
- def _verify_summaries(self, output_dir, names_to_values):
+ def _verify_events(self, output_dir, names_to_values):
"""Verifies that the given `names_to_values` are found in the summaries.
+ Also checks that a GraphDef was written out to the events file.
+
Args:
output_dir: An existing directory where summaries are found.
names_to_values: A dictionary of strings to values.
@@ -440,7 +442,13 @@ class EvaluateRepeatedlyTest(test.TestCase):
self.assertEqual(len(output_filepath), 1)
events = summary_iterator.summary_iterator(output_filepath[0])
- summaries = [e.summary for e in events if e.summary.value]
+ summaries = []
+ graph_def = None
+ for event in events:
+ if event.summary.value:
+ summaries.append(event.summary)
+ elif event.graph_def:
+ graph_def = event.graph_def
values = []
for summary in summaries:
for value in summary.value:
@@ -448,6 +456,7 @@ class EvaluateRepeatedlyTest(test.TestCase):
saved_results = {v.tag: v.simple_value for v in values}
for name in names_to_values:
self.assertAlmostEqual(names_to_values[name], saved_results[name], 5)
+ self.assertIsNotNone(graph_def)
def testSummariesAreFlushedToDisk(self):
checkpoint_dir = os.path.join(self.get_temp_dir(), 'summaries_are_flushed')
@@ -475,7 +484,23 @@ class EvaluateRepeatedlyTest(test.TestCase):
],
max_number_of_evaluations=1)
- self._verify_summaries(logdir, names_to_values)
+ self._verify_events(logdir, names_to_values)
+
+ def testSummaryAtEndHookWithoutSummaries(self):
+ logdir = os.path.join(self.get_temp_dir(),
+ 'summary_at_end_hook_without_summaires')
+ if gfile.Exists(logdir):
+ gfile.DeleteRecursively(logdir)
+
+ with ops.Graph().as_default():
+ # Purposefully don't add any summaries. The hook will just dump the
+ # GraphDef event.
+ hook = evaluation.SummaryAtEndHook(log_dir=logdir)
+ hook.begin()
+ with self.cached_session() as session:
+ hook.after_create_session(session, None)
+ hook.end(session)
+ self._verify_events(logdir, {})
if __name__ == '__main__':
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 51225f34bc..5c314f359c 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -612,6 +612,7 @@ cc_library(
copts = tf_copts(),
deps = tf_lib_proto_parsing_deps() + [
":platform_base",
+ "@com_google_absl//absl/strings",
"@double_conversion//:double-conversion",
],
)
@@ -694,6 +695,8 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":lib_internal",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -730,10 +733,12 @@ cc_library(
# required to use tf_cc_test, and that rule will change / into _
cc_library(
name = "core_stringpiece",
- srcs = ["lib/core/stringpiece.cc"],
hdrs = ["lib/core/stringpiece.h"],
copts = tf_copts(),
- deps = [":platform_base"],
+ deps = [
+ ":platform_base",
+ "@com_google_absl//absl/strings",
+ ],
)
# Test support library needed for all tests
@@ -868,7 +873,6 @@ tf_cuda_library(
"util/sparse/sparse_tensor.h",
"util/stat_summarizer.h",
"util/stat_summarizer_options.h",
- "util/status_util.h",
"util/stream_executor_util.h",
"util/strided_slice_op.h",
"util/tensor_format.h",
@@ -935,15 +939,6 @@ cc_library(
)
cc_library(
- name = "status_util",
- hdrs = ["util/status_util.h"],
- deps = [
- ":graph",
- ":lib",
- ],
-)
-
-cc_library(
name = "reader_base",
srcs = ["framework/reader_base.cc"],
hdrs = ["framework/reader_base.h"],
@@ -1589,7 +1584,9 @@ cc_library(
cc_library(
name = "mobile_additional_lib_deps",
- deps = tf_additional_lib_deps(),
+ deps = tf_additional_lib_deps() + [
+ "@com_google_absl//absl/strings",
+ ],
)
# Native library support for iOS applications.
@@ -2067,6 +2064,7 @@ cc_library(
],
}),
deps = tf_additional_lib_deps() + [
+ "@com_google_absl//absl/strings",
"//third_party/eigen3",
"//tensorflow/core/platform/default/build_config:platformlib",
] + if_static([":lib_internal_impl"]),
@@ -2084,7 +2082,6 @@ cc_library(
exclude = [
"**/*test*",
"framework/variant.cc",
- "lib/core/stringpiece.cc",
"lib/hash/crc32c_accelerate.cc",
"lib/gif/**/*",
"lib/jpeg/**/*",
@@ -2099,7 +2096,6 @@ cc_library(
) + tf_additional_lib_srcs(
exclude = [
"**/*test*",
- "lib/core/stringpiece.cc",
"platform/**/cuda.h",
"platform/**/cuda_libdevice_path.cc",
"platform/**/stream_executor.h",
@@ -2210,6 +2206,7 @@ cc_library(
":lib",
":lib_internal",
"//tensorflow/core/platform/default/build_config:png",
+ "@com_google_absl//absl/strings",
"@zlib_archive//:zlib",
],
)
@@ -2225,7 +2222,7 @@ cc_library(
"platform/macros.h",
"platform/platform.h",
"platform/types.h",
- ],
+ ] + if_windows(["platform/windows/integral_types.h"]),
copts = tf_copts(),
linkopts = ["-ldl"],
deps = [
@@ -2260,6 +2257,7 @@ cc_library(
deps = [
"//tensorflow/core/platform/default/build_config:jpeg",
"//tensorflow/core/platform/default/build_config:logging",
+ "@com_google_absl//absl/strings",
],
)
@@ -2291,6 +2289,7 @@ cc_library(
deps = [
"//tensorflow/core/platform/default/build_config:gif",
"//tensorflow/core/platform/default/build_config:logging",
+ "@com_google_absl//absl/strings",
],
)
@@ -2318,6 +2317,7 @@ cc_library(
linkopts = ["-ldl"],
deps = [
"//tensorflow/core/platform/default/build_config:logging",
+ "@com_google_absl//absl/strings",
"@png_archive//:png",
],
)
@@ -2462,6 +2462,7 @@ tf_cuda_library(
cc_header_only_library(
name = "framework_internal_headers_lib",
+ includes = ["../../external/com_google_absl"],
deps = [
":lib",
":lib_internal",
@@ -2546,6 +2547,11 @@ tf_cuda_library(
cc_header_only_library(
name = "framework_headers_lib",
+ extra_deps = [
+ # ABSL headers get dropped, so we add them back here.
+ "@com_google_absl//absl/strings",
+ ],
+ includes = ["../../external/com_google_absl"],
visibility = ["//visibility:public"],
deps = [
":framework",
@@ -2555,6 +2561,7 @@ cc_header_only_library(
cc_header_only_library(
name = "stream_executor_headers_lib",
+ includes = ["../../external/com_google_absl"],
visibility = ["//visibility:public"],
deps = [
":stream_executor",
@@ -2726,6 +2733,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/graph_optimizer.h",
"common_runtime/local_device.h",
"common_runtime/lower_if_op.h",
+ "common_runtime/lower_while_op.h",
"common_runtime/memory_types.h",
"common_runtime/mkl_cpu_allocator.h",
"common_runtime/optimization_registry.h",
@@ -2782,6 +2790,7 @@ tf_cuda_library(
"common_runtime/hierarchical_tree_broadcaster.cc",
"common_runtime/local_device.cc",
"common_runtime/lower_if_op.cc",
+ "common_runtime/lower_while_op.cc",
"common_runtime/memory_types.cc",
"common_runtime/mkl_cpu_allocator.cc",
"common_runtime/optimization_registry.cc",
@@ -3206,7 +3215,6 @@ tf_cc_tests(
"lib/core/status_test.cc",
"lib/core/stringpiece_test.cc",
"lib/core/threadpool_test.cc",
- "lib/gtl/array_slice_test.cc",
"lib/gtl/cleanup_test.cc",
"lib/gtl/compactptrset_test.cc",
"lib/gtl/edit_distance_test.cc",
@@ -3217,7 +3225,6 @@ tf_cc_tests(
"lib/gtl/iterator_range_test.cc",
"lib/gtl/manual_constructor_test.cc",
"lib/gtl/map_util_test.cc",
- "lib/gtl/optional_test.cc",
"lib/gtl/top_n_test.cc",
"lib/hash/crc32c_test.cc",
"lib/hash/hash_test.cc",
@@ -3543,7 +3550,6 @@ tf_cc_tests(
"util/semver_test.cc",
"util/sparse/sparse_tensor_test.cc",
"util/stat_summarizer_test.cc",
- "util/status_util_test.cc",
"util/tensor_format_test.cc",
"util/tensor_slice_reader_test.cc",
"util/tensor_slice_set_test.cc",
@@ -3568,7 +3574,6 @@ tf_cc_tests(
":ops",
":protos_all_cc",
":protos_test_cc",
- ":status_util",
":test",
":test_main",
":testlib",
@@ -4061,6 +4066,7 @@ tf_cuda_cc_test(
":testlib",
"//third_party/eigen3",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/core/kernels:collective_ops",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:dense_update_ops",
@@ -4102,6 +4108,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops",
# Link with support for TensorFlow Debugger (tfdbg).
"//tensorflow/core/debug",
+ "//tensorflow/core/kernels:collective_ops",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:dense_update_ops",
@@ -4586,6 +4593,29 @@ tf_cc_tests(
],
)
+tf_cc_tests(
+ name = "common_runtime_lower_while_op_test",
+ size = "small",
+ srcs = ["common_runtime/lower_while_op_test.cc"],
+ deps = [
+ ":all_kernels",
+ ":core_cpu",
+ ":core_cpu_internal",
+ ":direct_session",
+ ":framework",
+ ":framework_internal",
+ ":lib",
+ ":test",
+ ":test_main",
+ ":testlib",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
+ "//tensorflow/cc:client_session",
+ "//tensorflow/cc:function_ops",
+ "//tensorflow/cc:ops",
+ ],
+)
+
# Test data
filegroup(
name = "image_testdata",
diff --git a/tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt
deleted file mode 100644
index ffd01ba5cc..0000000000
--- a/tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt
+++ /dev/null
@@ -1,3 +0,0 @@
-op {
- graph_op_name: "FeatureStatsDataset"
-}
diff --git a/tensorflow/core/api_def/base_api/api_def_ParseSequenceExample.pbtxt b/tensorflow/core/api_def/base_api/api_def_ParseSequenceExample.pbtxt
new file mode 100644
index 0000000000..b1cb9a696d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ParseSequenceExample.pbtxt
@@ -0,0 +1,112 @@
+op {
+ graph_op_name: "ParseSequenceExample"
+ in_arg {
+ name: "serialized"
+ description: <<END
+A vector containing binary serialized SequenceExample protos.
+END
+ }
+ in_arg {
+ name: "debug_name"
+ description: <<END
+A vector containing the names of the serialized protos.
+May contain, for example, table key (descriptive) name for the
+corresponding serialized proto. This is purely useful for debugging
+purposes, and the presence of values here has no effect on the output.
+May also be an empty vector if no name is available.
+END
+ }
+ in_arg {
+ name: "context_dense_defaults"
+ description: <<END
+A list of Ncontext_dense Tensors (some may be empty).
+context_dense_defaults[j] provides default values
+when the SequenceExample's context map lacks context_dense_key[j].
+If an empty Tensor is provided for context_dense_defaults[j],
+then the Feature context_dense_keys[j] is required.
+The input type is inferred from context_dense_defaults[j], even when it's
+empty. If context_dense_defaults[j] is not empty, its shape must match
+context_dense_shapes[j].
+END
+ }
+ attr {
+ name: "feature_list_dense_missing_assumed_empty"
+ description: <<END
+A vector listing the
+FeatureList keys which may be missing from the SequenceExamples. If the
+associated FeatureList is missing, it is treated as empty. By default,
+any FeatureList not listed in this vector must exist in the SequenceExamples.
+END
+ }
+ attr {
+ name: "context_sparse_keys"
+ description: <<END
+A list of Ncontext_sparse string Tensors (scalars).
+The keys expected in the Examples' features associated with context_sparse
+values.
+END
+ }
+ attr {
+ name: "context_dense_keys"
+ description: <<END
+A list of Ncontext_dense string Tensors (scalars).
+The keys expected in the SequenceExamples' context features associated with
+dense values.
+END
+ }
+ attr {
+ name: "feature_list_sparse_keys"
+ description: <<END
+A list of Nfeature_list_sparse string Tensors
+(scalars). The keys expected in the FeatureLists associated with sparse
+values.
+END
+ }
+ attr {
+ name: "feature_list_dense_keys"
+ description: <<END
+A list of Nfeature_list_dense string Tensors (scalars).
+The keys expected in the SequenceExamples' feature_lists associated
+with lists of dense values.
+END
+ }
+ attr {
+ name: "context_sparse_types"
+ description: <<END
+A list of Ncontext_sparse types; the data types of data in
+each context Feature given in context_sparse_keys.
+Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
+DT_INT64 (Int64List), and DT_STRING (BytesList).
+END
+ }
+ attr {
+ name: "context_dense_shapes"
+ description: <<END
+A list of Ncontext_dense shapes; the shapes of data in
+each context Feature given in context_dense_keys.
+The number of elements in the Feature corresponding to context_dense_key[j]
+must always equal context_dense_shapes[j].NumEntries().
+The shape of context_dense_values[j] will match context_dense_shapes[j].
+END
+ }
+ attr {
+ name: "feature_list_sparse_types"
+ description: <<END
+A list of Nfeature_list_sparse types; the data types
+of data in each FeatureList given in feature_list_sparse_keys.
+Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
+DT_INT64 (Int64List), and DT_STRING (BytesList).
+END
+ }
+ attr {
+ name: "feature_list_dense_shapes"
+ description: <<END
+A list of Nfeature_list_dense shapes; the shapes of
+data in each FeatureList given in feature_list_dense_keys.
+The shape of each Feature in the FeatureList corresponding to
+feature_list_dense_key[j] must always equal
+feature_list_dense_shapes[j].NumEntries().
+END
+ }
+ summary: "Transforms a vector of brain.SequenceExample protos (as strings) into typed tensors."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt b/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt
index 8d6fc04847..9a89a4e8e7 100644
--- a/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt
@@ -32,7 +32,7 @@ END
description: <<END
a bitmask where a bit i being 1 means to ignore the begin
value and instead use the largest interval possible. At runtime
-begin[i] will be replaced with `[0, n-1) if `stride[i] > 0` or
+begin[i] will be replaced with `[0, n-1)` if `stride[i] > 0` or
`[-1, n-1]` if `stride[i] < 0`
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt
new file mode 100644
index 0000000000..3022fccb1e
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt
@@ -0,0 +1,12 @@
+op {
+ graph_op_name: "TensorListGather"
+ summary: "Creates a Tensor by indexing into the TensorList."
+ description: <<END
+Each row in the produced Tensor corresponds to the element in the TensorList
+specified by the given index (see `tf.gather`).
+
+input_handle: The input tensor list.
+indices: The indices used to index into the list.
+values: The tensor.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt
new file mode 100644
index 0000000000..35194b353e
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt
@@ -0,0 +1,14 @@
+op {
+ graph_op_name: "TensorListScatter"
+ summary: "Creates a TensorList by indexing into a Tensor."
+ description: <<END
+Each member of the TensorList corresponds to one row of the input tensor,
+specified by the given index (see `tf.gather`).
+
+tensor: The input tensor.
+indices: The indices used to index into the list.
+element_shape: The shape of the elements in the list (can be less specified than
+ the shape of the tensor).
+output_handle: The TensorList.
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt b/tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt
deleted file mode 100644
index 7f721f4fb7..0000000000
--- a/tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt
+++ /dev/null
@@ -1,4 +0,0 @@
-op {
- graph_op_name: "FeatureStatsDataset"
- visibility: HIDDEN
-}
diff --git a/tensorflow/core/api_def/python_api/api_def_ParseSequenceExample.pbtxt b/tensorflow/core/api_def/python_api/api_def_ParseSequenceExample.pbtxt
new file mode 100644
index 0000000000..4a7e75ba0e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ParseSequenceExample.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ParseSequenceExample"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index 52eedae9b7..3b2dc6a050 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -192,7 +192,9 @@ void OrderTaskDeviceMap(TaskDeviceMap* tdm) {
int next_rank = 0;
while (true) {
selected.insert(next_device);
- DevRec* dr = &(*tdm)[next_device];
+ auto next_dev_it = tdm->find(next_device);
+ CHECK(next_dev_it != tdm->end());
+ DevRec* dr = &next_dev_it->second;
dr->local_rank = next_rank;
++next_rank;
if (selected.size() == tdm->size()) {
@@ -206,9 +208,15 @@ void OrderTaskDeviceMap(TaskDeviceMap* tdm) {
parsed_name.id = il.device_id();
string endpoint_device =
DeviceNameUtils::ParsedNameToString(parsed_name);
+ // Skip the device if we've already seen it.
if (selected.find(endpoint_device) != selected.end()) {
continue;
}
+ // Skip the device if it is not participating in this collective
+ // instance.
+ if (tdm->find(endpoint_device) == tdm->end()) {
+ continue;
+ }
if (best_link == nullptr || il.strength() > best_link->strength()) {
best_link = &il;
}
@@ -407,6 +415,10 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams(
});
}
+// NOTE(ayushd): The DeviceLocality objects in localities will have LocalLinks
+// to all devices that they are physically connected to and visible to the
+// TensorFlow runtime. This set of devices may be a superset of the devices
+// participating in this instance of collectives.
void CollectiveParamResolverLocal::CompleteDefaultRanking(
const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
const std::vector<DeviceLocality>& localities) {
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index bf1d78ec65..eb388202fa 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -451,8 +451,22 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
RunState run_state(step_id, &devices_);
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
#ifndef __ANDROID__
- // Set up for collectives if the RunOption declares a key.
- if (run_options.experimental().collective_graph_key() > 0) {
+ // Set up for collectives if ExecutorsAndKeys declares a key.
+ if (executors_and_keys->collective_graph_key !=
+ BuildGraphOptions::kNoCollectiveGraphKey) {
+ if (run_options.experimental().collective_graph_key() !=
+ BuildGraphOptions::kNoCollectiveGraphKey) {
+ // If a collective_graph_key was specified in run_options, ensure that it
+ // matches what came out of GraphExecutionState::BuildGraph().
+ if (run_options.experimental().collective_graph_key() !=
+ executors_and_keys->collective_graph_key) {
+ return errors::Internal(
+ "collective_graph_key in RunOptions ",
+ run_options.experimental().collective_graph_key(),
+ " should match collective_graph_key from optimized graph ",
+ executors_and_keys->collective_graph_key);
+ }
+ }
if (!collective_executor_mgr_) {
std::unique_ptr<DeviceResolverInterface> drl(
new DeviceResolverLocal(device_mgr_.get()));
@@ -678,10 +692,16 @@ Status DirectSession::Run(const RunOptions& run_options,
// Check if we already have an executor for these arguments.
ExecutorsAndKeys* executors_and_keys;
RunStateArgs run_state_args(run_options.debug_options());
+ run_state_args.collective_graph_key =
+ run_options.experimental().collective_graph_key();
TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
target_nodes, &executors_and_keys,
&run_state_args));
+ {
+ mutex_lock l(collective_graph_key_lock_);
+ collective_graph_key_ = executors_and_keys->collective_graph_key;
+ }
// Configure a call frame for the step, which we use to feed and
// fetch values to and from the executors.
@@ -1116,6 +1136,8 @@ Status DirectSession::CreateExecutors(
BuildGraphOptions options;
options.callable_options = callable_options;
options.use_function_convention = !run_state_args->is_partial_run;
+ options.collective_graph_key =
+ callable_options.run_options().experimental().collective_graph_key();
std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
@@ -1123,9 +1145,9 @@ Status DirectSession::CreateExecutors(
ek->callable_options = callable_options;
std::unordered_map<string, std::unique_ptr<Graph>> graphs;
- TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &func_info->flib_def,
- run_state_args, &ek->input_types,
- &ek->output_types));
+ TF_RETURN_IF_ERROR(CreateGraphs(
+ options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types,
+ &ek->output_types, &ek->collective_graph_key));
if (run_state_args->is_partial_run) {
ek->graph = std::move(run_state_args->graph);
@@ -1353,6 +1375,9 @@ Status DirectSession::GetOrCreateExecutors(
}
*callable_options.mutable_run_options()->mutable_debug_options() =
run_state_args->debug_options;
+ callable_options.mutable_run_options()
+ ->mutable_experimental()
+ ->set_collective_graph_key(run_state_args->collective_graph_key);
std::unique_ptr<ExecutorsAndKeys> ek;
std::unique_ptr<FunctionInfo> func_info;
TF_RETURN_IF_ERROR(
@@ -1379,7 +1404,7 @@ Status DirectSession::CreateGraphs(
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
RunStateArgs* run_state_args, DataTypeVector* input_types,
- DataTypeVector* output_types) {
+ DataTypeVector* output_types, int64* collective_graph_key) {
mutex_lock l(graph_def_lock_);
std::unique_ptr<ClientGraph> client_graph;
@@ -1403,6 +1428,7 @@ Status DirectSession::CreateGraphs(
TF_RETURN_IF_ERROR(
execution_state->BuildGraph(subgraph_options, &client_graph));
}
+ *collective_graph_key = client_graph->collective_graph_key;
if (subgraph_options.callable_options.feed_size() !=
client_graph->feed_types.size()) {
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 55a6fbce6d..c2cf3c7fd7 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -117,6 +117,9 @@ class DirectSession : public Session {
::tensorflow::Status ReleaseCallable(CallableHandle handle) override;
private:
+ // For access to collective_graph_key_.
+ friend class DirectSessionCollectiveTest;
+
// We create one executor and its dependent library runtime for
// every partition.
struct PerPartitionExecutorsAndLib {
@@ -150,6 +153,8 @@ class DirectSession : public Session {
DataTypeVector output_types;
CallableOptions callable_options;
+
+ int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
};
// A FunctionInfo object is created for every unique set of feeds/fetches.
@@ -203,6 +208,7 @@ class DirectSession : public Session {
string handle;
std::unique_ptr<Graph> graph;
const DebugOptions& debug_options;
+ int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
};
// Initializes the base execution state given the 'graph',
@@ -234,7 +240,7 @@ class DirectSession : public Session {
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
RunStateArgs* run_state_args, DataTypeVector* input_types,
- DataTypeVector* output_types);
+ DataTypeVector* output_types, int64* collective_graph_key);
::tensorflow::Status RunInternal(int64 step_id, const RunOptions& run_options,
CallFrameInterface* call_frame,
@@ -391,6 +397,10 @@ class DirectSession : public Session {
Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
+ // For testing collective graph key generation.
+ mutex collective_graph_key_lock_;
+ int64 collective_graph_key_ GUARDED_BY(collective_graph_key_lock_) = -1;
+
TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
// EXPERIMENTAL: debugger (tfdbg) related
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 4b51b20bb1..3f2355e530 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -2218,4 +2218,121 @@ BENCHMARK(BM_FeedFetch)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
BENCHMARK(BM_FeedFetchCallable)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
} // namespace
+
+class DirectSessionCollectiveTest : public ::testing::Test {
+ public:
+ // Creates a graph with CollectiveOps inside functions and runs it. Returns
+ // the generated collective_graph_key.
+ Status RunGraphWithCollectiveFunctions(bool add_unused_function,
+ int64* collective_graph_key) {
+ GraphDef g = CreateGraph(add_unused_function);
+ const Tensor t1 =
+ test::AsTensor<float>({0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1});
+ const Tensor t2 =
+ test::AsTensor<float>({0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3});
+ auto session = CreateSession();
+ TF_RETURN_IF_ERROR(session->Create(g));
+ std::vector<Tensor> outputs;
+ TF_RETURN_IF_ERROR(
+ session->Run({{"input1:0", t1}, {"input2:0", t2}}, {},
+ {"collective_call1:0", "collective_call2:0"}, &outputs));
+ DirectSession* direct_session = static_cast<DirectSession*>(session.get());
+ {
+ mutex_lock l(direct_session->collective_graph_key_lock_);
+ *collective_graph_key = direct_session->collective_graph_key_;
+ }
+ return Status::OK();
+ }
+
+ private:
+ // Creates a function with name `function_name` and a single CollectiveReduce
+ // node with instance key set as `instance_key`.
+ FunctionDef CollectiveFunction(const string& function_name,
+ int instance_key) {
+ return FunctionDefHelper::Define(
+ // Function name
+ function_name,
+ // In def
+ {"arg:float"},
+ // Out def
+ {"reduce:float"},
+ // Attr def
+ {},
+ // Node def
+ {{
+ {"reduce"},
+ "CollectiveReduce",
+ {"arg"},
+ {{"group_size", 2},
+ {"group_key", 1},
+ {"instance_key", instance_key},
+ {"subdiv_offsets", gtl::ArraySlice<int32>({0})},
+ {"merge_op", "Add"},
+ {"final_op", "Div"},
+ {"T", DT_FLOAT}},
+ }});
+ }
+
+ // Creates a GraphDef that adds two CollectiveFunctions, one each on CPU0 and
+ // CPU1, with instance_key 1, and appropriate placeholder inputs. If
+ // `add_unused_function` is true, adds another CollectiveFunction with
+ // instance_key 2 that is not invoked in the graph.
+ GraphDef CreateGraph(bool add_unused_function) {
+ GraphDef g;
+ FunctionDef collective_function =
+ CollectiveFunction("CollectiveFunction1", 1);
+ FunctionDefLibrary* lib = g.mutable_library();
+ *lib->add_function() = collective_function;
+ if (add_unused_function) {
+ FunctionDef unused_function =
+ CollectiveFunction("CollectiveFunction2", 2);
+ *lib->add_function() = unused_function;
+ }
+
+ // Inputs.
+ AttrValue dtype_attr;
+ SetAttrValue(DT_FLOAT, &dtype_attr);
+ NodeDef input1;
+ input1.set_name("input1");
+ input1.set_op("Placeholder");
+ input1.mutable_attr()->insert({"dtype", dtype_attr});
+ NodeDef input2;
+ input2.set_name("input2");
+ input2.set_op("Placeholder");
+ input2.mutable_attr()->insert({"dtype", dtype_attr});
+
+ // CollectiveReduce on CPU0 with instance_key 1.
+ NodeDef collective_call1;
+ collective_call1.set_name("collective_call1");
+ collective_call1.set_op("CollectiveFunction1");
+ collective_call1.add_input("input1");
+ collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:0");
+ // CollectiveReduce on CPU1 with instance_key 1.
+ NodeDef collective_call2;
+ collective_call2.set_name("collective_call2");
+ collective_call2.set_op("CollectiveFunction1");
+ collective_call2.add_input("input2");
+ collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:1");
+
+ *g.add_node() = input1;
+ *g.add_node() = input2;
+ *g.add_node() = collective_call1;
+ *g.add_node() = collective_call2;
+
+ return g;
+ }
+};
+
+#ifndef GOOGLE_CUDA
+// TODO(ayushd): enable this test for GPU builds.
+TEST_F(DirectSessionCollectiveTest,
+ TestCollectiveGraphKeyUsesOnlyCalledFunctions) {
+ int64 key1;
+ TF_ASSERT_OK(RunGraphWithCollectiveFunctions(false, &key1));
+ int64 key2;
+ TF_ASSERT_OK(RunGraphWithCollectiveFunctions(true, &key2));
+ ASSERT_EQ(key1, key2);
+}
+#endif
+
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index 7f28f3b793..be5f3bae3a 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -219,7 +219,9 @@ tf_cuda_library(
visibility = ["//tensorflow:internal"],
deps = [
":kernel_and_device",
- "//tensorflow/c:c_api",
+ # Only the TF_AttrType enum is required, so pull in just the C headers.
+ # TODO(b/113535673): Break this dependency and avoid the C header completely.
+ "//tensorflow/c:c_api_headers",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
@@ -249,6 +251,7 @@ tf_cc_test(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index b859b06fa0..39a3b49cd1 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -26,7 +26,7 @@ namespace {
bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
bool val;
- if (ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
+ if (tensorflow::ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
return val;
}
return default_val;
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 02193dae5a..84865397bc 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -1482,6 +1482,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
const Status fill_status =
device->FillContextMap(graph, &device_context_map_);
if (!fill_status.ok()) {
+ delete this;
done(fill_status);
return;
}
@@ -1492,6 +1493,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
ready.push_back(TaggedNode{n, root_frame_, 0, false});
}
if (ready.empty()) {
+ delete this;
done(Status::OK());
} else {
num_outstanding_ops_ = ready.size();
@@ -2419,8 +2421,7 @@ void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
}
if (dst_ready) {
if (IsControlTrigger(dst_node)) dst_dead = false;
- ready->push_back(
- TaggedNode(dst_node, parent_frame, parent_iter, dst_dead));
+ ready->emplace_back(dst_node, parent_frame, parent_iter, dst_dead);
parent_iter_state->outstanding_ops++;
}
}
@@ -2544,7 +2545,7 @@ void ExecutorState::FrameState::ActivateNodes(const NodeItem* item,
// Add dst to the ready queue if it's ready
if (dst_ready) {
if (dst_item->is_control_trigger) dst_dead = false;
- ready->push_back(TaggedNode(dst_item->node, this, iter, dst_dead));
+ ready->emplace_back(dst_item->node, this, iter, dst_dead);
iter_state->outstanding_ops++;
}
}
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index fb89bcc0df..46bb8d92f8 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -310,6 +310,7 @@ class CallOp : public AsyncOpKernel {
opts.step_container = ctx->step_container();
opts.stats_collector = ctx->stats_collector();
opts.runner = ctx->runner();
+ opts.collective_executor = ctx->collective_executor();
std::vector<Tensor> args;
args.reserve(ctx->num_inputs());
for (int i = 0; i < ctx->num_inputs(); ++i) {
@@ -346,9 +347,10 @@ const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
return nullptr;
}
- mutex_lock l(mu_);
- CHECK_EQ(1, items_.count(local_handle));
- return items_[local_handle]->func_graph;
+ tf_shared_lock l(mu_);
+ auto iter = items_.find(local_handle);
+ CHECK(iter != items_.end());
+ return iter->second->func_graph;
}
Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
@@ -633,7 +635,7 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
const FunctionLibraryDefinition* lib_def;
string executor_type;
{
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
fbody = (*item)->func_graph;
lib_def = (*item)->overlay_lib;
executor_type = (*item)->executor_type;
@@ -682,12 +684,13 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
{
- mutex_lock l(mu_);
- if (items_.count(local_handle) == 0) {
+ tf_shared_lock l(mu_);
+ auto iter = items_.find(local_handle);
+ if (iter == items_.end()) {
return errors::NotFound("Function handle ", handle,
" is not valid. Likely an internal error.");
}
- *item = items_[local_handle].get();
+ *item = iter->second.get();
if ((*item)->exec != nullptr) {
return Status::OK();
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 3292ef2f62..2763ac0d4a 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -917,16 +917,21 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
}
const auto& gpu_options = options.config.gpu_options();
std::vector<CudaGpuId> visible_gpu_order;
- TF_RETURN_IF_ERROR(ParseVisibleDeviceList(gpu_options.visible_device_list(),
- &visible_gpu_order));
-
std::vector<CudaGpuId> valid_cuda_gpu_ids;
- TF_RETURN_IF_ERROR(GetValidDeviceIds(visible_gpu_order, &valid_cuda_gpu_ids));
+ // If we aren't going to use any GPUs, don't initialize them.
+ // We don't want to call ParseVisibleDeviceList if num_gpus_to_use is 0,
+ // because it treats an empty gpu_options.visible_device_list as 'all GPUs are
+ // visible'.
+ if (num_gpus_to_use > 0) {
+ TF_RETURN_IF_ERROR(ParseVisibleDeviceList(gpu_options.visible_device_list(),
+ &visible_gpu_order));
+ TF_RETURN_IF_ERROR(
+ GetValidDeviceIds(visible_gpu_order, &valid_cuda_gpu_ids));
+ }
if (num_gpus_to_use > valid_cuda_gpu_ids.size()) {
num_gpus_to_use = valid_cuda_gpu_ids.size();
}
- // If we aren't going to use any GPUs, don't initialize them.
- if (num_gpus_to_use > 0 && !valid_cuda_gpu_ids.empty()) {
+ if (!valid_cuda_gpu_ids.empty()) {
// Save the original device.
int original_device = 0;
cudaError_t err = cudaGetDevice(&original_device);
diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc
index 346befc255..7f260b3139 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/graph_execution_state.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_execution_state.h"
#include <memory>
+#include <set>
#include <string>
#include <unordered_set>
#include <utility>
@@ -727,12 +728,50 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
+ int64 collective_graph_key = options.collective_graph_key;
+ if (collective_graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
+ // BuildGraphOptions does not specify a collective_graph_key. Check all
+ // nodes in the Graph and FunctionLibraryDefinition for collective ops and
+ // if found, initialize a collective_graph_key as a hash of the ordered set
+ // of instance keys.
+ std::set<int32> instance_key_set;
+ for (Node* node : optimized_graph->nodes()) {
+ if (node->IsCollective()) {
+ int32 instance_key;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(node->attrs(), "instance_key", &instance_key));
+ instance_key_set.emplace(instance_key);
+ } else {
+ const FunctionDef* fdef = optimized_flib->Find(node->def().op());
+ if (fdef != nullptr) {
+ for (const NodeDef& ndef : fdef->node_def()) {
+ if (ndef.op() == "CollectiveReduce" ||
+ ndef.op() == "CollectiveBcastSend" ||
+ ndef.op() == "CollectiveBcastRecv") {
+ int32 instance_key;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(ndef, "instance_key", &instance_key));
+ instance_key_set.emplace(instance_key);
+ }
+ }
+ }
+ }
+ }
+ if (!instance_key_set.empty()) {
+ uint64 hash = 0x8774aa605c729c72ULL;
+ for (int32 instance_key : instance_key_set) {
+ hash = Hash64Combine(instance_key, hash);
+ }
+ collective_graph_key = hash;
+ }
+ }
+
// Copy the extracted graph in order to make its node ids dense,
// since the local CostModel used to record its stats is sized by
// the largest node id.
std::unique_ptr<ClientGraph> dense_copy(
new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types,
- rewrite_metadata.fetch_types));
+ rewrite_metadata.fetch_types, collective_graph_key));
CopyGraph(*optimized_graph, &dense_copy->graph);
// TODO(vrv): We should check invariants of the graph here.
diff --git a/tensorflow/core/common_runtime/graph_execution_state.h b/tensorflow/core/common_runtime/graph_execution_state.h
index d44a24c87b..9cabe478a6 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.h
+++ b/tensorflow/core/common_runtime/graph_execution_state.h
@@ -50,17 +50,20 @@ struct GraphExecutionStateOptions {
// BuildGraphOptions.
struct ClientGraph {
explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
- DataTypeVector feed_types, DataTypeVector fetch_types)
+ DataTypeVector feed_types, DataTypeVector fetch_types,
+ int64 collective_graph_key)
: flib_def(std::move(flib)),
graph(flib_def.get()),
feed_types(std::move(feed_types)),
- fetch_types(std::move(fetch_types)) {}
+ fetch_types(std::move(fetch_types)),
+ collective_graph_key(collective_graph_key) {}
// Each client-graph gets its own function library since optimization passes
// post rewrite for execution might want to introduce new functions.
std::unique_ptr<FunctionLibraryDefinition> flib_def;
Graph graph;
DataTypeVector feed_types;
DataTypeVector fetch_types;
+ int64 collective_graph_key;
};
// GraphExecutionState is responsible for generating an
diff --git a/tensorflow/core/common_runtime/lower_while_op.cc b/tensorflow/core/common_runtime/lower_while_op.cc
new file mode 100644
index 0000000000..1f5da133e9
--- /dev/null
+++ b/tensorflow/core/common_runtime/lower_while_op.cc
@@ -0,0 +1,427 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/lower_while_op.h"
+#include "tensorflow/core/common_runtime/lower_if_op.h"
+
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
+
+namespace tensorflow {
+
+namespace {
+
+using NodeOut = NodeBuilder::NodeOut;
+
+// Helper to convert a functional While op to its lowered form.
+//
+// Example:
+//
+// Input graph:
+//
+// loop_var -> WhileOp<cond_func, body_func> -> consumer
+//
+// Output graph(top to down flow):
+//
+// loop_var
+// |
+// Enter
+// |
+// inlined_cond_func ---<--- Merge -----<----- NextIteration
+// | | |
+// V V ^
+// | | |
+// LoopCond ------>-------- Switch ---->---- inlined_body_func
+// |
+// Exit
+// |
+// consumer
+class LowerWhileHelper {
+ public:
+ static Status Run(Node* while_op, const string& cond_fn_name,
+ const string& body_fn_name, Graph* graph) {
+ LowerWhileHelper helper(while_op, cond_fn_name, body_fn_name, graph);
+ return helper.RunInternal();
+ }
+
+ private:
+ // Create a LowerWhileHelper to create the lowering of While op that has cond
+ // and body functions named `cond_fn_name` and `body_fn_name` respectively in
+ // the given graph.
+ LowerWhileHelper(Node* while_op, const string& cond_fn_name,
+ const string& body_fn_name, Graph* graph);
+
+ Status RunInternal();
+
+ // Creates an Enter node for each `while_op_` input and adds them to
+ // `enter_nodes_`. If the `while_op_` has an incoming control edge from a
+ // `src` node we add a control edge from `src` to each Enter node.
+ Status CreateEnterNodes();
+
+ // Creates a Merge node for each Enter node and adds to `merge_nodes_`.
+ // Initially now both inputs of a Merge node are the Enter node. Input at
+ // index 1 is later updated to the output of NextIteration node in
+ // `UpdateMergeNodes`.
+ Status CreateMergeNodes();
+
+ // Creates the call node for cond func and stores in `cond_call_node_`.
+ // This gets inlined later in `InlineCallNodes`.
+ Status CreateCondFuncCallNode();
+
+ // Creates a Switch node for each loop var and adds to `switch_nodes_`.
+ // Output at index 1(true) of a Switch node is fed into the loop body.
+ // Output at index 0(false) of a Switch node is fed into the Exit nodes.
+ Status CreateSwitchNodes();
+
+ // Creates the call node for body func and stores in `body_call_node_`.
+ // This gets inlined later in `InlineCallNodes`.
+ Status CreateBodyFuncCallNode();
+
+ // Creates an Exit node for each loop var and adds to `exit_nodes_`. These
+ // are fed into the consumers of the `while_op_`.
+ Status CreateExitNodes();
+
+ // Creates an NextIteration node for each loop var and adds to
+ // `next_iteration_nodes_`.
+ Status CreateNextIterationNodes();
+
+ // Updates input at index 1 of each merge node created in `CreateMergeNodes`
+ // to use the output of NextIteration node created in
+ // `CreateNextIterationNodes` instead.
+ Status UpdateMergeNodes();
+
+ // Updates consumers of the original `while_op_` to instead use the outputs
+ // from the exit nodes in `exit_nodes_`. Also updates any outgoing control
+ // edges to depend on `lowered_while_output_` instead.
+ Status UpdateConsumers();
+
+ // Inlines the cond and body functions.
+ Status InlineCallNodes();
+
+ // Returns unique name containing the name of the While op being rewritten
+ // (name_), infix and a suffix to ensure it is unique within the graph.
+ string NewName(const string& infix);
+
+ // The original While op.
+ Node* while_op_;
+ // The call node for the cond branch. This gets inlined.
+ Node* cond_call_node_;
+ // The LoopCond node specifying the loop termination condition.
+ Node* loop_cond_node_;
+ // The call node for the body branch. This gets inlined.
+ Node* body_call_node_;
+ // The IdentityN node with the same outputs as the original While op.
+ Node* lowered_while_output_;
+ Graph* graph_;
+ // Name of the `while_op_`.
+ string name_;
+
+ NodeBuilder cond_call_builder_;
+ NodeBuilder body_call_builder_;
+
+ std::vector<Node*> enter_nodes_;
+ std::vector<Node*> merge_nodes_;
+ std::vector<Node*> switch_nodes_;
+ std::vector<Node*> exit_nodes_;
+ std::vector<Node*> next_iterations_nodes_;
+
+ size_t num_loop_inputs_;
+};
+
+LowerWhileHelper::LowerWhileHelper(Node* while_op, const string& cond_fn_name,
+ const string& body_fn_name, Graph* graph)
+ : while_op_(while_op),
+ graph_(graph),
+ name_(while_op->name()),
+ cond_call_builder_(NewName("cond"), cond_fn_name, graph->op_registry()),
+ body_call_builder_(NewName("body"), body_fn_name, graph->op_registry()),
+ num_loop_inputs_(while_op_->num_inputs()) {
+ // We intentionally `resize` instead of `reserve` space in `enter_nodes_`
+ // because we need to set it's elements out of order in `CreateEnterNodes`.
+ enter_nodes_.resize(num_loop_inputs_);
+ merge_nodes_.reserve(num_loop_inputs_);
+ switch_nodes_.reserve(num_loop_inputs_);
+ exit_nodes_.reserve(num_loop_inputs_);
+ next_iterations_nodes_.reserve(num_loop_inputs_);
+}
+
+Status LowerWhileHelper::RunInternal() {
+ TF_RETURN_IF_ERROR(CreateEnterNodes());
+ TF_RETURN_IF_ERROR(CreateMergeNodes());
+ TF_RETURN_IF_ERROR(CreateCondFuncCallNode());
+ TF_RETURN_IF_ERROR(CreateSwitchNodes());
+ TF_RETURN_IF_ERROR(CreateBodyFuncCallNode());
+ TF_RETURN_IF_ERROR(CreateExitNodes());
+ TF_RETURN_IF_ERROR(CreateNextIterationNodes());
+ TF_RETURN_IF_ERROR(UpdateMergeNodes());
+ TF_RETURN_IF_ERROR(UpdateConsumers());
+ TF_RETURN_IF_ERROR(InlineCallNodes());
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateEnterNodes() {
+ // Note: `Node::input_edge` runs in O(num_inputs) so we use
+ // `Node::input_edges` instead so that below loop runs in O(num_inputs) time
+ // and not O(num_inputs^2).
+ std::vector<const Edge*> edges;
+ TF_RETURN_IF_ERROR(while_op_->input_edges(&edges));
+ for (const Edge* edge : edges) {
+ Node* enter_node;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("enter"), "Enter", graph_->op_registry())
+ .Input(NodeOut(edge->src(), edge->src_output()))
+ .Attr("frame_name", name_)
+ .Finalize(graph_, &enter_node));
+ enter_nodes_[edge->dst_input()] = enter_node;
+ }
+ // Create a NoOp node that takes incoming control inputs of the original While
+ // op as control inputs and use it as a control input for all Enter nodes.
+ std::vector<Node*> control_inputs;
+ for (const Edge* e : while_op_->in_edges()) {
+ if (e->IsControlEdge()) {
+ control_inputs.push_back(e->src());
+ }
+ }
+ if (!control_inputs.empty()) {
+ Node* incoming_control_node;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("LoopControlInputs"), "NoOp", graph_->op_registry())
+ .ControlInputs(control_inputs)
+ .Finalize(graph_, &incoming_control_node));
+ for (Node* n : enter_nodes_) {
+ graph_->AddControlEdge(incoming_control_node, n);
+ }
+ }
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateMergeNodes() {
+ for (Node* enter_node : enter_nodes_) {
+ Node* merge_node;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("merge"), "Merge", graph_->op_registry())
+ .Input({NodeOut(enter_node, 0), NodeOut(enter_node, 0)})
+ .Finalize(graph_, &merge_node));
+ merge_nodes_.emplace_back(merge_node);
+ }
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateCondFuncCallNode() {
+ for (Node* merge_node : merge_nodes_) {
+ cond_call_builder_.Input(NodeOut(merge_node, 0));
+ }
+ TF_RETURN_IF_ERROR(cond_call_builder_.Finalize(graph_, &cond_call_node_));
+ // Add a control edge to make sure the Const nodes in the cond function
+ // are in the same frame as the rest of the function, otherwise
+ // `BuildControlFlowInfo` throws an error.
+ graph_->AddControlEdge(merge_nodes_[0], cond_call_node_);
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("LoopCond"), "LoopCond", graph_->op_registry())
+ .Input(NodeOut(cond_call_node_, 0))
+ .Finalize(graph_, &loop_cond_node_));
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateSwitchNodes() {
+ for (int i = 0; i < num_loop_inputs_; i++) {
+ string op_name;
+ {
+ const Node* input_node;
+ TF_RETURN_IF_ERROR(while_op_->input_node(i, &input_node));
+ op_name = strings::StrCat(input_node->name(), "_switch");
+ }
+ Node* switch_node;
+ string op_type = "Switch";
+ if (IsRefType(merge_nodes_[i]->output_type(0))) {
+ op_type = "RefSwitch";
+ }
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName(op_name), op_type, graph_->op_registry())
+ .Input(NodeOut(merge_nodes_[i], 0))
+ .Input(NodeOut(loop_cond_node_, 0))
+ .Finalize(graph_, &switch_node));
+ switch_nodes_.emplace_back(switch_node);
+ }
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateBodyFuncCallNode() {
+ for (Node* switch_node : switch_nodes_) {
+ body_call_builder_.Input(NodeOut(switch_node, 1));
+ }
+ TF_RETURN_IF_ERROR(body_call_builder_.Finalize(graph_, &body_call_node_));
+ // Add a control edge to make sure the Const nodes in the body function
+ // are in the same frame as the rest of the function, otherwise
+ // `BuildControlFlowInfo` throws an error.
+ // TODO(srbs): The choice of input at index 0 seems arbitrary(is it?) however
+ // this is how tf.while_loop does it. Can this affect performance if the 0th
+ // node is not the first one to be ready? Can we speed that case up using some
+ // sort of multi-input Merge?
+ Node* body_control_node_;
+ string op_type = "Identity";
+ if (IsRefType(switch_nodes_[0]->output_type(1))) {
+ op_type = "RefIdentity";
+ }
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("loop_body_control"), op_type, graph_->op_registry())
+ .Input(NodeOut(switch_nodes_[0], 1))
+ .Finalize(graph_, &body_control_node_));
+ graph_->AddControlEdge(body_control_node_, body_call_node_);
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateExitNodes() {
+ std::vector<NodeOut> outputs;
+ outputs.reserve(num_loop_inputs_);
+ for (Node* switch_node : switch_nodes_) {
+ Node* exit_node;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("exit"), "Exit", graph_->op_registry())
+ .Input(NodeOut(switch_node, 0))
+ .Finalize(graph_, &exit_node));
+ exit_nodes_.emplace_back(exit_node);
+ outputs.emplace_back(NodeOut(exit_node, 0));
+ }
+
+ // Add an IdentityN node that has the same outputs and same name as the
+ // original functional While op. This is used for
+ // 1. Rewiring the control edges with the original while op as src.
+ // 2. Fetching the output of the While node by name in calls to sess.run.
+ NodeBuilder ib(name_, "IdentityN");
+ ib.Input(outputs);
+ TF_RETURN_IF_ERROR(ib.Finalize(graph_, &lowered_while_output_));
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateNextIterationNodes() {
+ for (int i = 0; i < num_loop_inputs_; i++) {
+ Node* next_iteration;
+ TF_RETURN_IF_ERROR(NodeBuilder(NewName("next_iteration"), "NextIteration",
+ graph_->op_registry())
+ .Input(NodeOut(body_call_node_, i))
+ .Finalize(graph_, &next_iteration));
+ next_iterations_nodes_.emplace_back(next_iteration);
+ }
+ return Status::OK();
+}
+
+Status LowerWhileHelper::UpdateMergeNodes() {
+ for (int i = 0; i < num_loop_inputs_; i++) {
+ TF_RETURN_IF_ERROR(
+ graph_->UpdateEdge(next_iterations_nodes_[i], 0, merge_nodes_[i], 1));
+ }
+ return Status::OK();
+}
+
+Status LowerWhileHelper::UpdateConsumers() {
+ for (const Edge* e : while_op_->out_edges()) {
+ if (e->IsControlEdge()) {
+ graph_->AddControlEdge(lowered_while_output_, e->dst());
+ } else {
+ // Feed the outputs directly from the exit nodes so that downstream ops
+ // can start before all the outputs have been computed.
+ graph_->AddEdge(exit_nodes_[e->src_output()], 0, e->dst(),
+ e->dst_input());
+ }
+ }
+ return Status::OK();
+}
+
+string LowerWhileHelper::NewName(const string& infix) {
+ return graph_->NewName(strings::StrCat(name_, "/", infix));
+}
+
+Status InlineCallInGraph(Node* n, Graph* g) {
+ const auto& lib = g->flib_def();
+ const FunctionDef* fdef = lib.Find(n->type_string());
+ CHECK(fdef != nullptr);
+ FunctionBody* fbody;
+ TF_RETURN_IF_ERROR(
+ FunctionDefToBodyHelper(*fdef, n->attrs(), &lib,
+ [&lib](const string& op, const OpDef** sig) {
+ return lib.LookUpOpDef(op, sig);
+ },
+ &fbody));
+ // TODO(jpienaar): Improve this interface to make the need to delete it
+ // explicit.
+ InlineFunctionBody(g->flib_def(), g, n, fbody, false);
+ delete fbody;
+ return Status::OK();
+}
+
+Status LowerWhileHelper::InlineCallNodes() {
+ TF_RETURN_IF_ERROR(InlineCallInGraph(cond_call_node_, graph_));
+ TF_RETURN_IF_ERROR(InlineCallInGraph(body_call_node_, graph_));
+ return Status::OK();
+}
+
+} // namespace
+
+Status LowerWhileOpPass::Run(const GraphOptimizationPassOptions& options) {
+ if (options.partition_graphs != nullptr) {
+ return errors::Internal(
+ "Lowering While op should happen before partitioning.");
+ }
+ if (options.graph == nullptr) {
+ return Status::OK();
+ }
+
+ Graph* g = options.graph->get();
+ if (g == nullptr) {
+ return errors::Internal(
+ "Lowering While op requires a graph to be available.");
+ }
+
+ // Match all the nodes that need to be rewritten.
+ gtl::InlinedVector<Node*, 2> matches;
+ for (Node* n : g->op_nodes()) {
+ if (n->type_string() == "While") {
+ // Only rewrite if the While op is marked as needing to be lowered.
+ bool match;
+ Status s = GetNodeAttr(n->attrs(),
+ LowerIfOpPass::kLowerUsingSwitchMergeAttr, &match);
+ if (s.ok() && match) matches.push_back(n);
+ }
+ }
+ for (Node* n : matches) {
+ TF_RETURN_IF_ERROR(RewriteNode(n, g));
+ }
+ return Status::OK();
+}
+
+Status LowerWhileOpPass::RewriteNode(Node* n, Graph* g) {
+ const AttrValue* cond_attr = n->attrs().Find("cond");
+ if (cond_attr == nullptr) {
+ return errors::InvalidArgument("While cond function missing");
+ }
+ const AttrValue* body_attr = n->attrs().Find("body");
+ if (body_attr == nullptr) {
+ return errors::InvalidArgument("While body function missing");
+ }
+
+ TF_RETURN_IF_ERROR(LowerWhileHelper::Run(n, cond_attr->func().name(),
+ body_attr->func().name(), g));
+ g->RemoveNode(n);
+
+ return Status::OK();
+}
+
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
+ LowerWhileOpPass);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/status_util_test.cc b/tensorflow/core/common_runtime/lower_while_op.h
index 1f06004db2..eadafbeb91 100644
--- a/tensorflow/core/util/status_util_test.cc
+++ b/tensorflow/core/common_runtime/lower_while_op.h
@@ -13,24 +13,25 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/util/status_util.h"
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_
-#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/graph/node_builder.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
-namespace {
-
-TEST(TestStatusUtil, ErrorFormatTagForNode) {
- Graph graph(OpRegistry::Global());
- Node* node;
- TF_CHECK_OK(NodeBuilder("Foo", "NoOp").Finalize(&graph, &node));
- EXPECT_EQ(error_format_tag(*node, "${line}"), "^^node:Foo:${line}^^");
- EXPECT_EQ(error_format_tag(*node, "${file}:${line}"),
- "^^node:Foo:${file}:${line}^^");
-}
-
-} // namespace
+
+// Rewrite While ops to use lower level control flow primitives instead.
+class LowerWhileOpPass : public GraphOptimizationPass {
+ public:
+ Status Run(const GraphOptimizationPassOptions& options) override;
+
+ private:
+ // Rewrite the given While node `n` in graph `g` to use the lower level
+ // primitives Enter, Exit, Switch, Merge and NextIteration.
+ Status RewriteNode(Node* n, Graph* g);
+};
+
} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_
diff --git a/tensorflow/core/common_runtime/lower_while_op_test.cc b/tensorflow/core/common_runtime/lower_while_op_test.cc
new file mode 100644
index 0000000000..27cbada004
--- /dev/null
+++ b/tensorflow/core/common_runtime/lower_while_op_test.cc
@@ -0,0 +1,249 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/lower_while_op.h"
+#include "tensorflow/core/common_runtime/lower_if_op.h"
+
+#include "tensorflow/cc/client/client_session.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/control_flow_ops_internal.h"
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/common_runtime/graph_runner.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+Status Rewrite(std::unique_ptr<Graph>* graph) {
+ FunctionDefLibrary flib;
+ FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
+
+ GraphOptimizationPassOptions opt_options;
+ opt_options.graph = graph;
+ opt_options.flib_def = &flib_def;
+ LowerWhileOpPass pass;
+ return pass.Run(opt_options);
+}
+
+TEST(LowerWhileOpTest, Simple) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ // Add test functions for cond and body.
+ FunctionDefLibrary f_lib_proto;
+ *f_lib_proto.add_function() = test::function::XTimesTwo();
+ *f_lib_proto.add_function() = test::function::LessThanOrEqualToN(8);
+ FunctionLibraryDefinition f_lib(OpRegistry::Global(), f_lib_proto);
+
+ Scope root = Scope::NewRootScope().ExitOnError();
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
+ auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0);
+ Node* while_node;
+ std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
+ AttrValue cond_func;
+ cond_func.mutable_func()->set_name("LessThanOrEqualToN");
+ AttrValue body_func;
+ body_func.mutable_func()->set_name("XTimesTwo");
+ TF_ASSERT_OK(NodeBuilder("while", "While", &f_lib)
+ .Input(inputs)
+ .Attr("T", {DT_INT32})
+ .Attr("cond", cond_func)
+ .Attr("body", body_func)
+ .Attr(LowerIfOpPass::kLowerUsingSwitchMergeAttr, true)
+ .Finalize(root.graph(), &while_node));
+ TF_ASSERT_OK(root.DoShapeInference(while_node));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ // The input graph has no lower level control flow primitives.
+ int node_called_while_count = 0;
+ for (const auto* op : graph->op_nodes()) {
+ ASSERT_FALSE(op->IsEnter());
+ ASSERT_FALSE(op->IsExit());
+ ASSERT_FALSE(op->IsSwitch());
+ ASSERT_FALSE(op->IsMerge());
+ ASSERT_FALSE(op->IsNextIteration());
+ ASSERT_FALSE(op->IsLoopCond());
+ if (op->name() == "while") {
+ node_called_while_count++;
+ }
+ }
+ ASSERT_EQ(node_called_while_count, 1);
+
+ TF_ASSERT_OK(Rewrite(&graph));
+
+ int enter_count = 0;
+ int exit_count = 0;
+ int switch_count = 0;
+ int merge_count = 0;
+ int next_iteration_count = 0;
+ node_called_while_count = 0;
+ for (const auto* op : graph->op_nodes()) {
+ if (op->IsEnter()) {
+ ++enter_count;
+ }
+ if (op->IsExit()) {
+ ++exit_count;
+ }
+ if (op->IsSwitch()) {
+ ++switch_count;
+ }
+ if (op->IsMerge()) {
+ ++merge_count;
+ }
+ if (op->IsNextIteration()) {
+ ++next_iteration_count;
+ }
+ if (op->name() == "while") {
+ node_called_while_count++;
+ }
+ ASSERT_NE(op->type_string(), "While");
+ }
+ // One node per loop input.
+ ASSERT_EQ(enter_count, 1);
+ ASSERT_EQ(exit_count, 1);
+ ASSERT_EQ(switch_count, 1);
+ ASSERT_EQ(merge_count, 1);
+ ASSERT_EQ(next_iteration_count, 1);
+ ASSERT_EQ(node_called_while_count, 1);
+
+ // Verify execution.
+ ClientSession session(root);
+ {
+ ClientSession::FeedType feeds;
+ feeds.emplace(Output(a.node()), Input::Initializer(1));
+ std::vector<Tensor> out_tensors;
+ TF_ASSERT_OK(session.Run(feeds, {Output(while_node)}, &out_tensors));
+ ASSERT_EQ(out_tensors.size(), 1);
+ EXPECT_EQ(out_tensors[0].scalar<int>()(), 16);
+ }
+ {
+ ClientSession::FeedType feeds;
+ feeds.emplace(Output(a.node()), Input::Initializer(3));
+ std::vector<Tensor> out_tensors;
+ TF_ASSERT_OK(session.Run(feeds, {Output(while_node)}, &out_tensors));
+ ASSERT_EQ(out_tensors.size(), 1);
+ EXPECT_EQ(out_tensors[0].scalar<int>()(), 12);
+ }
+}
+
+TEST(LowerWhileOpTest, MultipleInputs) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ // Add test functions for cond and body.
+ FunctionDefLibrary f_lib_proto;
+ *(f_lib_proto.add_function()) = test::function::XPlusOneXTimesY();
+ *(f_lib_proto.add_function()) = test::function::XYXLessThanOrEqualToN(4);
+ FunctionLibraryDefinition f_lib(OpRegistry::Global(), f_lib_proto);
+
+ Scope root = Scope::NewRootScope().ExitOnError();
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
+ auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0);
+ auto b = ops::_Arg(root.WithOpName("B"), DT_INT32, 1);
+ Node* while_node;
+ std::vector<NodeBuilder::NodeOut> inputs(
+ {NodeBuilder::NodeOut(a.node()), NodeBuilder::NodeOut(b.node())});
+ AttrValue cond_func;
+ cond_func.mutable_func()->set_name("XYXLessThanOrEqualToN");
+ AttrValue body_func;
+ body_func.mutable_func()->set_name("XPlusOneXTimesY");
+ TF_ASSERT_OK(NodeBuilder("while", "While", &f_lib)
+ .Input(inputs)
+ .Attr("T", {DT_INT32, DT_INT32})
+ .Attr("cond", cond_func)
+ .Attr("body", body_func)
+ .Attr(LowerIfOpPass::kLowerUsingSwitchMergeAttr, true)
+ .Finalize(root.graph(), &while_node));
+ TF_ASSERT_OK(root.DoShapeInference(while_node));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ // The input graph has no lower level control flow primitives.
+ for (const auto* op : graph->op_nodes()) {
+ ASSERT_FALSE(op->IsEnter());
+ ASSERT_FALSE(op->IsExit());
+ ASSERT_FALSE(op->IsSwitch());
+ ASSERT_FALSE(op->IsMerge());
+ ASSERT_FALSE(op->IsNextIteration());
+ ASSERT_FALSE(op->IsLoopCond());
+ }
+
+ TF_ASSERT_OK(Rewrite(&graph));
+
+ int enter_count = 0;
+ int exit_count = 0;
+ int switch_count = 0;
+ int merge_count = 0;
+ int next_iteration_count = 0;
+ for (const auto* op : graph->op_nodes()) {
+ if (op->IsEnter()) {
+ ++enter_count;
+ }
+ if (op->IsExit()) {
+ ++exit_count;
+ }
+ if (op->IsSwitch()) {
+ ++switch_count;
+ }
+ if (op->IsMerge()) {
+ ++merge_count;
+ }
+ if (op->IsNextIteration()) {
+ ++next_iteration_count;
+ }
+ ASSERT_NE(op->type_string(), "While");
+ }
+ // Two nodes per loop input.
+ ASSERT_EQ(enter_count, 2);
+ ASSERT_EQ(exit_count, 2);
+ ASSERT_EQ(switch_count, 2);
+ ASSERT_EQ(merge_count, 2);
+ ASSERT_EQ(next_iteration_count, 2);
+
+ // Verify execution.
+ ClientSession session(root);
+ {
+ ClientSession::FeedType feeds;
+ feeds.emplace(Output(a.node()), Input::Initializer(1));
+ feeds.emplace(Output(b.node()), Input::Initializer(1));
+ std::vector<Tensor> out_tensors;
+ TF_ASSERT_OK(session.Run(
+ feeds, {Output(while_node, 0), Output(while_node, 1)}, &out_tensors));
+ ASSERT_EQ(out_tensors.size(), 2);
+ EXPECT_EQ(out_tensors[0].scalar<int>()(), 5);
+ EXPECT_EQ(out_tensors[1].scalar<int>()(), 24);
+ }
+ {
+ ClientSession::FeedType feeds;
+ feeds.emplace(Output(a.node()), Input::Initializer(3));
+ feeds.emplace(Output(b.node()), Input::Initializer(5));
+ std::vector<Tensor> out_tensors;
+ TF_ASSERT_OK(session.Run(
+ feeds, {Output(while_node, 0), Output(while_node, 1)}, &out_tensors));
+ ASSERT_EQ(out_tensors.size(), 2);
+ EXPECT_EQ(out_tensors[0].scalar<int>()(), 5);
+ EXPECT_EQ(out_tensors[1].scalar<int>()(), 60);
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
index 99bd43e090..6b76e7e0e7 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
@@ -148,12 +148,14 @@ class MklCPUAllocator : public VisitableAllocator {
Status s = Status(error::Code::UNIMPLEMENTED,
"Unimplemented case for hooking MKL function.");
TF_CHECK_OK(s); // way to assert with an error message
+ return nullptr; // return a value and make static code analyzers happy
}
static inline void* ReallocHook(void* ptr, size_t size) {
Status s = Status(error::Code::UNIMPLEMENTED,
"Unimplemented case for hooking MKL function.");
TF_CHECK_OK(s); // way to assert with an error message
+ return nullptr; // return a value and make static code analyzers happy
}
/// Do we allow growth in BFC Allocator
@@ -166,6 +168,9 @@ class MklCPUAllocator : public VisitableAllocator {
static constexpr const size_t kAlignment = 64;
VisitableAllocator* allocator_; // owned by this class
+
+ // Prevent copying and assignment
+ TF_DISALLOW_COPY_AND_ASSIGN(MklCPUAllocator);
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc
index d581f45a90..7f3c25d81d 100644
--- a/tensorflow/core/common_runtime/placer.cc
+++ b/tensorflow/core/common_runtime/placer.cc
@@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/util/status_util.h"
namespace tensorflow {
@@ -934,14 +933,13 @@ bool Placer::ClientHandlesErrorFormatting() const {
// Returns the node name in single quotes. If the client handles formatted
// errors, appends a formatting tag which the client will reformat into, for
// example, " (defined at filename:123)".
+// TODO(shikharagarwal): Remove this function once
+// client_handles_error_formatting flag is removed.
string Placer::RichNodeName(const Node* node) const {
- string quoted_name = strings::StrCat("'", node->name(), "'");
if (ClientHandlesErrorFormatting()) {
- string file_and_line = error_format_tag(*node, "${defined_at}");
- return strings::StrCat(quoted_name, file_and_line);
- } else {
- return quoted_name;
+ return errors::FormatNodeNameForError(node->name());
}
+ return strings::StrCat("'", node->name(), "'");
}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc
index 87f2f2ceb9..83d27e2730 100644
--- a/tensorflow/core/common_runtime/placer_test.cc
+++ b/tensorflow/core/common_runtime/placer_test.cc
@@ -1159,9 +1159,8 @@ TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementFormatTag) {
Status s = Place(&g, &options);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
LOG(WARNING) << s.error_message();
- EXPECT_TRUE(str_util::StrContains(s.error_message(),
- "Cannot assign a device for operation 'in'"
- "^^node:in:${defined_at}^^"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(), "Cannot assign a device for operation {{node in}}"));
}
// Test that the "Cannot assign a device" error message does not contain a
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 6dac4c3acf..c43a9d7dc2 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -113,7 +113,7 @@ void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
const string& key_prefix, int64 src_incarnation, int64 num_tensors,
DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs, Rendezvous* rendezvous,
- std::vector<Tensor>* received_tensors, const StatusCallback& done) {
+ std::vector<Tensor>* received_tensors, StatusCallback done) {
std::vector<string> keys;
for (int64 i = 0; i < num_tensors; ++i) {
string name = strings::StrCat(key_prefix, i);
@@ -121,9 +121,8 @@ void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
target_device, name, FrameAndIter(0, 0));
keys.push_back(key);
}
- RecvOutputsFromRendezvousAsync(
- rendezvous, device_context, alloc_attrs, keys, received_tensors,
- [done](const Status& status) { done(status); });
+ RecvOutputsFromRendezvousAsync(rendezvous, device_context, alloc_attrs, keys,
+ received_tensors, std::move(done));
}
Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation(
@@ -192,7 +191,7 @@ FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle(
const string& function_key) const {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return gtl::FindWithDefault(table_, function_key, kInvalidHandle);
}
@@ -204,11 +203,12 @@ bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice(
FunctionLibraryRuntime::LocalHandle
ProcessFunctionLibraryRuntime::GetHandleOnDevice(
const string& device_name, FunctionLibraryRuntime::Handle handle) {
- mutex_lock l(mu_);
- if (function_data_.count(handle) == 0) {
+ tf_shared_lock l(mu_);
+ auto iter = function_data_.find(handle);
+ if (iter == function_data_.end()) {
return kInvalidLocalHandle;
}
- FunctionData* function_data = function_data_[handle].get();
+ FunctionData* function_data = iter->second.get();
if (function_data->target_device() != device_name) {
return kInvalidLocalHandle;
}
@@ -217,9 +217,10 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice(
string ProcessFunctionLibraryRuntime::GetDeviceName(
FunctionLibraryRuntime::Handle handle) {
- mutex_lock l(mu_);
- CHECK_EQ(1, function_data_.count(handle));
- FunctionData* function_data = function_data_[handle].get();
+ tf_shared_lock l(mu_);
+ auto iter = function_data_.find(handle);
+ CHECK(iter != function_data_.end());
+ FunctionData* function_data = iter->second.get();
return function_data->target_device();
}
@@ -302,13 +303,15 @@ void ProcessFunctionLibraryRuntime::Run(
string target_device;
FunctionLibraryRuntime::LocalHandle local_handle;
{
- mutex_lock l(mu_);
- if (function_data_.count(handle) == 0) {
+ tf_shared_lock l(mu_);
+ auto iter = function_data_.find(handle);
+ if (iter == function_data_.end()) {
done(errors::NotFound("Handle: ", handle, " not found."));
return;
}
- target_device = function_data_[handle]->target_device();
- local_handle = function_data_[handle]->local_handle();
+ FunctionData* function_data = iter->second.get();
+ target_device = function_data->target_device();
+ local_handle = function_data->local_handle();
}
flr = GetFLR(target_device);
if (flr != nullptr) {
@@ -339,26 +342,29 @@ void ProcessFunctionLibraryRuntime::Run(
opts.rets_alloc_attrs;
std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
flr->Run(opts, handle, args, remote_rets,
- [source_device, target_device, target_incarnation, rendezvous,
- device_context, rets_alloc_attrs, remote_rets, rets,
- done](const Status& status) {
- if (!status.ok()) {
- delete remote_rets;
- done(status);
- return;
- }
- int64 num_returns = remote_rets->size();
- delete remote_rets;
- // Now receive the return values from the target.
- ReceiveTensorsAsync(target_device, source_device, "ret_",
- target_incarnation, num_returns,
- device_context, rets_alloc_attrs, rendezvous,
- rets, done);
- });
+ std::bind(
+ [source_device, target_device, target_incarnation, rendezvous,
+ device_context, rets_alloc_attrs, remote_rets,
+ rets](const Status& status,
+ FunctionLibraryRuntime::DoneCallback& done) {
+ if (!status.ok()) {
+ delete remote_rets;
+ done(status);
+ return;
+ }
+ int64 num_returns = remote_rets->size();
+ delete remote_rets;
+ // Now receive the return values from the target.
+ ReceiveTensorsAsync(target_device, source_device, "ret_",
+ target_incarnation, num_returns,
+ device_context, rets_alloc_attrs,
+ rendezvous, rets, std::move(done));
+ },
+ std::placeholders::_1, std::move(done)));
return;
}
if (parent_ != nullptr) {
- parent_->Run(opts, local_handle, args, rets, done);
+ parent_->Run(opts, local_handle, args, rets, std::move(done));
return;
}
done(errors::Internal("Could not find device"));
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index 69381dd34d..53815715d8 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
@@ -59,8 +60,6 @@ class ProcessFunctionLibraryRuntime {
const std::vector<AllocatorAttributes>& alloc_attrs,
Rendezvous* rendezvous);
- typedef std::function<void(const Status&)> StatusCallback;
-
// Receives `received_tensors` from `target_device` (originally sent from
// `source_device`) using `rendezvous`. Uses `key_prefix` to construct the
// keys to be retrieved. `device_context` should be for the device receiving
@@ -73,7 +72,7 @@ class ProcessFunctionLibraryRuntime {
DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
Rendezvous* rendezvous, std::vector<Tensor>* received_tensors,
- const StatusCallback& done);
+ StatusCallback done);
static const char kDefaultFLRDevice[];
// Returns the FunctionLibraryRuntime for the corresponding device_name.
diff --git a/tensorflow/core/common_runtime/rendezvous_util.cc b/tensorflow/core/common_runtime/rendezvous_util.cc
index 92dc03812e..1e3fed0d6f 100644
--- a/tensorflow/core/common_runtime/rendezvous_util.cc
+++ b/tensorflow/core/common_runtime/rendezvous_util.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/rendezvous_util.h"
+#include "tensorflow/core/util/reffed_status_callback.h"
+
namespace tensorflow {
Status SendTensorsToRendezvous(
@@ -54,7 +56,7 @@ void RecvOutputsFromRendezvousAsync(
Rendezvous* rendezvous, DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
const std::vector<string>& keys, std::vector<Tensor>* received_tensors,
- const StatusCallback& done) {
+ StatusCallback done) {
if (keys.empty()) {
done(Status::OK());
return;
@@ -85,13 +87,7 @@ void RecvOutputsFromRendezvousAsync(
alloc_attr);
}
- typedef struct {
- mutex mu;
- int64 done_counter;
- Status shared_status = Status::OK();
- } CallState;
- CallState* call_state = new CallState;
- call_state->done_counter = keys.size();
+ auto status_cb = new ReffedStatusCallback(std::move(done));
for (auto& p : arguments) {
const string& key = std::get<0>(p);
Tensor* val = std::get<1>(p);
@@ -99,13 +95,13 @@ void RecvOutputsFromRendezvousAsync(
Rendezvous::Args rendez_args;
rendez_args.device_context = device_context;
rendez_args.alloc_attrs = std::get<3>(p);
-
+ status_cb->Ref();
rendezvous->RecvAsync(
parsed, rendez_args,
- [val, done, key, call_state](const Status& s,
- const Rendezvous::Args& send_args,
- const Rendezvous::Args& recv_args,
- const Tensor& v, const bool is_dead) {
+ [val, key, status_cb](const Status& s,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args,
+ const Tensor& v, const bool is_dead) {
Status status = s;
if (status.ok()) {
*val = v;
@@ -114,20 +110,11 @@ void RecvOutputsFromRendezvousAsync(
" was not valid.");
}
}
- call_state->mu.lock();
- call_state->shared_status.Update(status);
- call_state->done_counter--;
- // If we are the last async call to return, call the done callback.
- if (call_state->done_counter == 0) {
- const Status& final_status = call_state->shared_status;
- call_state->mu.unlock();
- done(final_status);
- delete call_state;
- return;
- }
- call_state->mu.unlock();
+ status_cb->UpdateStatus(status);
+ status_cb->Unref();
});
}
+ status_cb->Unref();
}
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
diff --git a/tensorflow/core/common_runtime/rendezvous_util.h b/tensorflow/core/common_runtime/rendezvous_util.h
index aad910f6d8..deb9a7c822 100644
--- a/tensorflow/core/common_runtime/rendezvous_util.h
+++ b/tensorflow/core/common_runtime/rendezvous_util.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <map>
#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
@@ -42,7 +43,7 @@ void RecvOutputsFromRendezvousAsync(
Rendezvous* rendezvous, DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
const std::vector<string>& keys, std::vector<Tensor>* received_tensors,
- const StatusCallback& done);
+ StatusCallback done);
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
const Rendezvous::Args& args);
diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc
index bb8eeb141a..a81f8650bf 100644
--- a/tensorflow/core/common_runtime/ring_reducer.cc
+++ b/tensorflow/core/common_runtime/ring_reducer.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -497,13 +498,6 @@ bool RingReducer::RunAsyncParts() {
rfv_.clear();
rfv_.resize(group_size_ * num_subdivs_);
PCQueue ready_queue;
- int field_done_count = 0;
- int send_pending_count = 0;
- int recv_pending_count = 0;
- std::atomic<bool> aborted(false);
- field_done_count = 0;
- send_pending_count = 0;
- recv_pending_count = 0;
for (int chunk_idx = 0; chunk_idx < group_size_; ++chunk_idx) {
for (int subdiv_idx = 0; subdiv_idx < num_subdivs_; ++subdiv_idx) {
int rf_index = (chunk_idx * num_subdivs_) + subdiv_idx;
@@ -511,6 +505,30 @@ bool RingReducer::RunAsyncParts() {
ready_queue.Enqueue(&rfv_[rf_index]);
}
}
+ const DeviceBase::GpuDeviceInfo* gpu_info =
+ col_ctx_->device->tensorflow_gpu_device_info();
+ if (gpu_info) {
+ // Wait for all currently queued events on the CPU compute stream to
+ // complete before proceeding. The previous InitRingField calls allocated
+ // temp memory buffers that are not guaranteed to be valid (e.g. for RDMA
+ // write) unless we do.
+ Notification note;
+ Status s = gpu_info->default_context->ThenExecute(
+ col_ctx_->device, gpu_info->stream, [&note]() { note.Notify(); });
+ if (s.ok()) {
+ note.WaitForNotification();
+ } else {
+ mutex_lock l(status_mu_);
+ status_ =
+ errors::Internal("Failed to dispatch ThenExecute in RingReducer");
+ return false;
+ }
+ }
+
+ int field_done_count = 0;
+ int send_pending_count = 0;
+ int recv_pending_count = 0;
+ std::atomic<bool> aborted(false);
// Loop until all RingFields have advanced to completion.
while (field_done_count < rfv_.size()) {
diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc
index 5e079dbce6..28df85399e 100644
--- a/tensorflow/core/common_runtime/ring_reducer_test.cc
+++ b/tensorflow/core/common_runtime/ring_reducer_test.cc
@@ -260,13 +260,17 @@ class RingReducerTest : public ::testing::Test {
}
}
- void Reduce() {
+ void Reduce(int fail_after) {
std::atomic<int> done(0);
for (auto di : instances_) {
SchedClosure([di, &done] {
di->DoReduce();
++done;
});
+ if (fail_after > 0) {
+ // Stagger the op execution starts.
+ Env::Default()->SleepForMicroseconds(100);
+ }
}
while (done < static_cast<int>(instances_.size())) {
if (stop_) break;
@@ -296,7 +300,7 @@ class RingReducerTest : public ::testing::Test {
}
});
}
- Reduce();
+ Reduce(fail_after);
if (fail_after > 0) {
// Confirm that every device terminated with the expected error status.
for (int di = 0; di < static_cast<int>(instances_.size()); ++di) {
@@ -640,6 +644,7 @@ DEF_TEST(INT64, CPU, 1, 2, 1, 1001, 0)
DEF_TEST(INT64, CPU, 2, 8, 3, 4095, 0)
// Failure tests
+DEF_TEST(FLOAT, CPU, 2, 8, 1, 9408, 1)
DEF_TEST(FLOAT, CPU, 2, 8, 1, 9408, 7)
DEF_TEST(FLOAT, CPU, 2, 8, 2, 9408, 11)
#endif
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc
index 949d034dea..38863db1cc 100644
--- a/tensorflow/core/debug/debug_io_utils.cc
+++ b/tensorflow/core/debug/debug_io_utils.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <stddef.h>
#include <string.h>
#include <cmath>
+#include <cstdlib>
+#include <cstring>
#include <limits>
#include <utility>
#include <vector>
@@ -418,6 +420,19 @@ Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
if (str_util::Lowercase(url).find(kFileURLScheme) == 0) {
const string dump_root_dir = url.substr(strlen(kFileURLScheme));
+ const int64 tensorBytes =
+ tensor.IsInitialized() ? tensor.TotalBytes() : 0;
+ if (!DebugFileIO::requestDiskByteUsage(tensorBytes)) {
+ return errors::ResourceExhausted(
+ "TensorFlow Debugger has exhausted file-system byte-size "
+ "allowance (",
+ DebugFileIO::globalDiskBytesLimit, "), therefore it cannot ",
+ "dump an additional ", tensorBytes, " byte(s) of tensor data ",
+ "for the debug tensor ", debug_node_key.node_name, ":",
+ debug_node_key.output_slot, ". You may use the environment ",
+ "variable TFDBG_DISK_BYTES_LIMIT to set a higher limit.");
+ }
+
Status s = DebugFileIO::DumpTensorToDir(
debug_node_key, tensor, wall_time_us, dump_root_dir, nullptr);
if (!s.ok()) {
@@ -670,6 +685,42 @@ Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) {
}
}
+// Default total disk usage limit: 100 GBytes
+const uint64 DebugFileIO::defaultGlobalDiskBytesLimit = 107374182400L;
+uint64 DebugFileIO::globalDiskBytesLimit = 0;
+uint64 DebugFileIO::diskBytesUsed = 0;
+
+mutex DebugFileIO::bytes_mu(LINKER_INITIALIZED);
+
+bool DebugFileIO::requestDiskByteUsage(uint64 bytes) {
+ if (globalDiskBytesLimit == 0) {
+ const char* env_tfdbg_disk_bytes_limit = getenv("TFDBG_DISK_BYTES_LIMIT");
+ if (env_tfdbg_disk_bytes_limit == nullptr ||
+ strlen(env_tfdbg_disk_bytes_limit) == 0) {
+ globalDiskBytesLimit = defaultGlobalDiskBytesLimit;
+ } else {
+ strings::safe_strtou64(string(env_tfdbg_disk_bytes_limit),
+ &globalDiskBytesLimit);
+ }
+ }
+
+ if (bytes == 0) {
+ return true;
+ }
+ mutex_lock l(bytes_mu);
+ if (diskBytesUsed + bytes < globalDiskBytesLimit) {
+ diskBytesUsed += bytes;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+void DebugFileIO::resetDiskByteUsage() {
+ mutex_lock l(bytes_mu);
+ diskBytesUsed = 0;
+}
+
#ifndef PLATFORM_WINDOWS
DebugGrpcChannel::DebugGrpcChannel(const string& server_stream_addr)
: server_stream_addr_(server_stream_addr),
diff --git a/tensorflow/core/debug/debug_io_utils.h b/tensorflow/core/debug/debug_io_utils.h
index cedb7386b7..5390ce408a 100644
--- a/tensorflow/core/debug/debug_io_utils.h
+++ b/tensorflow/core/debug/debug_io_utils.h
@@ -193,6 +193,26 @@ class DebugFileIO {
const string& dir_name,
const string& file_name);
+ // Request additional bytes to be dumped to the file system.
+ //
+ // Does not actually dump the bytes, but instead just performs the
+ // bookkeeping necessary to prevent the total dumped amount of data from
+ // exceeding the limit (default 100 GBytes or set customly through the
+ // environment variable TFDBG_DISK_BYTES_LIMIT).
+ //
+ // Args:
+ // bytes: Number of bytes to request.
+ //
+ // Returns:
+ // Whether the request is approved given the total dumping
+ // limit.
+ static bool requestDiskByteUsage(uint64 bytes);
+
+ // Reset the disk byte usage to zero.
+ static void resetDiskByteUsage();
+
+ static uint64 globalDiskBytesLimit;
+
private:
// Encapsulates the Tensor in an Event protobuf and write it to file.
static Status DumpTensorToEventFile(const DebugNodeKey& debug_node_key,
@@ -204,6 +224,15 @@ class DebugFileIO {
// TODO(cais): Replace with shared implementation once http://b/30497715 is
// fixed.
static Status RecursiveCreateDir(Env* env, const string& dir);
+
+ // Tracks how much disk has been used so far.
+ static uint64 diskBytesUsed;
+ // Mutex for thread-safe access to diskBytesUsed.
+ static mutex bytes_mu;
+ // Default limit for the disk space.
+ static const uint64 defaultGlobalDiskBytesLimit;
+
+ friend class DiskUsageLimitTest;
};
} // namespace tensorflow
diff --git a/tensorflow/core/debug/debug_io_utils_test.cc b/tensorflow/core/debug/debug_io_utils_test.cc
index 0807a85b8b..82e0ae5edb 100644
--- a/tensorflow/core/debug/debug_io_utils_test.cc
+++ b/tensorflow/core/debug/debug_io_utils_test.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <cstdlib>
#include <unordered_set>
#include "tensorflow/core/debug/debug_io_utils.h"
@@ -454,5 +455,50 @@ TEST_F(DebugIOUtilsTest, PublishTensorConcurrentlyToPartiallyOverlappingPaths) {
}
}
+class DiskUsageLimitTest : public ::testing::Test {
+ public:
+ void Initialize() {
+ setenv("TFDBG_DISK_BYTES_LIMIT", "", 1);
+ DebugFileIO::resetDiskByteUsage();
+ DebugFileIO::globalDiskBytesLimit = 0;
+ }
+};
+
+TEST_F(DiskUsageLimitTest, RequestWithZeroByteIsOkay) {
+ Initialize();
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(0L));
+}
+
+TEST_F(DiskUsageLimitTest, ExceedingLimitAfterOneCall) {
+ Initialize();
+ ASSERT_FALSE(DebugFileIO::requestDiskByteUsage(100L * 1024L * 1024L * 1024L));
+}
+
+TEST_F(DiskUsageLimitTest, ExceedingLimitAfterTwoCalls) {
+ Initialize();
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(50L * 1024L * 1024L * 1024L));
+ ASSERT_FALSE(DebugFileIO::requestDiskByteUsage(50L * 1024L * 1024L * 1024L));
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(1024L));
+}
+
+TEST_F(DiskUsageLimitTest, ResetDiskByteUsageWorks) {
+ Initialize();
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(50L * 1024L * 1024L * 1024L));
+ ASSERT_FALSE(DebugFileIO::requestDiskByteUsage(50L * 1024L * 1024L * 1024L));
+ DebugFileIO::resetDiskByteUsage();
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(50L * 1024L * 1024L * 1024L));
+}
+
+TEST_F(DiskUsageLimitTest, CustomEnvVarIsObeyed) {
+ Initialize();
+ setenv("TFDBG_DISK_BYTES_LIMIT", "1024", 1);
+ ASSERT_FALSE(DebugFileIO::requestDiskByteUsage(1024L));
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(1000L));
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(23L));
+ ASSERT_FALSE(DebugFileIO::requestDiskByteUsage(1L));
+ DebugFileIO::resetDiskByteUsage();
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(1023L));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/debug/debugger_state_impl.cc b/tensorflow/core/debug/debugger_state_impl.cc
index 2f5aaf93fa..79798f9392 100644
--- a/tensorflow/core/debug/debugger_state_impl.cc
+++ b/tensorflow/core/debug/debugger_state_impl.cc
@@ -27,6 +27,9 @@ DebuggerState::DebuggerState(const DebugOptions& debug_options) {
debug_urls_.insert(url);
}
}
+ if (debug_options.reset_disk_byte_usage()) {
+ DebugFileIO::resetDiskByteUsage();
+ }
}
DebuggerState::~DebuggerState() {
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index abd07e37b7..8e9eec1ed9 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -449,7 +449,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
*c->req.mutable_debug_options() =
callable_opts_.run_options().debug_options();
- c->req.set_collective_graph_key(bg_opts_.collective_graph_key);
+ c->req.set_collective_graph_key(client_graph()->collective_graph_key);
VLOG(2) << "Register " << c->req.graph_def().DebugString();
auto cb = [c, &done](const Status& s) {
c->status = s;
@@ -1111,10 +1111,6 @@ uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
}
- if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
- h = Hash64Combine(opts.collective_graph_key, h);
- }
-
return h;
}
@@ -1788,10 +1784,10 @@ Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
Status s = run_status;
if (s.ok()) {
pss->end_micros = Env::Default()->NowMicros();
- if (rcg->build_graph_options().collective_graph_key !=
+ if (rcg->client_graph()->collective_graph_key !=
BuildGraphOptions::kNoCollectiveGraphKey) {
env_->collective_executor_mgr->RetireStepId(
- rcg->build_graph_options().collective_graph_key, step_id);
+ rcg->client_graph()->collective_graph_key, step_id);
}
// Schedule post-processing and cleanup to be done asynchronously.
rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
@@ -1850,7 +1846,7 @@ Status MasterSession::DoRunWithLocalExecution(
// Keeps the highest 8 bits 0x01: we reserve some bits of the
// step_id for future use.
- uint64 step_id = NewStepId(bgopts.collective_graph_key);
+ uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key);
TRACEPRINTF("stepid %llu", step_id);
std::unique_ptr<ProfileHandler> ph;
@@ -1914,8 +1910,7 @@ Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
// Prepare.
int64 count = rcg->get_and_increment_execution_count();
- const uint64 step_id =
- NewStepId(rcg->build_graph_options().collective_graph_key);
+ const uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key);
TRACEPRINTF("stepid %llu", step_id);
const RunOptions& run_options = rcg->callable_options().run_options();
diff --git a/tensorflow/core/distributed_runtime/tensor_coding.h b/tensorflow/core/distributed_runtime/tensor_coding.h
index bae4ec794c..4c34297990 100644
--- a/tensorflow/core/distributed_runtime/tensor_coding.h
+++ b/tensorflow/core/distributed_runtime/tensor_coding.h
@@ -87,6 +87,9 @@ class TensorResponse {
// modified.
const RecvTensorResponse& metadata() const { return meta_; }
+ // Return pointer to the device hosting the tensor.
+ DeviceBase* device() const { return device_; }
+
private:
bool ParseTensorSubmessage(protobuf::io::CodedInputStream* input,
TensorProto* tensor_meta);
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 21c6940b62..20a07d86a2 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -432,9 +432,9 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
DimensionHandle batch_size_dim;
DimensionHandle input_depth_dim;
gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2);
- TF_RETURN_IF_ERROR(DimensionsFromShape(conv_input_shape, data_format,
- &batch_size_dim, &input_spatial_dims,
- &input_depth_dim, c));
+ TF_RETURN_IF_ERROR(DimensionsFromShape(
+ conv_input_shape, data_format, &batch_size_dim,
+ absl::MakeSpan(input_spatial_dims), &input_depth_dim, c));
DimensionHandle output_depth_dim = c->Dim(
filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index b0b27ce94f..9ffd8e1ee0 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -179,6 +179,13 @@ Status GraphDefBuilderWrapper::AddFunction(SerializationContext* ctx,
return Status::OK();
}
+void GraphDefBuilderWrapper::AddPlaceholderInternal(const Tensor& val,
+ Node** output) {
+ *output = ops::SourceOp(
+ "Placeholder",
+ b_->opts().WithAttr("dtype", val.dtype()).WithAttr("shape", val.shape()));
+}
+
void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val,
Node** output) {
*output = ops::SourceOp(
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index e06ca68bca..04865a1d4f 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -110,10 +110,11 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
- // Adds a Const node with Tensor value to the Graph.
+ // Adds a `Const` node for the given tensor value to the graph.
+ //
// `*output` contains a pointer to the output `Node`. It is guaranteed to be
- // non-null if the method returns with an OK status.
- // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
+ // non-null if the method returns with an OK status. The returned `Node`
+ // pointer is owned by the backing graph of `GraphDefBuilder`.
Status AddTensor(const Tensor& val, Node** output) {
AddTensorInternal(val, output);
if (*output == nullptr) {
@@ -122,6 +123,20 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
+ // Adds a `Placeholder` node for the given tensor value to the graph.
+ //
+ // `*output` contains a pointer to the output `Node`. It is guaranteed to be
+ // non-null if the method returns with an OK status. The returned `Node`
+ // pointer is owned by the backing graph of `GraphDefBuilder`.
+ Status AddPlaceholder(const Tensor& val, Node** output) {
+ AddPlaceholderInternal(val, output);
+ if (*output == nullptr) {
+ return errors::Internal(
+ "AddPlaceholder: Failed to build Placeholder op.");
+ }
+ return Status::OK();
+ }
+
Status AddDataset(const DatasetBase* dataset,
const std::vector<Node*>& inputs, Node** output) {
return AddDataset(dataset, inputs, {}, output);
@@ -168,6 +183,7 @@ class GraphDefBuilderWrapper {
}
private:
+ void AddPlaceholderInternal(const Tensor& val, Node** output);
void AddTensorInternal(const Tensor& val, Node** output);
Status EnsureFunctionIsStateless(const FunctionLibraryDefinition& flib_def,
@@ -334,7 +350,8 @@ class SerializationContext {
public:
struct Params {
bool allow_stateful_functions = false;
- const FunctionLibraryDefinition* flib_def; // Not owned.
+ const FunctionLibraryDefinition* flib_def = nullptr; // Not owned.
+ std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned.
};
explicit SerializationContext(Params params) : params_(std::move(params)) {}
@@ -343,6 +360,10 @@ class SerializationContext {
const FunctionLibraryDefinition& flib_def() { return *params_.flib_def; }
+ std::vector<std::pair<string, Tensor>>* input_list() {
+ return params_.input_list;
+ }
+
private:
Params params_;
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
index 6e38256ba8..46b169dddc 100644
--- a/tensorflow/core/framework/function_testlib.cc
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -219,6 +219,62 @@ FunctionDef InvalidControlFlow() {
{{"o", "add:z"}});
}
+FunctionDef LessThanOrEqualToN(int64 N) {
+ const Tensor kN = test::AsScalar<int64>(N);
+ return FDH::Define(
+ // Name
+ "LessThanOrEqualToN",
+ // Args
+ {"x: T"},
+ // Return values
+ {"z: bool"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {
+ {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}},
+ {{"y"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
+ {{"z"}, "LessEqual", {"x", "y"}, {{"T", "$T"}}},
+ });
+}
+
+FunctionDef XPlusOneXTimesY() {
+ const Tensor kOne = test::AsScalar<int64>(1);
+ return FDH::Define(
+ // Name
+ "XPlusOneXTimesY",
+ // Args
+ {"x: T", "y: T"},
+ // Return values
+ {"s: T", "t: T"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {{{"one"}, "Const", {}, {{"value", kOne}, {"dtype", DT_INT64}}},
+ {{"increment"}, "Cast", {"one"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
+ {{"s"}, "Add", {"x", "increment"}, {{"T", "$T"}}},
+ {{"t"}, "Mul", {"x", "y"}, {{"T", "$T"}}}});
+}
+
+FunctionDef XYXLessThanOrEqualToN(int64 N) {
+ const Tensor kN = test::AsScalar<int64>(N);
+ return FDH::Define(
+ // Name
+ "XYXLessThanOrEqualToN",
+ // Args
+ {"x: T", "y: T"},
+ // Return values
+ {"z: bool"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {
+ {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}},
+ {{"N1"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
+ {{"z"}, "LessEqual", {"x", "N1"}, {{"T", "$T"}}},
+ });
+}
+
void FunctionTestSchedClosure(std::function<void()> fn) {
static thread::ThreadPool* w =
new thread::ThreadPool(Env::Default(), "Test", 8);
diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h
index af08d296b2..6d6476b936 100644
--- a/tensorflow/core/framework/function_testlib.h
+++ b/tensorflow/core/framework/function_testlib.h
@@ -87,6 +87,15 @@ FunctionDef Swap();
// Contains malformed control flow which can't be run by the executor.
FunctionDef InvalidControlFlow();
+// x:T -> x <= N.
+FunctionDef LessThanOrEqualToN(int64 N);
+
+// x:T, y:T -> x+1, x*y
+FunctionDef XPlusOneXTimesY();
+
+// x:T, y:T -> x <= N
+FunctionDef XYXLessThanOrEqualToN(int64 N);
+
void FunctionTestSchedClosure(std::function<void()> fn);
} // end namespace function
diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h
index 333bf761b0..bab1df87a4 100644
--- a/tensorflow/core/graph/mkl_graph_util.h
+++ b/tensorflow/core/graph/mkl_graph_util.h
@@ -41,7 +41,7 @@ namespace tensorflow {
typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
// NOTE: Currently, we use contiguous ordering. If you change this, then you
// would need to change Mkl op definitions in nn_ops.cc.
-static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
+static const MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
// Get index of MetaData tensor from index 'n' of Data tensor.
inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 7e501c1717..2e644fe987 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -1043,6 +1043,7 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
// device of the original
// node.
.Finalize(&**g, out));
+ CHECK_NOTNULL(*out); // Make sure we got a valid object before using it
// If number of inputs to the original node is > 0, then we add
// control dependency between 1st input (index 0) of the original node and
@@ -1336,6 +1337,7 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
// device of the original
// node.
.Finalize(&**g, out));
+ CHECK_NOTNULL(*out); // Make sure we got a valid object before using it
// If number of inputs to the original node is > 0, then we add
// control dependency between 1st input (index 0) of the original node and
@@ -3214,6 +3216,7 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
// device of the original
// node.
.Finalize(&**g, out));
+ CHECK_NOTNULL(*out); // Make sure we got a valid object before using it
// If number of inputs to the original node is > 0, then we add
// control dependency between 1st input (index 0) of the original node and
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 70ad9f9a9b..a24004dc16 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -110,12 +110,13 @@ cc_library(
],
)
-tf_cuda_cc_test(
+tf_cc_test(
name = "constant_folding_test",
srcs = ["constant_folding_test.cc"],
- tags = ["requires-gpu-sm35"],
+ shard_count = 5,
deps = [
":constant_folding",
+ ":dependency_optimizer",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/core:all_kernels",
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 4fb2fe6883..4fed88d536 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -2703,22 +2703,31 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
NodeDef* inner_function;
TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
// Optimize only if:
+ // 0. inner_function is not in the preserve set,
// 1. inner_function's Op is element-wise monotonic
// 2. inner_function's output is not being consumed elsewhere.
- if (IsElementWiseMonotonic(*inner_function) &&
- (NumNonControlOutputs(*inner_function, *ctx().node_map) == 1)) {
+ if (!IsInPreserveSet(*inner_function) &&
+ IsElementWiseMonotonic(*inner_function) &&
+ ctx().node_map->GetOutputs(inner_function->name()).size() == 1) {
// Swap the first inputs of the inner function Op & the reduction Op.
NodeDef* inner_input;
TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
- inner_function->set_input(0, reduction_node->name());
- UpdateConsumersAvoidingLoop(inner_function, reduction_node->name());
reduction_node->set_input(0, inner_input->name());
- UpdateConsumersAvoidingLoop(reduction_node, inner_function->name());
+ ctx().node_map->UpdateInput(reduction_node->name(),
+ inner_function->name(), inner_input->name());
+ inner_function->set_input(0, reduction_node->name());
+ UpdateConsumers(reduction_node, inner_function->name());
+ ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(),
+ reduction_node->name());
+
+ AddToOptimizationQueue(reduction_node);
+ AddToOptimizationQueue(inner_function);
+ AddToOptimizationQueue(inner_input);
}
return Status::OK();
}
- void UpdateConsumersAvoidingLoop(NodeDef* node, const string& new_input) {
+ void UpdateConsumers(NodeDef* node, const string& new_input) {
const string& node_name = node->name();
const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
for (NodeDef* consumer : consumers) {
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 551c3652bf..d457eb6d21 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -61,7 +61,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool fold_multiply_into_conv = true;
bool fold_transpose_into_matmul = true;
bool hoist_common_factor_out_of_aggregation = true;
- bool hoist_cwise_unary_chains = false;
+ bool hoist_cwise_unary_chains = true;
bool minimize_broadcasts = true;
bool optimize_max_or_min_of_monotonic = true;
bool remove_idempotent = true;
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 685b5379af..bfccc0affd 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -3224,6 +3224,30 @@ TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) {
EXPECT_EQ(2, required_node_count);
}
+TEST_F(ArithmeticOptimizerTest,
+ OptimizeMaxOrMinOfMonotonicElementWise_DoNotChangeFetchNode) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
+ Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
+ Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
+
+ GrapplerItem item;
+ item.fetch = {"sqrt", "final_out"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(2, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
+ OptimizeTwice(&optimizer, &item, &output);
+
+ // Should be a NoOp since we are not allowed to change the output of fetch
+ // nodes.
+ VerifyGraphsMatch(item.graph, output, __LINE__);
+}
+
TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 815bd23307..99737a71eb 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -136,6 +136,27 @@ bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
return removed_input;
}
+bool GetConcatAxis(const GraphProperties& properties, NodeDef* node,
+ int* axis) {
+ if (node->op() != "ConcatV2" ||
+ properties.GetInputProperties(node->name()).empty()) {
+ return false;
+ }
+ const auto& axis_input = properties.GetInputProperties(node->name()).back();
+ if (!TensorShape::IsValid(axis_input.shape()) || !axis_input.has_value()) {
+ return false;
+ }
+
+ Tensor axis_tensor(axis_input.dtype(), axis_input.shape());
+ if (!axis_tensor.FromProto(axis_input.value())) {
+ return false;
+ }
+ *axis = axis_input.dtype() == DT_INT64
+ ? static_cast<int>(axis_tensor.scalar<int64>()())
+ : axis_tensor.scalar<int32>()();
+ return true;
+}
+
} // namespace
ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
@@ -852,19 +873,7 @@ DataType GetDataTypeFromNodeOrProps(const NodeDef& node,
}
return dtype;
}
-bool IsValidConstShapeForNCHW(const TensorShapeProto& shape) {
- if (shape.dim_size() != 4) {
- return false;
- }
- int num_dim_larger_than_one = 0;
- for (const auto& dim : shape.dim()) {
- if (dim.size() > 1) ++num_dim_larger_than_one;
- }
- return num_dim_larger_than_one <= 1;
-}
-const string& GetShape(const NodeDef& node) {
- return node.attr().at("data_format").s();
-}
+
} // namespace
// static
@@ -1711,7 +1720,7 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
return Status::OK();
}
- if (MulConvPushDown(*properties, optimized_graph, node)) {
+ if (MulConvPushDown(node, *properties)) {
graph_modified_ = true;
return Status::OK();
}
@@ -1731,6 +1740,11 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
return Status::OK();
}
+ if (MergeConcat(*properties, use_shape_info, optimized_graph, node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
return Status::OK();
}
@@ -2553,9 +2567,8 @@ bool ConstantFolding::ConstantPushDown(NodeDef* node) {
return false;
}
-bool ConstantFolding::MulConvPushDown(const GraphProperties& properties,
- GraphDef* optimized_graph,
- NodeDef* node) {
+bool ConstantFolding::MulConvPushDown(NodeDef* node,
+ const GraphProperties& properties) {
// Push down multiplication on ConvND.
// * ConvND
// / \ / \
@@ -2631,14 +2644,12 @@ bool ConstantFolding::MulConvPushDown(const GraphProperties& properties,
}
const auto& const_shape = const_props[0].shape();
- if (GetShape(*conv_node) == "NHWC") {
- TensorShapeProto new_filter_shape;
- if (!ShapeAfterBroadcast(filter_shape, const_shape, &new_filter_shape)) {
- return false;
- }
- if (!ShapesSymbolicallyEqual(filter_shape, new_filter_shape)) {
- return false;
- }
+ TensorShapeProto new_filter_shape;
+ if (!ShapeAfterBroadcast(filter_shape, const_shape, &new_filter_shape)) {
+ return false;
+ }
+ if (!ShapesSymbolicallyEqual(filter_shape, new_filter_shape)) {
+ return false;
}
string mul_new_name =
@@ -2672,69 +2683,6 @@ bool ConstantFolding::MulConvPushDown(const GraphProperties& properties,
}
node_map_->AddNode(mul_new_name, node);
- if (GetShape(*conv_node) == "NCHW") {
- if (const_node->attr().at("value").tensor().tensor_shape().dim_size() <=
- 1) {
- // Broadcast should work for scalar or 1D. No need to reshape.
- return true;
- }
- if (!IsValidConstShapeForNCHW(
- const_node->attr().at("value").tensor().tensor_shape())) {
- return false;
- }
- // Adds Const node for Reshape.
- auto* shape_const_node = optimized_graph->add_node();
- const string shape_const_node_name =
- OptimizedNodeName(*const_node, "_new_shape");
- shape_const_node->set_name(shape_const_node_name);
- shape_const_node->set_op("Const");
- shape_const_node->set_device(const_node->device());
- (*shape_const_node->mutable_attr())["dtype"].set_type(DT_INT32);
- Tensor t(DT_INT32, {4});
- t.flat<int32>()(0) = 1;
- t.flat<int32>()(1) = 1;
- t.flat<int32>()(2) = 1;
- t.flat<int32>()(3) = const_node->attr()
- .at("value")
- .tensor()
- .tensor_shape()
- .dim(1) // IsValidConstShapeForNCHW guarantees
- // dim 1 is the dim to reshape
- .size();
- t.AsProtoTensorContent(
- (*shape_const_node->mutable_attr())["value"].mutable_tensor());
- node_map_->AddNode(shape_const_node_name, shape_const_node);
-
- // Adds Reshape node.
- auto* reshape_node = optimized_graph->add_node();
- const string reshape_node_name =
- OptimizedNodeName(*const_node, "_reshape");
- reshape_node->set_op("Reshape");
- reshape_node->set_name(reshape_node_name);
- reshape_node->set_device(const_node->device());
- (*reshape_node->mutable_attr())["T"].set_type(
- const_node->attr().at("dtype").type());
- (*reshape_node->mutable_attr())["Tshape"].set_type(DT_INT32);
- node_map_->AddNode(reshape_node_name, reshape_node);
-
- // const_node -> reshape_node
- node_map_->RemoveOutput(const_node->name(), node->name());
- *reshape_node->add_input() = const_node->name();
- node_map_->AddOutput(const_node->name(), reshape_node_name);
-
- // shape_const_node -> reshape_node
- *reshape_node->add_input() = shape_const_node_name;
- node_map_->AddOutput(shape_const_node_name, reshape_node_name);
-
- // reshape_node -> node (Mul)
- node_map_->AddOutput(reshape_node_name, node->name());
- if (left_child_is_constant) {
- node->set_input(0, reshape_node_name);
- } else {
- node->set_input(1, reshape_node_name);
- }
- }
-
return true;
}
return false;
@@ -2988,6 +2936,55 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
return false;
}
+bool ConstantFolding::MergeConcat(const GraphProperties& properties,
+ bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node) {
+ // We only optimize for ConcatV2.
+ int axis;
+ if (!use_shape_info || !GetConcatAxis(properties, node, &axis) ||
+ nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() ||
+ node_map_->GetOutputs(node->name()).size() != 1) {
+ return false;
+ }
+
+ NodeDef* parent = *node_map_->GetOutputs(node->name()).begin();
+ int parent_axis;
+ if (!GetConcatAxis(properties, parent, &parent_axis) || axis != parent_axis) {
+ return false;
+ }
+
+ const int index = NumNonControlInputs(*node) - 1;
+ auto inputs = parent->input();
+ parent->clear_input();
+ for (int i = 0; i < inputs.size(); ++i) {
+ if (IsSameInput(inputs.Get(i), node->name())) {
+ for (int j = 0; j < node->input_size(); ++j) {
+ if (j < index) {
+ // Input tensors (non axis), add to input list of parent.
+ parent->add_input(node->input(j));
+ node_map_->RemoveOutput(node->input(j), node->name());
+ node_map_->AddOutput(node->input(j), parent->name());
+ }
+ // Skip j == index, which means axis tensor.
+ if (j > index) {
+ // Control Dependencies, push back to inputs so they can be forwarded
+ // to parent.
+ *inputs.Add() = node->input(j);
+ }
+ }
+ } else {
+ parent->add_input(inputs.Get(i));
+ }
+ }
+ node->clear_input();
+ node->set_op("NoOp");
+ node->clear_attr();
+ node_map_->RemoveNode(node->name());
+ (*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1);
+
+ return true;
+}
+
Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
const GrapplerItem& item,
GraphDef* optimized_graph) {
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index 051dfb681e..8593b3e0b8 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -125,8 +125,7 @@ class ConstantFolding : public GraphOptimizer {
// Aggregate constants present around a conv operator. Returns true if the
// transformation was applied successfully.
- bool MulConvPushDown(const GraphProperties& properties,
- GraphDef* optimized_graph, NodeDef* node);
+ bool MulConvPushDown(NodeDef* node, const GraphProperties& properties);
// Strength reduces floating point division by a constant Div(x, const) to
// multiplication by the reciprocal Mul(x, Reciprocal(const)).
@@ -210,6 +209,10 @@ class ConstantFolding : public GraphOptimizer {
// Removes Split or SplitV node if possible.
bool RemoveSplitOrSplitV(const GraphProperties& properties,
GraphDef* optimized_graph, NodeDef* node);
+
+ bool MergeConcat(const GraphProperties& properties, bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node);
+
// Points to an externally provided device or to owned_device_;
RewriterConfig::Toggle opt_level_;
DeviceBase* cpu_device_;
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 0683572dcc..2a19b3f95a 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -240,7 +240,7 @@ TEST_F(ConstantFoldingTest, AddTree) {
}
}
-TEST_F(ConstantFoldingTest, ConvPushDownTestNHWC) {
+TEST_F(ConstantFoldingTest, ConvPushDownTest) {
// Tests if the following rewrite is performed:
//
// * Conv2D
@@ -2030,6 +2030,130 @@ TEST_F(ConstantFoldingTest, TileWithMultipliesBeingOne) {
CompareGraphs(want, got);
}
+TEST_F(ConstantFoldingTest, MergeConcat) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
+ Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
+ Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
+ Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
+
+ ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
+ ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
+
+ GrapplerItem item;
+ item.fetch = {"c2"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("in3", "VariableV2", {}, {}, &want);
+ AddNode("axis", "Const", {}, {}, &want);
+ AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
+TEST_F(ConstantFoldingTest, MergeConcat_SameInput) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
+ Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
+ Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
+ Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
+
+ ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
+ ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3, Output(c1)}, axis);
+
+ GrapplerItem item;
+ item.fetch = {"c2"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("in3", "VariableV2", {}, {}, &want);
+ AddNode("axis", "Const", {}, {}, &want);
+ AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "in1", "in2", "axis"}, {},
+ &want);
+
+ CompareGraphs(want, got);
+}
+
+TEST_F(ConstantFoldingTest, MergeConcat_ConcatWithConst) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 6}, DT_FLOAT);
+ Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
+ Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
+ Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
+
+ ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
+ ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
+
+ GrapplerItem item;
+ item.fetch = {"c2"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("in3", "VariableV2", {}, {}, &want);
+ AddNode("axis", "Const", {}, {}, &want);
+ AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
+TEST_F(ConstantFoldingTest, MergeConcat_AxisMismatch) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 5}, DT_FLOAT);
+ Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
+ Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
+ Output axis1 = ops::Const(scope.WithOpName("axis1"), 0, {});
+ Output axis2 = ops::Const(scope.WithOpName("axis2"), 1, {});
+
+ ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis2);
+ ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis1);
+
+ GrapplerItem item;
+ item.fetch = {"c2"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("in3", "VariableV2", {}, {}, &want);
+ AddNode("axis1", "Const", {}, {}, &want);
+ AddNode("axis2", "Const", {}, {}, &want);
+ AddNode("c1", "ConcatV2", {"in1", "in2", "axis2"}, {}, &want);
+ AddNode("c2", "ConcatV2", {"c1", "in3", "axis1"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
TEST_F(ConstantFoldingTest, PaddingWithZeroSize) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
@@ -3080,110 +3204,6 @@ TEST_F(ConstantFoldingTest, FoldingPreservesDenormalFlushing) {
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
}
-#if GOOGLE_CUDA
-TEST_F(ConstantFoldingTest, ConvPushDownTestNCHW) {
- // Tests if the following rewrite is performed:
- //
- // * Conv2D
- // / \ / \
- // c Conv2D --> x (c * filter)
- // / \
- // x filter
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-
- int input_channel = 1;
- int output_channel = 2;
- int filter_size = 1;
-
- TensorShape filter_shape(
- {filter_size, filter_size, input_channel, output_channel});
-
- // Filter shape: [1, 1, 1, 2]
- // Filter for output channel 0 = {2.f}
- // Filter for output channel 1 = {-2.f}
- // clang-format off
- Output filter =
- ops::Const(s.WithOpName("filter"), {
- {
- {{2.f, -2.f}}
- }
- });
- // clang-format on
-
- int batch_size = 1;
- int matrix_size = 3;
- // input shape: [1,1,3,3]
- TensorShape input_shape(
- {batch_size, input_channel, matrix_size, matrix_size});
- Output input = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
- ops::Placeholder::Shape(input_shape));
-
- Output conv = ops::Conv2D(s.WithOpName("conv"), input, filter, {1, 1, 1, 1},
- "VALID", ops::Conv2D::DataFormat("NCHW"));
- Output c = ops::Const(s.WithOpName("c"), 2.0f, /* shape */ {1, 2, 1, 1});
- Output mul = ops::Mul(s.WithOpName("mul"), c, conv);
-
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
-
- ConstantFolding fold(nullptr);
- GraphDef output;
- Status status = fold.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
-
- // Here only op/IO are checked. The values are verified by EvaluateNodes
- // below.
- int found = 0;
- for (const auto& node : output.node()) {
- if (node.name() == "mul") {
- ++found;
- EXPECT_EQ("Conv2D", node.op());
- EXPECT_EQ(2, node.input_size());
- EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("conv/merged_input", node.input(1));
- } else if (node.name() == "conv/merged_input") {
- ++found;
- EXPECT_EQ("Const", node.op());
- EXPECT_EQ(0, node.input_size());
- }
- }
- EXPECT_EQ(2, found);
-
- // Check that const folded multiplication node has the expected value.
- std::vector<string> fetch = {"mul"};
- // Input shape (NCHW) is [1,1,3,3], filter is [1,1,1,2] output shape should be
- // (NCHW) [1,2,3,3]
- ::tensorflow::Input::Initializer x{
- {
- {
- {1.f, 2.f, 3.f}, // H = 0
- {4.f, 5.f, 6.f}, // H = 1
- {7.f, 8.f, 9.f} // H = 2
- } // C = 0
- } // N = 0
- };
-
- // |1,2,3|
- // conv( |4,5,6|, // input
- // |7,8,9|
- // [[[2,-2]]]) // filter
- // * [1,2,1,1] // mul by const
- // =
- // [
- // |4, 8, 12|
- // |16,20,24| ==> output channel 0
- // |28,32,36|
- //
- // | -4, -8,-12|
- // |-16,-20,-24| ==> output channel 1
- // |-28,-32,-36|
- // ]
- auto actual = EvaluateNodes(output, fetch, {{"x", x.tensor}});
- auto expected = EvaluateNodes(item.graph, fetch, {{"x", x.tensor}});
- test::ExpectTensorEqual<float>(expected[0], actual[0]);
-}
-#endif
-
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index 91794cefe5..c775a26914 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -1071,11 +1071,13 @@ static bool IdentifySwappingCandidates(
// ensure that swapping the tensor back in won't recreate the memory
// bottleneck. Last but not least, we want the tensor to have as few
// remaining uses as possible.
+ //
+ // Note that we must perform the arithmetic inexactly as "double", since
+ // the values do not fit into any integral type.
mem_info.fitness =
- MathUtil::IPow((earliest_use - peak_time).count(), 2);
- mem_info.fitness /= MathUtil::IPow(mem_info.uses_left.size(), 2);
- mem_info.fitness +=
- MathUtil::IPow((allocation_time - peak_time).count(), 2);
+ MathUtil::IPow<double>((earliest_use - peak_time).count(), 2) /
+ MathUtil::IPow<double>(mem_info.uses_left.size(), 2) +
+ MathUtil::IPow<double>((allocation_time - peak_time).count(), 2);
mem_info.fitness = -mem_info.fitness;
mem_state.push_back(mem_info);
}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 633fe9ab77..25063ac823 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2296,6 +2296,31 @@ tf_cc_tests(
],
)
+cc_library(
+ name = "eigen_benchmark",
+ testonly = 1,
+ hdrs = [
+ "eigen_benchmark.h",
+ ":eigen_helpers",
+ ],
+ deps = [
+ "//tensorflow/core:framework",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_cc_test(
+ name = "eigen_benchmark_cpu_test",
+ srcs = ["eigen_benchmark_cpu_test.cc"],
+ deps = [
+ ":eigen_benchmark",
+ ":eigen_helpers",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//third_party/eigen3",
+ ],
+)
+
tf_cc_tests(
name = "basic_ops_benchmark_test",
size = "small",
@@ -4196,6 +4221,7 @@ cc_library(
"hinge-loss.h",
"logistic-loss.h",
"loss.h",
+ "poisson-loss.h",
"smooth-hinge-loss.h",
"squared-loss.h",
],
diff --git a/tensorflow/core/kernels/candidate_sampler_ops.cc b/tensorflow/core/kernels/candidate_sampler_ops.cc
index 654d99301a..663bff3657 100644
--- a/tensorflow/core/kernels/candidate_sampler_ops.cc
+++ b/tensorflow/core/kernels/candidate_sampler_ops.cc
@@ -89,9 +89,9 @@ class BaseCandidateSamplerOp : public OpKernel {
// Pick sampled candidates.
auto local_gen = generator_.ReserveSamples32(samples32);
random::SimplePhilox random(&local_gen);
- sampler_->SampleBatchGetExpectedCount(&random, unique_, &sampled_candidate,
- &sampled_expected_count,
- true_candidate, &true_expected_count);
+ sampler_->SampleBatchGetExpectedCount(&random, unique_, sampled_candidate,
+ sampled_expected_count,
+ true_candidate, true_expected_count);
if (sampler_->NeedsUpdates()) {
sampler_->Update(true_candidate);
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index 5de41bac72..e0da91125b 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -132,14 +132,19 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
+ // Allocate output on the first pass through this function. This must be
+ // done immediately, while we're still in the executor thread. Otherwise
+ // the memory is not guaranteed to be unused by any concurrently executing
+ // GPU kernel.
+ if (c->mutable_output(0) == nullptr) {
+ // Allocate the output tensor, trying to reuse the input.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(c,
+ c->forward_input_or_allocate_output(
+ {0}, 0, c->input(0).shape(), &output),
+ done);
+ }
if (!CanProceedWithCompute(c, col_exec, done)) return;
- // Allocate the output tensor, trying to reuse the input.
- Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(c,
- c->forward_input_or_allocate_output(
- {0}, 0, c->input(0).shape(), &output),
- done);
-
auto actual_done = [c, col_exec, done](const Status& s) {
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
@@ -183,16 +188,23 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
+ // Allocate output on the first pass through this function. This must be
+ // done immediately, while we're still in the executor thread. Otherwise
+ // the memory is not guaranteed to be unused by any concurrently executing
+ // GPU kernel.
+ if (c->mutable_output(0) == nullptr) {
+ // Allocate the output tensor, trying to reuse the input.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(
+ c, c->forward_input_or_allocate_output({0}, 0, shape_, &output),
+ done);
+ }
if (!CanProceedWithCompute(c, col_exec, done)) return;
OP_REQUIRES_ASYNC(
c, shape_.IsSameSize(c->input(0).shape()),
errors::Internal("Declared shape of op ", col_params_.name,
" does not match shape of input"),
done);
- // Allocate the output Tensor, trying to reuse the input.
- Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(
- c, c->forward_input_or_allocate_output({0}, 0, shape_, &output), done);
auto actual_done = [c, col_exec, done](const Status& s) {
OP_REQUIRES_OK_ASYNC(c, s, done);
@@ -239,10 +251,16 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
+ // Allocate output on the first pass through this function. This must be
+ // done immediately, while we're still in the executor thread. Otherwise
+ // the memory is not guaranteed to be unused by any concurrently executing
+ // GPU kernel.
+ if (c->mutable_output(0) == nullptr) {
+ // No input, so must allocate output.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape_, &output), done);
+ }
if (!CanProceedWithCompute(c, col_exec, done)) return;
- // No input, so must allocate output.
- Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape_, &output), done);
auto actual_done = [c, col_exec, done](const Status& s) {
OP_REQUIRES_OK_ASYNC(c, s, done);
diff --git a/tensorflow/core/kernels/cwise_op_zeta.cc b/tensorflow/core/kernels/cwise_op_zeta.cc
index 2c5538534c..dc064eec5f 100644
--- a/tensorflow/core/kernels/cwise_op_zeta.cc
+++ b/tensorflow/core/kernels/cwise_op_zeta.cc
@@ -18,4 +18,9 @@ limitations under the License.
namespace tensorflow {
REGISTER2(BinaryOp, CPU, "Zeta", functor::zeta, float, double);
REGISTER2(BinaryOp, CPU, "Polygamma", functor::polygamma, float, double);
+
+#if GOOGLE_CUDA
+REGISTER2(BinaryOp, GPU, "Zeta", functor::zeta, float, double);
+REGISTER2(BinaryOp, GPU, "Polygamma", functor::polygamma, float, double);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 8d867455e7..e7b3d0c92f 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -481,8 +481,7 @@ tf_kernel_library(
":dataset",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
+ "//tensorflow/core:graph",
],
)
@@ -505,8 +504,7 @@ tf_kernel_library(
":dataset",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
+ "//tensorflow/core:graph",
],
)
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
index 831e7252da..6263dc3cf8 100644
--- a/tensorflow/core/kernels/data/optimize_dataset_op.cc
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -92,8 +92,10 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
DatasetGraphDefBuilder db(&b);
Node* input_node = nullptr;
SerializationContext::Params params;
+ std::vector<std::pair<string, Tensor>> input_list;
params.allow_stateful_functions = true;
params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
+ params.input_list = &input_list;
SerializationContext serialization_ctx(params);
TF_RETURN_IF_ERROR(
db.AddInputDataset(&serialization_ctx, input_, &input_node));
@@ -118,7 +120,7 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
GraphRunner graph_runner(ctx->function_library()->device());
TF_RETURN_IF_ERROR(
- graph_runner.Run(&graph, lib_, {}, {output_node}, &outputs));
+ graph_runner.Run(&graph, lib_, input_list, {output_node}, &outputs));
TF_RETURN_IF_ERROR(
GetDatasetFromVariantTensor(outputs[0], &optimized_input_));
optimized_input_->Ref();
diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc
index 52753a3ccd..8957f5d997 100644
--- a/tensorflow/core/kernels/data/stats_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc
@@ -242,202 +242,6 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
};
};
-class FeatureStatsDatasetOp : public UnaryDatasetOpKernel {
- public:
- explicit FeatureStatsDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx) {}
-
- void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
- DatasetBase** output) override {
- string tag;
- OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag));
- OP_REQUIRES(ctx, input->output_dtypes()[0] == DT_STRING,
- errors::InvalidArgument("FeatureStatsDataset only supports "
- "input with a single `tf.string` "
- "component."));
- *output = new Dataset(ctx, input, std::move(tag));
- }
-
- private:
- class Dataset : public DatasetBase {
- public:
- explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, string tag)
- : DatasetBase(DatasetContext(ctx)),
- input_(input),
- tag_(std::move(tag)) {
- input_->Ref();
- }
-
- ~Dataset() override { input_->Unref(); }
-
- std::unique_ptr<IteratorBase> MakeIteratorInternal(
- const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(new Iterator(
- {this, strings::StrCat(prefix, "::FeatureStatsDataset")}));
- }
-
- const DataTypeVector& output_dtypes() const override {
- return input_->output_dtypes();
- }
- const std::vector<PartialTensorShape>& output_shapes() const override {
- return input_->output_shapes();
- }
-
- string DebugString() const override {
- return "FeatureStatsDatasetOp::Dataset";
- }
-
- protected:
- Status AsGraphDefInternal(SerializationContext* ctx,
- DatasetGraphDefBuilder* b,
- Node** output) const override {
- Node* input_node;
- TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
- Node* tag_node;
- TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node));
- TF_RETURN_IF_ERROR(b->AddDataset(this, {input_node, tag_node}, output));
- return Status::OK();
- }
-
- private:
- class Iterator : public DatasetIterator<Dataset> {
- public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
-
- Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
- }
-
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- tf_shared_lock l(mu_);
- Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
- auto stats_aggregator = ctx->stats_aggregator();
- if (stats_aggregator && s.ok() && !*end_of_sequence) {
- for (const Tensor& t : *out_tensors) {
- auto record_t = t.flat<string>();
- Example example;
- // TODO(b/111553342): redundant parsing here, potential solutions
- // to improve performance is to a) have a potential
- // ParseExampleDataset and collect stats from there and b) make
- // changes to parse_example() where it returns stats as well.
- for (int i = 0; i < record_t.size(); ++i) {
- if (example.ParseFromString(record_t(i))) {
- stats_aggregator->IncrementCounter("examples_count", "trainer",
- 1);
- AddStatsFeatures(example, stats_aggregator);
- } else {
- SequenceExample sequence_example;
- if (sequence_example.ParseFromString(record_t(i))) {
- stats_aggregator->IncrementCounter("sequence_examples_count",
- "trainer", 1);
- AddStatsFeatures(sequence_example, stats_aggregator);
- }
- }
- }
- }
- }
- return s;
- }
-
- int AddStatsFeatureValues(const Feature& feature) {
- int feature_values_list_size = 0;
- switch (feature.kind_case()) {
- case Feature::kBytesList: {
- feature_values_list_size = feature.bytes_list().value().size();
- break;
- }
- case Feature::kFloatList: {
- feature_values_list_size = feature.float_list().value().size();
- break;
- }
- case Feature::kInt64List: {
- feature_values_list_size = feature.int64_list().value().size();
- break;
- }
- case Feature::KIND_NOT_SET:
- break;
- }
- return feature_values_list_size;
- }
-
- void AddStatsFeatures(
- const Example& example,
- const std::shared_ptr<StatsAggregator>& stats_aggregator) {
- stats_aggregator->AddToHistogram(
- strings::StrCat(dataset()->tag_, ":features"),
- {static_cast<double>(example.features().feature().size())});
-
- int feature_values_list_size_sum = 0;
- for (const auto& feature : example.features().feature()) {
- stats_aggregator->IncrementCounter("features_count", "trainer", 1);
- feature_values_list_size_sum += AddStatsFeatureValues(feature.second);
- }
- stats_aggregator->IncrementCounter("feature_values_count", "trainer",
- feature_values_list_size_sum);
- stats_aggregator->AddToHistogram(
- strings::StrCat(dataset()->tag_, ":feature-values"),
- {static_cast<double>(feature_values_list_size_sum)});
- }
-
- void AddStatsFeatures(
- const SequenceExample& example,
- const std::shared_ptr<StatsAggregator>& stats_aggregator) {
- stats_aggregator->AddToHistogram(
- strings::StrCat(dataset()->tag_, ":features"),
- {static_cast<double>(
- example.context().feature().size() +
- example.feature_lists().feature_list().size())});
-
- int feature_values_list_size_sum = 0;
- for (const auto& feature : example.context().feature()) {
- stats_aggregator->IncrementCounter("features_count", "trainer", 1);
- feature_values_list_size_sum += AddStatsFeatureValues(feature.second);
- }
-
- for (const auto& feature_list :
- example.feature_lists().feature_list()) {
- stats_aggregator->IncrementCounter("feature_lists_count", "trainer",
- 1);
- for (const auto& feature : feature_list.second.feature()) {
- feature_values_list_size_sum += AddStatsFeatureValues(feature);
- }
- }
- stats_aggregator->IncrementCounter("feature_values_count", "trainer",
- feature_values_list_size_sum);
- stats_aggregator->AddToHistogram(
- strings::StrCat(dataset()->tag_, ":feature-values"),
- {static_cast<double>(feature_values_list_size_sum)});
- }
-
- protected:
- Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
- return Status::OK();
- }
-
- Status RestoreInternal(IteratorContext* ctx,
- IteratorStateReader* reader) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
- return Status::OK();
- }
-
- private:
- mutex mu_;
- std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
- };
-
- const DatasetBase* const input_;
- const string tag_;
- };
-};
-
-REGISTER_KERNEL_BUILDER(Name("FeatureStatsDataset").Device(DEVICE_CPU),
- FeatureStatsDatasetOp);
REGISTER_KERNEL_BUILDER(Name("LatencyStatsDataset").Device(DEVICE_CPU),
LatencyStatsDatasetOp);
REGISTER_KERNEL_BUILDER(Name("BytesProducedStatsDataset").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc
index fc21c3235a..1192fafc4c 100644
--- a/tensorflow/core/kernels/data/tensor_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
@@ -28,8 +29,6 @@ class TensorDatasetOp : public DatasetOpKernel {
explicit TensorDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
- // Create a new TensorDatasetOp::Dataset, insert it in the step
- // container, and return it as the output.
OpInputList inputs;
OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs));
// TODO(mrry): Validate that the shapes of the "components" tensors match
@@ -74,7 +73,13 @@ class TensorDatasetOp : public DatasetOpKernel {
components.reserve(tensors_.size());
for (const Tensor& t : tensors_) {
Node* node;
- TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ std::vector<std::pair<string, Tensor>>* input_list = ctx->input_list();
+ if (input_list) {
+ TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
+ input_list->emplace_back(node->name(), t);
+ } else {
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ }
components.emplace_back(node);
}
AttrValue dtypes;
diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
index 5b051e0e08..dc32cd23e5 100644
--- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/util/batch_util.h"
@@ -30,8 +31,6 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
: DatasetOpKernel(ctx) {}
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
- // Create a new TensorDatasetOp::Dataset, insert it in the step
- // container, and return it as the output.
OpInputList inputs;
OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs));
std::vector<Tensor> components;
@@ -93,7 +92,13 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
components.reserve(tensors_.size());
for (const Tensor& t : tensors_) {
Node* node;
- TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ std::vector<std::pair<string, Tensor>>* input_list = ctx->input_list();
+ if (input_list) {
+ TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
+ input_list->emplace_back(node->name(), t);
+ } else {
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ }
components.emplace_back(node);
}
AttrValue dtypes;
diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h
index f7c68e8d47..33ed5522d0 100644
--- a/tensorflow/core/kernels/debug_ops.h
+++ b/tensorflow/core/kernels/debug_ops.h
@@ -177,8 +177,10 @@ class BaseDebugOp : public OpKernel {
// Publish a tensor to all debug URLs of the debug op.
// Log an error if the publishing failed.
- void PublishTensor(const Tensor& tensor) {
- if (!debug_urls_.empty()) {
+ Status PublishTensor(const Tensor& tensor) {
+ if (debug_urls_.empty()) {
+ return Status::OK();
+ } else {
Status status = DebugIO::PublishDebugTensor(*debug_watch_key_, tensor,
Env::Default()->NowMicros(),
debug_urls_, gated_grpc_);
@@ -189,6 +191,7 @@ class BaseDebugOp : public OpKernel {
<< str_util::Join(debug_urls_, ", ")
<< ", due to: " << status.error_message();
}
+ return status;
}
}
@@ -213,7 +216,7 @@ class DebugIdentityOp : public BaseDebugOp {
return;
}
- PublishTensor(context->input(0));
+ OP_REQUIRES_OK(context, PublishTensor(context->input(0)));
context->set_output(0, context->input(0));
}
};
diff --git a/tensorflow/core/kernels/eigen_benchmark.h b/tensorflow/core/kernels/eigen_benchmark.h
new file mode 100644
index 0000000000..46ad38fb77
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_benchmark.h
@@ -0,0 +1,298 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_BENCHMARK_H_
+#define TENSORFLOW_CORE_KERNELS_EIGEN_BENCHMARK_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h"
+#include "tensorflow/core/kernels/eigen_backward_spatial_convolutions.h"
+#include "tensorflow/core/kernels/eigen_cuboid_convolution.h"
+#include "tensorflow/core/kernels/eigen_spatial_convolutions.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+using ::tensorflow::TTypes;
+
+template <typename Scalar, typename Device>
+class SpatialConvolutionBenchmarksSuite {
+ public:
+ using Input = TTypes<float, 4>::ConstTensor;
+ using Filter = TTypes<float, 4>::ConstTensor;
+ using Output = TTypes<float, 4>::Tensor;
+
+ using Dimensions = Eigen::DSizes<Eigen::Index, 4>;
+
+ SpatialConvolutionBenchmarksSuite(int iters, Device& device)
+ : iters_(iters), device_(device) {}
+
+ Eigen::Index BufferSize(const Dimensions& dims) {
+ return dims.TotalSize() * sizeof(Scalar);
+ }
+
+ void SpatialConvolution(Dimensions input_dims, Dimensions filter_dims) {
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ filter_dims[3]); // filter_count
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+ Scalar* output_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(filter_data, 123, BufferSize(filter_dims));
+
+ Input input(input_data, input_dims);
+ Filter filter(filter_data, filter_dims);
+ Output output(output_data, output_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ output.device(device_) = Eigen::SpatialConvolution(input, filter);
+ tensorflow::testing::DoNotOptimize(output);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(filter_data);
+ device_.deallocate(output_data);
+ }
+
+ void SpatialConvolutionBackwardInput(Dimensions input_dims,
+ Dimensions filter_dims) {
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ filter_dims[3]); // filter_count
+
+ // Assuming that the convolution had SAME padding.
+ Eigen::Index input_rows = input_dims[1];
+ Eigen::Index input_cols = input_dims[2];
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+ Scalar* output_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(filter_data, 123, BufferSize(filter_dims));
+
+ Input input(input_data, input_dims);
+ Filter filter(filter_data, filter_dims);
+ Output output(output_data, output_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ output.device(device_) = Eigen::SpatialConvolutionBackwardInput(
+ filter, input, input_rows, input_cols);
+ tensorflow::testing::DoNotOptimize(output);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(filter_data);
+ device_.deallocate(output_data);
+ }
+
+ void SpatialConvolutionBackwardKernel(Dimensions input_dims,
+ Dimensions filter_dims) {
+ using OutputBackward = TTypes<float, 4>::ConstTensor;
+ using FilterGrad = TTypes<float, 4>::Tensor;
+
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ filter_dims[3]); // filter_count
+
+ // Assuming that the convolution had SAME padding.
+ Eigen::Index filter_rows = filter_dims[0];
+ Eigen::Index filter_cols = filter_dims[1];
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* output_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(output_backward_data, 123, BufferSize(output_dims));
+
+ Input input(input_data, input_dims);
+ OutputBackward output_backward(output_backward_data, input_dims);
+ FilterGrad filter_grad(filter_data, filter_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ filter_grad.device(device_) = Eigen::SpatialConvolutionBackwardKernel(
+ input, output_backward, filter_rows, filter_cols);
+ tensorflow::testing::DoNotOptimize(filter_grad);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(output_backward_data);
+ device_.deallocate(filter_data);
+ }
+
+ private:
+ int iters_;
+ Device& device_;
+};
+
+template <typename Scalar, typename Device>
+class CuboidConvolutionBenchmarksSuite {
+ public:
+ using Input = TTypes<float, 5>::ConstTensor;
+ using Filter = TTypes<float, 5>::ConstTensor;
+ using Output = TTypes<float, 5>::Tensor;
+
+ using Dimensions = Eigen::DSizes<Eigen::Index, 5>;
+
+ CuboidConvolutionBenchmarksSuite(int iters, Device& device)
+ : iters_(iters), device_(device) {}
+
+ Eigen::Index BufferSize(const Dimensions& dims) {
+ return dims.TotalSize() * sizeof(Scalar);
+ }
+
+ void CuboidConvolution(Dimensions input_dims, Dimensions filter_dims) {
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ input_dims[3], // input_planes
+ filter_dims[4]); // filter_count
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+ Scalar* output_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(filter_data, 123, BufferSize(filter_dims));
+
+ Input input(input_data, input_dims);
+ Filter filter(filter_data, filter_dims);
+ Output output(output_data, output_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ output.device(device_) = Eigen::CuboidConvolution(input, filter);
+ tensorflow::testing::DoNotOptimize(output);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(filter_data);
+ device_.deallocate(output_data);
+ }
+
+ void CuboidConvolutionBackwardInput(Dimensions input_dims,
+ Dimensions filter_dims) {
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ input_dims[3], // input_planes
+ filter_dims[4]); // filter_count
+
+ // Assuming that the convolution had SAME padding.
+ Eigen::Index input_rows = input_dims[1];
+ Eigen::Index input_cols = input_dims[2];
+ Eigen::Index input_planes = input_dims[3];
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+ Scalar* output_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(filter_data, 123, BufferSize(filter_dims));
+
+ Input input(input_data, input_dims);
+ Filter filter(filter_data, filter_dims);
+ Output output(output_data, output_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ output.device(device_) = Eigen::CuboidConvolutionBackwardInput(
+ filter, input, input_planes, input_rows, input_cols);
+ tensorflow::testing::DoNotOptimize(output);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(filter_data);
+ device_.deallocate(output_data);
+ }
+
+ void CuboidConvolutionBackwardKernel(Dimensions input_dims,
+ Dimensions filter_dims) {
+ using OutputBackward = TTypes<float, 5>::ConstTensor;
+ using FilterGrad = TTypes<float, 5>::Tensor;
+
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ input_dims[3], // input_planes
+ filter_dims[4]); // filter_count
+
+ // Assuming that the convolution had SAME padding.
+ Eigen::Index filter_rows = filter_dims[0];
+ Eigen::Index filter_cols = filter_dims[1];
+ Eigen::Index filter_planes = filter_dims[2];
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* output_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(output_backward_data, 123, BufferSize(output_dims));
+
+ Input input(input_data, input_dims);
+ OutputBackward output_backward(output_backward_data, output_dims);
+ FilterGrad filter_grad(filter_data, filter_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ filter_grad.device(device_) = Eigen::CuboidConvolutionBackwardKernel(
+ input, output_backward, filter_planes, filter_rows, filter_cols);
+ tensorflow::testing::DoNotOptimize(filter_grad);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(output_backward_data);
+ device_.deallocate(filter_data);
+ }
+
+ private:
+ int iters_;
+ Device& device_;
+};
+
+#endif // TENSORFLOW_CORE_KERNELS_EIGEN_BENCHMARK_H_
diff --git a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
new file mode 100644
index 0000000000..2a8308ef9a
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
@@ -0,0 +1,402 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENTE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONT OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#define EIGEN_USE_CUSTOM_THREAD_POOL
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/eigen_benchmark.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+#define CREATE_THREAD_POOL(threads) \
+ Eigen::ThreadPool tp(threads); \
+ Eigen::ThreadPoolDevice device(&tp, threads)
+
+// -------------------------------------------------------------------------- //
+// Spatial Convolutions //
+// -------------------------------------------------------------------------- //
+
+void SpatialConvolution(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height, int input_width,
+ int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height, int filter_width) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ SpatialConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(input_batches, input_height,
+ input_width, input_depth);
+ typename Benchmark::Dimensions filter_dims(filter_height, filter_width,
+ input_depth, filter_count);
+
+ benchmark.SpatialConvolution(input_dims, filter_dims);
+
+ auto output_size = input_dims.TotalSize();
+ auto flops = output_size * (input_depth * filter_height * filter_width);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+void SpatialConvolutionBackwardInput(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height,
+ int input_width, int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height,
+ int filter_width) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ SpatialConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(input_batches, input_height,
+ input_width, input_depth);
+ typename Benchmark::Dimensions filter_dims(filter_height, filter_width,
+ input_depth, filter_count);
+
+ benchmark.SpatialConvolutionBackwardInput(input_dims, filter_dims);
+
+ auto output_size = input_dims.TotalSize();
+ auto flops = output_size * (input_depth * filter_height * filter_width);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+void SpatialConvolutionBackwardKernel(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height,
+ int input_width, int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height,
+ int filter_width) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ SpatialConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(input_batches, input_height,
+ input_width, input_depth);
+ typename Benchmark::Dimensions filter_dims(filter_height, filter_width,
+ input_depth, filter_count);
+
+ benchmark.SpatialConvolutionBackwardKernel(input_dims, filter_dims);
+
+ auto filter_size = filter_dims.TotalSize();
+ auto flops = filter_size * (input_batches * input_height * input_width);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+// Macro arguments names: --------------------------------------------------- //
+// NT: num threads
+// N: batch size
+// H: height
+// W: width
+// C: channels
+// FC: filter count
+// FH: filter height
+// FW: filter width
+
+#define BM_SPATIAL_NAME(prefix, NT, N, H, W, C, FC, FH, FW) \
+ BM_##prefix##_CPU_##NT##T_in_##N##_##H##_##W##_##C##_f_##FC##_##FH##_##FW
+
+#define BM_SpatialConvolution(NT, N, H, W, C, FC, FH, FW, LABEL) \
+ static void BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ SpatialConvolution(iters, NT, N, H, W, C, FC, FH, FW); \
+ } \
+ BENCHMARK(BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, FW))
+
+#define BM_SpatialConvolutionBwdInput(NT, N, H, W, C, FC, FH, FW, LABEL) \
+ static void BM_SPATIAL_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, \
+ FH, FW)(int iters) { \
+ SpatialConvolutionBackwardInput(iters, NT, N, H, W, C, FC, FH, FW); \
+ } \
+ BENCHMARK( \
+ BM_SPATIAL_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, FH, FW))
+
+#define BM_SpatialConvolutionBwdKernel(NT, N, H, W, C, FC, FH, FW, LABEL) \
+ static void BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \
+ FH, FW)(int iters) { \
+ SpatialConvolutionBackwardKernel(iters, NT, N, H, W, C, FC, FH, FW); \
+ } \
+ BENCHMARK(BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \
+ FH, FW))
+
+#define BM_SpatialConvolutions(N, H, W, C, FC, FH, FW, LABEL) \
+ BM_SpatialConvolution(2, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolution(4, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolution(8, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolution(16, N, H, W, C, FC, FH, FW, LABEL);
+
+#define BM_SpatialConvolutionsBwdInput(N, H, W, C, FC, FH, FW, LABEL) \
+ BM_SpatialConvolutionBwdInput(2, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdInput(4, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdInput(8, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdInput(16, N, H, W, C, FC, FH, FW, LABEL);
+
+#define BM_SpatialConvolutionsBwdKernel(N, H, W, C, FC, FH, FW, LABEL) \
+ BM_SpatialConvolutionBwdKernel(2, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdKernel(4, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdKernel(8, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdKernel(16, N, H, W, C, FC, FH, FW, LABEL);
+
+// ImageNet Forward Convolutions -------------------------------------------- //
+
+BM_SpatialConvolutions(32, // batch size
+ 56, 56, 64, // input: height, width, depth
+ 192, 3, 3, // filter: count, height, width
+ "conv2_00");
+
+BM_SpatialConvolutions(32, 28, 28, 96, 128, 3, 3, "conv3a_00_3x3");
+BM_SpatialConvolutions(32, 28, 28, 16, 32, 5, 5, "conv3a_00_5x5");
+BM_SpatialConvolutions(32, 28, 28, 128, 192, 3, 3, "conv3_00_3x3");
+BM_SpatialConvolutions(32, 28, 28, 32, 96, 5, 5, "conv3_00_5x5");
+BM_SpatialConvolutions(32, 14, 14, 96, 204, 3, 3, "conv4a_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 16, 48, 5, 5, "conv4a_00_5x5");
+BM_SpatialConvolutions(32, 14, 14, 112, 224, 3, 3, "conv4b_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 24, 64, 5, 5,
+ "conv4b_00_5x5 / conv4c_00_5x5");
+BM_SpatialConvolutions(32, 14, 14, 128, 256, 3, 3, "conv4c_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 144, 288, 3, 3, "conv4d_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 32, 64, 5, 5, "conv4d_00_5x5");
+BM_SpatialConvolutions(32, 14, 14, 160, 320, 3, 3, "conv4_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 32, 128, 5, 5, "conv4_00_5x5");
+BM_SpatialConvolutions(32, 7, 7, 160, 320, 3, 3, "conv5a_00_3x3");
+BM_SpatialConvolutions(32, 7, 7, 48, 128, 5, 5, "conv5a_00_5x5 / conv5_00_5x5");
+BM_SpatialConvolutions(32, 7, 7, 192, 384, 3, 3, "conv5_00_3x3");
+
+// Benchmarks from https://github.com/soumith/convnet-benchmarks
+BM_SpatialConvolutions(128, 128, 128, 3, 96, 11, 11, "convnet-layer1");
+BM_SpatialConvolutions(128, 64, 64, 64, 128, 9, 9, "convnet-layer2");
+BM_SpatialConvolutions(128, 32, 32, 128, 128, 9, 9, "convnet-layer3");
+BM_SpatialConvolutions(128, 16, 16, 128, 128, 7, 7, "convnet-layer4");
+BM_SpatialConvolutions(128, 13, 13, 384, 384, 3, 3, "convnet-layer5");
+
+// ImageNet BackwardInput Convolutions -------------------------------------- //
+
+BM_SpatialConvolutionsBwdInput(32, 56, 56, 64, 192, 3, 3, "conv2_00");
+BM_SpatialConvolutionsBwdInput(32, 28, 28, 96, 128, 3, 3, "conv3a_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 28, 28, 16, 32, 5, 5, "conv3a_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 28, 28, 128, 192, 3, 3, "conv3_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 28, 28, 32, 96, 5, 5, "conv3_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 96, 204, 3, 3, "conv4a_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 16, 48, 5, 5, "conv4a_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 112, 224, 3, 3, "conv4b_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 24, 64, 5, 5,
+ "conv4b_00_5x5 / conv4c_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 128, 256, 3, 3, "conv4c_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 144, 288, 3, 3, "conv4d_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 32, 64, 5, 5, "conv4d_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 160, 320, 3, 3, "conv4_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 32, 128, 5, 5, "conv4_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 7, 7, 160, 320, 3, 3, "conv5a_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 7, 7, 48, 128, 5, 5,
+ "conv5a_00_5x5 / conv5_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 7, 7, 192, 384, 3, 3, "conv5_00_3x3");
+
+// ImageNet BackwardKernel Convolutions ------------------------------------- //
+
+BM_SpatialConvolutionsBwdKernel(32, 56, 56, 64, 192, 3, 3, "conv2_00");
+BM_SpatialConvolutionsBwdKernel(32, 28, 28, 96, 128, 3, 3, "conv3a_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 28, 28, 16, 32, 5, 5, "conv3a_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 28, 28, 128, 192, 3, 3, "conv3_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 28, 28, 32, 96, 5, 5, "conv3_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 96, 204, 3, 3, "conv4a_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 16, 48, 5, 5, "conv4a_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 112, 224, 3, 3, "conv4b_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 24, 64, 5, 5,
+ "conv4b_00_5x5 / conv4c_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 128, 256, 3, 3, "conv4c_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 144, 288, 3, 3, "conv4d_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 32, 64, 5, 5, "conv4d_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 160, 320, 3, 3, "conv4_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 32, 128, 5, 5, "conv4_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 7, 7, 160, 320, 3, 3, "conv5a_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 7, 7, 48, 128, 5, 5,
+ "conv5a_00_5x5 / conv5_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 7, 7, 192, 384, 3, 3, "conv5_00_3x3");
+
+// -------------------------------------------------------------------------- //
+// Cuboid Convolutions //
+// -------------------------------------------------------------------------- //
+
+void CuboidConvolution(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height, int input_width,
+ int input_planes, int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height, int filter_width,
+ int filter_planes) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(
+ input_batches, input_height, input_width, input_planes, input_depth);
+ typename Benchmark::Dimensions filter_dims(
+ filter_height, filter_width, filter_planes, input_depth, filter_count);
+
+ benchmark.CuboidConvolution(input_dims, filter_dims);
+
+ auto output_size = input_dims.TotalSize();
+ auto flops = output_size *
+ (input_depth * filter_height * filter_width * filter_planes);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+void CuboidConvolutionBackwardInput(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height,
+ int input_width, int input_planes,
+ int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height,
+ int filter_width, int filter_planes) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(
+ input_batches, input_height, input_width, input_planes, input_depth);
+ typename Benchmark::Dimensions filter_dims(
+ filter_height, filter_width, filter_planes, input_depth, filter_count);
+
+ benchmark.CuboidConvolutionBackwardInput(input_dims, filter_dims);
+
+ auto output_size = input_dims.TotalSize();
+ auto flops = output_size *
+ (input_depth * filter_height * filter_width * filter_planes);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+void CuboidConvolutionBackwardKernel(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height,
+ int input_width, int input_planes,
+ int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height,
+ int filter_width, int filter_planes) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(
+ input_batches, input_height, input_width, input_planes, input_depth);
+ typename Benchmark::Dimensions filter_dims(
+ filter_height, filter_width, filter_planes, input_depth, filter_count);
+
+ benchmark.CuboidConvolutionBackwardKernel(input_dims, filter_dims);
+
+ auto filter_size = filter_dims.TotalSize();
+ auto flops =
+ filter_size * (input_batches * input_height * input_width * input_planes);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+// Macro arguments names: --------------------------------------------------- //
+// NT: num threads
+// N: batch size
+// H: height
+// W: width
+// P: panes
+// C: channels
+// FC: filter count
+// FH: filter height
+// FW: filter width
+// FP: filter panes
+
+#define BM_CONCAT(a, b) a##b
+
+#define BM_CUBOID_NAME(p, NT, N, H, W, P, C, FC, FH, FW, FP) \
+ BM_CONCAT(BM_##p##_CPU_##NT##T_in_##N##_##H##_##W##_##P##_##C, \
+ _f_##FC##_##FH##_##FW##_##FP)
+
+#define BM_CuboidConvolution(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ static void BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, \
+ FP)(int iters) { \
+ CuboidConvolution(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
+ } \
+ BENCHMARK( \
+ BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, FP))
+
+#define BM_CuboidConvolutionBwdInput(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ static void BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \
+ FH, FW, FP)(int iters) { \
+ CuboidConvolutionBackwardInput(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
+ } \
+ BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \
+ FH, FW, FP))
+
+#define BM_CuboidConvolutionBwdKernel(NT, N, H, W, P, C, FC, FH, FW, FP, \
+ LABEL) \
+ static void BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, \
+ FC, FH, FW, FP)(int iters) { \
+ CuboidConvolutionBackwardKernel(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
+ } \
+ BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, FC, \
+ FH, FW, FP))
+
+#define BM_CuboidConvolutions(N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ BM_CuboidConvolution(2, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolution(4, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolution(8, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolution(16, N, H, W, P, C, FC, FH, FW, FP, LABEL);
+
+#define BM_CuboidConvolutionsBwdInput(N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ BM_CuboidConvolutionBwdInput(2, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdInput(4, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdInput(8, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdInput(16, N, H, W, P, C, FC, FH, FW, FP, LABEL);
+
+#define BM_CuboidConvolutionsBwdKernel(N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ BM_CuboidConvolutionBwdKernel(2, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdKernel(4, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdKernel(8, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdKernel(16, N, H, W, P, C, FC, FH, FW, FP, LABEL);
+
+// Random Cuboid Convolutions ----------------------------------------------- //
+// TODO(ezhulenev): find representative dims for cuboid convolutions (find
+// models using Conv3D ops).
+
+BM_CuboidConvolutions(8, // batch size
+ 25, 25, 25, 4, // input: height, width, panes, depth
+ 16, 5, 5, 5, // filter: count, height, width, panes
+ "conv3d");
+
+BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d");
+
+BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d");
diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc
index 83cd0e9b47..528b3c6bf0 100644
--- a/tensorflow/core/kernels/example_parsing_ops.cc
+++ b/tensorflow/core/kernels/example_parsing_ops.cc
@@ -264,9 +264,168 @@ class ParseSingleExampleOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("ParseSingleExample").Device(DEVICE_CPU),
ParseSingleExampleOp);
-class SingleSequenceExampleParserOp : public OpKernel {
+class ParseSequenceExampleOp : public OpKernel {
public:
- explicit SingleSequenceExampleParserOp(OpKernelConstruction* ctx)
+ explicit ParseSequenceExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, attrs_.Init(ctx));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* debug_name;
+ const Tensor* serialized;
+ OpInputList context_dense_defaults;
+
+ OP_REQUIRES_OK(ctx, ctx->input("debug_name", &debug_name));
+ OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized));
+ OP_REQUIRES_OK(ctx, ctx->input_list("context_dense_defaults",
+ &context_dense_defaults));
+
+ bool has_debug_name = (debug_name->NumElements() > 0);
+ if (has_debug_name) {
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(debug_name->shape()),
+ errors::InvalidArgument(
+ "Expected debug_name to be a vector, got shape: ",
+ debug_name->shape().DebugString()));
+ }
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(serialized->shape()),
+ errors::InvalidArgument(
+ "Expected serialized to be a vector, got shape: ",
+ serialized->shape().DebugString()));
+
+ OP_REQUIRES(ctx, context_dense_defaults.size() == attrs_.num_context_dense,
+ errors::InvalidArgument("Expected len(context_dense_defaults) "
+ "== len(context_dense_keys) but got: ",
+ context_dense_defaults.size(), " vs. ",
+ attrs_.num_context_dense));
+
+ std::vector<bool> required(attrs_.num_context_dense);
+ for (int d = 0; d < attrs_.num_context_dense; ++d) {
+ const Tensor& def_value = context_dense_defaults[d];
+ required[d] = (def_value.NumElements() == 0); // No default provided.
+
+ if (def_value.NumElements() > 0) {
+ OP_REQUIRES(ctx, def_value.shape() == attrs_.context_dense_shapes[d],
+ errors::InvalidArgument(
+ "default_value[", d,
+ "].shape() == ", def_value.shape().DebugString(),
+ " != context_dense_shapes[", d,
+ "] == ", attrs_.context_dense_shapes[d].DebugString()));
+ OP_REQUIRES(
+ ctx, def_value.dtype() == attrs_.context_dense_types[d],
+ errors::InvalidArgument(
+ "context_dense_defaults[", d, "].dtype() == ",
+ DataTypeString(def_value.dtype()), " != context_dense_types[",
+ d, "] == ", DataTypeString(attrs_.context_dense_types[d])));
+ }
+ }
+
+ example::Result context_result, feature_list_result;
+ std::vector<Tensor> dense_feature_lengths;
+
+ example::FastParseExampleConfig context_config;
+ for (int d = 0; d < attrs_.num_context_dense; ++d) {
+ context_config.dense.push_back(
+ {attrs_.context_dense_keys[d], attrs_.context_dense_types[d],
+ attrs_.context_dense_shapes[d], context_dense_defaults[d],
+ false /* attrs_.context_variable_length[d] */,
+ 0 /*attrs_.context_elements_per_stride[d] */});
+ }
+ for (int d = 0; d < attrs_.num_context_sparse; ++d) {
+ context_config.sparse.push_back(
+ {attrs_.context_sparse_keys[d], attrs_.context_sparse_types[d]});
+ }
+ example::FastParseExampleConfig feature_list_config;
+ for (int d = 0; d < attrs_.num_feature_list_dense; ++d) {
+ DataType dtype = attrs_.feature_list_dense_types[d];
+ Tensor default_value = Tensor(dtype, TensorShape({}));
+ feature_list_config.dense.push_back(
+ {attrs_.feature_list_dense_keys[d], dtype,
+ attrs_.feature_list_dense_shapes[d], default_value,
+ (attrs_.feature_list_dense_missing_assumed_empty.count(
+ attrs_.feature_list_dense_keys[d]) > 0),
+ 0 /*attrs_.context_elements_per_stride[d] */});
+ }
+ for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) {
+ feature_list_config.sparse.push_back(
+ {attrs_.feature_list_sparse_keys[d],
+ attrs_.feature_list_sparse_types[d]});
+ }
+
+ auto serialized_t = serialized->flat<string>();
+ auto debug_name_t = debug_name->flat<string>();
+ gtl::ArraySlice<string> slice(serialized_t.data(), serialized_t.size());
+ gtl::ArraySlice<string> names_slice(debug_name_t.data(),
+ debug_name_t.size());
+
+ OP_REQUIRES_OK(
+ ctx,
+ FastParseSequenceExample(
+ context_config, feature_list_config, slice, names_slice,
+ ctx->device()->tensorflow_cpu_worker_threads()->workers,
+ &context_result, &feature_list_result, &dense_feature_lengths));
+
+ OpOutputList context_sparse_indices;
+ OpOutputList context_sparse_values;
+ OpOutputList context_sparse_shapes;
+ OpOutputList context_dense_values;
+ OpOutputList feature_list_sparse_indices;
+ OpOutputList feature_list_sparse_values;
+ OpOutputList feature_list_sparse_shapes;
+ OpOutputList feature_list_dense_values;
+ OpOutputList feature_list_dense_lengths;
+
+ OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices",
+ &context_sparse_indices));
+ OP_REQUIRES_OK(
+ ctx, ctx->output_list("context_sparse_values", &context_sparse_values));
+ OP_REQUIRES_OK(
+ ctx, ctx->output_list("context_sparse_shapes", &context_sparse_shapes));
+ OP_REQUIRES_OK(
+ ctx, ctx->output_list("context_dense_values", &context_dense_values));
+ OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices",
+ &context_sparse_indices));
+ OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_indices",
+ &feature_list_sparse_indices));
+ OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_values",
+ &feature_list_sparse_values));
+ OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_shapes",
+ &feature_list_sparse_shapes));
+ OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_dense_values",
+ &feature_list_dense_values));
+ OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_dense_lengths",
+ &feature_list_dense_lengths));
+ for (int d = 0; d < attrs_.num_context_dense; ++d) {
+ context_dense_values.set(d, context_result.dense_values[d]);
+ }
+ TensorShape lengths_shape;
+ lengths_shape.AddDim(serialized_t.size());
+ for (int d = 0; d < attrs_.num_feature_list_dense; ++d) {
+ feature_list_dense_values.set(d, feature_list_result.dense_values[d]);
+ feature_list_dense_lengths.set(d, dense_feature_lengths[d]);
+ }
+ for (int d = 0; d < attrs_.num_context_sparse; ++d) {
+ context_sparse_indices.set(d, context_result.sparse_indices[d]);
+ context_sparse_values.set(d, context_result.sparse_values[d]);
+ context_sparse_shapes.set(d, context_result.sparse_shapes[d]);
+ }
+ for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) {
+ feature_list_sparse_indices.set(d, feature_list_result.sparse_indices[d]);
+ feature_list_sparse_values.set(d, feature_list_result.sparse_values[d]);
+ feature_list_sparse_shapes.set(d, feature_list_result.sparse_shapes[d]);
+ }
+ }
+
+ protected:
+ ParseSequenceExampleAttrs attrs_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ParseSequenceExample").Device(DEVICE_CPU),
+ ParseSequenceExampleOp);
+
+class ParseSingleSequenceExampleOp : public OpKernel {
+ public:
+ explicit ParseSingleSequenceExampleOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
OP_REQUIRES_OK(ctx, attrs_.Init(ctx));
}
@@ -658,7 +817,7 @@ class SingleSequenceExampleParserOp : public OpKernel {
};
REGISTER_KERNEL_BUILDER(Name("ParseSingleSequenceExample").Device(DEVICE_CPU),
- SingleSequenceExampleParserOp);
+ ParseSingleSequenceExampleOp);
#ifndef IS_MOBILE_PLATFORM
// when using lite protos on mobile, decoding JSON is not available.
diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
index ad0112e6cb..66ae7f0894 100644
--- a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
+++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
@@ -113,10 +113,25 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
#endif
generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator(
slice_size, Tindices, Tparams, Tout, &error_loc);
+
+#ifdef INTEL_MKL
+// Eigen implementation below is not highly performant. gather_nd_generator
+// does not seem to be called in parallel, leading to very poor performance.
+// Additionally, since it uses scalar (Tscratch) to invoke 'generate', it
+// needs to go through redundant operations like 'reshape', 'broadcast' and
+// 'sum'. OpenMP loop below essentially does same thing as Eigen code, but
+// is considerably more efficient.
+#pragma omp parallel for
+ for (Eigen::DenseIndex i = 0; i < batch_size; i++) {
+ const Eigen::array<Eigen::DenseIndex, 1> loc = i;
+ gather_nd_generator(loc);
+ }
+#else
Tscratch.device(d) = Tscratch.reshape(reshape_dims)
.broadcast(broadcast_dims)
.generate(gather_nd_generator)
.sum();
+#endif
// error_loc() returns -1 if there's no out-of-bounds index,
// otherwise it returns the location of an OOB index in Tindices.
diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc
index 84fa63fc00..bca1cff41c 100644
--- a/tensorflow/core/kernels/list_kernels.cc
+++ b/tensorflow/core/kernels/list_kernels.cc
@@ -588,7 +588,11 @@ REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU(bfloat16);
REGISTER_KERNEL_BUILDER(Name("TensorListStack") \
.TypeConstraint<T>("element_dtype") \
.Device(DEVICE_CPU), \
- TensorListStack<CPUDevice, T>)
+ TensorListStack<CPUDevice, T>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListGather") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_CPU), \
+ TensorListGather<CPUDevice, T>)
TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_STACK_CPU);
REGISTER_TENSOR_LIST_STACK_CPU(quint8);
@@ -604,7 +608,11 @@ REGISTER_TENSOR_LIST_STACK_CPU(bfloat16);
REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor") \
.TypeConstraint<T>("element_dtype") \
.Device(DEVICE_CPU), \
- TensorListFromTensor<CPUDevice, T>)
+ TensorListFromTensor<CPUDevice, T>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_CPU), \
+ TensorListScatter<CPUDevice, T>)
TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_FROM_TENSOR_CPU);
REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(quint8);
diff --git a/tensorflow/core/kernels/list_kernels.cu.cc b/tensorflow/core/kernels/list_kernels.cu.cc
index 0ea9362cbe..c591226b76 100644
--- a/tensorflow/core/kernels/list_kernels.cu.cc
+++ b/tensorflow/core/kernels/list_kernels.cu.cc
@@ -40,7 +40,12 @@ typedef Eigen::GpuDevice GPUDevice;
REGISTER_KERNEL_BUILDER(Name("TensorListStack") \
.TypeConstraint<T>("element_dtype") \
.Device(DEVICE_GPU), \
- TensorListStack<GPUDevice, T>)
+ TensorListStack<GPUDevice, T>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListGather") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("indices"), \
+ TensorListGather<GPUDevice, T>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_STACK_GPU);
REGISTER_TENSOR_LIST_STACK_GPU(bfloat16);
@@ -71,7 +76,13 @@ REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU(bool);
.TypeConstraint<T>("element_dtype") \
.Device(DEVICE_GPU) \
.HostMemory("element_shape"), \
- TensorListFromTensor<GPUDevice, T>)
+ TensorListFromTensor<GPUDevice, T>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("element_shape") \
+ .HostMemory("indices"), \
+ TensorListScatter<GPUDevice, T>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_FROM_TENSOR_GPU);
REGISTER_TENSOR_LIST_FROM_TENSOR_GPU(bfloat16);
diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h
index b3f74c060b..066a1d603b 100644
--- a/tensorflow/core/kernels/list_kernels.h
+++ b/tensorflow/core/kernels/list_kernels.h
@@ -134,6 +134,74 @@ class TensorListStack : public OpKernel {
};
template <typename Device, typename T>
+class TensorListGather : public OpKernel {
+ public:
+ typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
+ ConstMatrixVector;
+ explicit TensorListGather(OpKernelConstruction* c) : OpKernel(c) {
+ OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
+ }
+
+ void Compute(OpKernelContext* c) override {
+ const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>();
+ OP_REQUIRES(c, l != nullptr,
+ errors::InvalidArgument(
+ "Input handle is not a list. Saw: '",
+ c->input(0).scalar<Variant>()().DebugString(), "'"));
+ OP_REQUIRES(c, element_dtype_ == l->element_dtype,
+ errors::InvalidArgument("Invalid data types; op elements ",
+ DataTypeString(element_dtype_),
+ " but list elements ",
+ DataTypeString(l->element_dtype)));
+ OP_REQUIRES(c, l->element_shape.IsFullyDefined(),
+ errors::InvalidArgument("Tried to stack elements from a list "
+ "with non-fully-defined shape: ",
+ l->element_shape.DebugString()));
+ Tensor indices = c->input(1);
+ TensorShape resulting_shape;
+ resulting_shape.AddDim(indices.NumElements());
+ for (TensorShapeDim s : l->element_shape) {
+ resulting_shape.AddDim(s.size);
+ }
+ Tensor* output;
+ OP_REQUIRES_OK(c, c->allocate_output(0, resulting_shape, &output));
+ if (output->NumElements() == 0) {
+ return;
+ }
+
+ ConstMatrixVector inputs_flat;
+ inputs_flat.reserve(l->tensors.size());
+ for (int index = 0; index < indices.NumElements(); ++index) {
+ const int i = indices.flat<int32>()(index);
+ OP_REQUIRES(
+ c, i < l->tensors.size(),
+ errors::InvalidArgument("Index ", i, " out o range; list only has ",
+ l->tensors.size(), " elements."));
+ const Tensor& t = l->tensors[i];
+ OP_REQUIRES(c, l->element_shape.IsCompatibleWith(t.shape()),
+ errors::InvalidArgument(
+ "Tensor with invalid shape in list. List element shape: ",
+ l->element_shape.DebugString(),
+ " and tensor shape: ", t.shape().DebugString()));
+ inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
+ t.shaped<T, 2>({1, t.NumElements()})));
+ }
+ auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
+
+#if GOOGLE_CUDA
+ if (std::is_same<Device, Eigen::GpuDevice>::value) {
+ ConcatGPU<T>(c, inputs_flat, output, &output_flat);
+ return;
+ }
+#endif // GOOGLE_CUDA
+ ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
+ }
+
+ private:
+ DataType element_dtype_;
+};
+
+template <typename Device, typename T>
class TensorListFromTensor : public OpKernel {
public:
TensorListFromTensor(OpKernelConstruction* c) : OpKernel(c) {}
@@ -178,6 +246,59 @@ class TensorListFromTensor : public OpKernel {
}
};
+template <typename Device, typename T>
+class TensorListScatter : public OpKernel {
+ public:
+ TensorListScatter(OpKernelConstruction* c) : OpKernel(c) {}
+
+ void Compute(OpKernelContext* c) override {
+ Tensor* output_tensor;
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
+ Tensor indices = c->input(1);
+ PartialTensorShape element_shape;
+ OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(2), &element_shape));
+ TensorList output_list;
+ const Tensor& t = c->input(0);
+ output_list.element_dtype = t.dtype();
+ OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
+ errors::InvalidArgument(
+ "Tensor must be at least a vector, but saw shape: ",
+ t.shape().DebugString()));
+ TensorShape output_shape(t.shape());
+ output_shape.RemoveDim(0);
+ OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
+ errors::InvalidArgument(
+ "Specified a list with shape ", element_shape.DebugString(),
+ " from a tensor with shape ", output_shape.DebugString()));
+ output_list.element_shape = element_shape;
+ output_list.tensors.reserve(indices.NumElements());
+ for (int index = 0; index < indices.NumElements(); ++index) {
+ const int i = indices.flat<int32>()(index);
+ OP_REQUIRES(c, i < t.shape().dim_size(0),
+ errors::InvalidArgument("Trying to scatter index ", i,
+ " from tensor with ",
+ t.shape().dim_size(0), " rows."));
+ Tensor tmp = t.Slice(i, i + 1);
+ TensorShape tmp_shape = tmp.shape();
+ tmp_shape.RemoveDim(0);
+ OP_REQUIRES(c, tmp.CopyFrom(tmp, tmp_shape),
+ errors::Unknown("Unexpected shape error."));
+ // TODO(apassos) maybe not always align; but weird compiler bugs seem to
+ // prevent this.
+ Tensor aligned;
+ OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
+ // TODO(apassos) do all slices in a single kernel invocation instead of
+ // many small ondes.
+ aligned.flat<T>().device(c->eigen_device<Device>()) =
+ tmp.unaligned_flat<T>();
+ output_list.tensors.push_back(aligned);
+ }
+ output_tensor->scalar<Variant>()() = std::move(output_list);
+ }
+};
+
template <typename Device>
Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a,
const TensorList& b, TensorList* out) {
diff --git a/tensorflow/core/kernels/logistic-loss.h b/tensorflow/core/kernels/logistic-loss.h
index b43902e0b9..9198a98e47 100644
--- a/tensorflow/core/kernels/logistic-loss.h
+++ b/tensorflow/core/kernels/logistic-loss.h
@@ -86,7 +86,7 @@ class LogisticLossUpdater : public DualLossUpdater {
} else {
inverse_exp_term = 1 / (1 + exp(label * wx));
}
- return inverse_exp_term * label * example_weight;
+ return -inverse_exp_term * label * example_weight;
}
// The smoothness constant is 4 since the derivative of logistic loss, which
diff --git a/tensorflow/core/kernels/loss_test.cc b/tensorflow/core/kernels/loss_test.cc
index 460d65c5c2..9209ed2ab7 100644
--- a/tensorflow/core/kernels/loss_test.cc
+++ b/tensorflow/core/kernels/loss_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/hinge-loss.h"
#include "tensorflow/core/kernels/logistic-loss.h"
+#include "tensorflow/core/kernels/poisson-loss.h"
#include "tensorflow/core/kernels/smooth-hinge-loss.h"
#include "tensorflow/core/kernels/squared-loss.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -30,6 +31,24 @@ namespace {
// TODO(sibyl-Aix6ihai): add a test to show the improvements of the Newton
// modification detailed in readme.md
+// This test checks that the dual value after update is optimal.
+// At the optimum the dual value should be the opposite of the primal gradient.
+// This does not hold at a point where the primal is not differentiable.
+void TestComputeUpdatedDual(const DualLossUpdater &loss_updater,
+ const int num_loss_partitions, const double label,
+ const double example_weight,
+ const double current_dual, const double wx,
+ const double weighted_example_norm) {
+ double new_dual = loss_updater.ComputeUpdatedDual(
+ num_loss_partitions, label, example_weight, current_dual, wx,
+ weighted_example_norm);
+ // The primal gradient needs to be computed after the weight update.
+ double new_wx = wx + (new_dual - current_dual) * num_loss_partitions *
+ weighted_example_norm * example_weight;
+ EXPECT_NEAR(new_dual, -loss_updater.PrimalLossDerivative(new_wx, label, 1.0),
+ 1e-5);
+}
+
TEST(LogisticLoss, ComputePrimalLoss) {
LogisticLossUpdater loss_updater;
EXPECT_NEAR(0.693147,
@@ -65,19 +84,12 @@ TEST(LogisticLoss, ComputeDualLoss) {
TEST(LogisticLoss, ComputeUpdatedDual) {
LogisticLossUpdater loss_updater;
- EXPECT_NEAR(0.479,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, 1.0 /* label */,
- 1.0 /* example weight */, 0.5 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
-
- EXPECT_NEAR(-0.031,
- loss_updater.ComputeUpdatedDual(
- 2 /* num partitions */, -1.0 /* label */,
- 1.0 /* example weight */, 0.1 /* current_dual */,
- -0.8 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */,
+ 1.0 /* example weight */, 0.5 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 2 /* num partitions */, -1.0 /* label */,
+ 1.0 /* example weight */, 0.1 /* current_dual */,
+ -0.8 /* wx */, 10.0 /* weighted_example_norm */);
}
TEST(SquaredLoss, ComputePrimalLoss) {
@@ -126,19 +138,12 @@ TEST(SquaredLoss, ComputeDualLoss) {
TEST(SquaredLoss, ComputeUpdatedDual) {
SquaredLossUpdater loss_updater;
- EXPECT_NEAR(0.336,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, 1.0 /* label */,
- 1.0 /* example weight */, 0.3 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
-
- EXPECT_NEAR(-0.427,
- loss_updater.ComputeUpdatedDual(
- 5 /* num partitions */, -1.0 /* label */,
- 1.0 /* example weight */, -0.4 /* current_dual */,
- 0.8 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */,
+ 1.0 /* example weight */, 0.3 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 5 /* num partitions */, -1.0 /* label */,
+ 1.0 /* example weight */, -0.4 /* current_dual */,
+ 0.8 /* wx */, 10.0 /* weighted_example_norm */);
}
TEST(HingeLoss, ComputePrimalLoss) {
@@ -207,48 +212,27 @@ TEST(HingeLoss, ConvertLabel) {
TEST(HingeLoss, ComputeUpdatedDual) {
HingeLossUpdater loss_updater;
- // When label=1.0, example_weight=1.0, current_dual=0.5, wx=0.3 and
- // weighted_example_norm=100.0, it turns out that the optimal value to update
- // the dual to is 0.507 which is within the permitted range and thus should be
- // the value returned.
+ // For the two tests belows, y*wx=1 after the update which is a
+ // non-differetiable point of the hinge loss and TestComputeUpdatedDual
+ // cannot be used. Check value of the dual variable instead.
EXPECT_NEAR(0.507,
loss_updater.ComputeUpdatedDual(
1 /* num partitions */, 1.0 /* label */,
1.0 /* example weight */, 0.5 /* current_dual */,
0.3 /* wx */, 100.0 /* weighted_example_norm */),
1e-3);
- // When label=-1.0, example_weight=1.0, current_dual=0.4, wx=0.6,
- // weighted_example_norm=10.0 and num_loss_partitions=10, it turns out that
- // the optimal value to update the dual to is 0.384 which is within the
- // permitted range and thus should be the value returned.
EXPECT_NEAR(-0.416,
loss_updater.ComputeUpdatedDual(
10 /* num partitions */, -1.0 /* label */,
1.0 /* example weight */, -0.4 /* current_dual */,
0.6 /* wx */, 10.0 /* weighted_example_norm */),
1e-3);
- // When label=1.0, example_weight=1.0, current_dual=-0.5, wx=0.3 and
- // weighted_example_norm=10.0, it turns out that the optimal value to update
- // the dual to is -0.43. However, this is outside the allowed [0.0, 1.0] range
- // and hence the closest permitted value (0.0) should be returned instead.
- EXPECT_NEAR(0.0,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, 1.0 /* label */,
- 1.0 /* example weight */, -0.5 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
-
- // When label=-1.0, example_weight=2.0, current_dual=-1.0, wx=0.3 and
- // weighted_example_norm=10.0, it turns out that the optimal value to update
- // the dual to is -1.065. However, this is outside the allowed [-1.0, 0.0]
- // range and hence the closest permitted value (-1.0) should be returned
- // instead.
- EXPECT_NEAR(-1.0,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, -1.0 /* label */,
- 2.0 /* example weight */, -1.0 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */,
+ 1.0 /* example weight */, -0.5 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, -1.0 /* label */,
+ 2.0 /* example weight */, -1.0 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
}
TEST(SmoothHingeLoss, ComputePrimalLoss) {
@@ -297,19 +281,75 @@ TEST(SmoothHingeLoss, ComputeDualLoss) {
TEST(SmoothHingeLoss, ComputeUpdatedDual) {
SmoothHingeLossUpdater loss_updater;
- EXPECT_NEAR(0.336,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, 1.0 /* label */,
- 1.0 /* example weight */, 0.3 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */,
+ 1.0 /* example weight */, 0.3 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 5 /* num partitions */, -1.0 /* label */,
+ 1.0 /* example weight */, -0.4 /* current_dual */,
+ 0.8 /* wx */, 10.0 /* weighted_example_norm */);
+}
- EXPECT_NEAR(-0.427,
- loss_updater.ComputeUpdatedDual(
- 5 /* num partitions */, -1.0 /* label */,
- 1.0 /* example weight */, -0.4 /* current_dual */,
- 0.8 /* wx */, 10.0 /* weighted_example_norm */),
+TEST(PoissonLoss, ComputePrimalLoss) {
+ PoissonLossUpdater loss_updater;
+ EXPECT_NEAR(1.0,
+ loss_updater.ComputePrimalLoss(0.0 /* wx */, 3.0 /* label */,
+ 1.0 /* example weight */),
1e-3);
+ EXPECT_NEAR(21996.0,
+ loss_updater.ComputePrimalLoss(10.0 /* wx */, 3.0 /* label */,
+ 1.0 /* example weight */),
+ 1.0);
+ EXPECT_NEAR(0.606,
+ loss_updater.ComputePrimalLoss(-0.5 /* wx */, 0.0 /* label */,
+ 1.0 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(6.64,
+ loss_updater.ComputePrimalLoss(1.2 /* wx */, 0.0 /* label */,
+ 2.0 /* example weight */),
+ 1e-2);
+}
+
+TEST(PoissonLoss, ComputeDualLoss) {
+ PoissonLossUpdater loss_updater;
+ // Dual is undefined.
+ EXPECT_NEAR(
+ std::numeric_limits<double>::max(),
+ loss_updater.ComputeDualLoss(1.0 /* current dual */, 0.0 /* label */,
+ 1.0 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(
+ 0.0,
+ loss_updater.ComputeDualLoss(0.0 /* current dual */, 0.0 /* label */,
+ 3.0 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(
+ -0.847,
+ loss_updater.ComputeDualLoss(1.5 /* current dual */, 2.0 /* label */,
+ 1.0 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(
+ -2.675,
+ loss_updater.ComputeDualLoss(0.5 /* current dual */, 2.0 /* label */,
+ 3.0 /* example weight */),
+ 1e-3);
+}
+
+TEST(PoissonLoss, ConvertLabel) {
+ PoissonLossUpdater loss_updater;
+ float example_label = -1.0;
+ // Negative label should throw an error.
+ Status status = loss_updater.ConvertLabel(&example_label);
+ EXPECT_FALSE(status.ok());
+}
+
+TEST(PoissonLoss, ComputeUpdatedDual) {
+ PoissonLossUpdater loss_updater;
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 2.0 /* label */,
+ 1.0 /* example weight */, 0.5 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 2 /* num partitions */, 0.0 /* label */,
+ 1.0 /* example weight */, 0.0 /* current_dual */,
+ -0.8 /* wx */, 10.0 /* weighted_example_norm */);
}
} // namespace
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index afbfaa83f3..52157ed5fb 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -300,19 +300,24 @@ template <typename T>
class MklConvBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
static MklConvBwdFilterPrimitive<T>* Get(
- const MklConvBwdFilterParams& convBwdFilterDims) {
+ const MklConvBwdFilterParams& convBwdFilterDims, bool do_not_cache) {
MklConvBwdFilterPrimitive<T>* conv_bwd_filter = nullptr;
- // look into the pool for reusable primitive
- conv_bwd_filter = dynamic_cast<MklConvBwdFilterPrimitive<T>*>(
+ if (do_not_cache) { /* Create new primitive always */
+ conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims);
+ } else {
+ // look into the pool for reusable primitive
+ conv_bwd_filter = dynamic_cast<MklConvBwdFilterPrimitive<T>*> (
MklConvBwdFilterPrimitiveFactory<T>::GetInstance().GetConvBwdFilter(
convBwdFilterDims));
- if (conv_bwd_filter == nullptr) {
- conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims);
- MklConvBwdFilterPrimitiveFactory<T>::GetInstance().SetConvBwdFilter(
- convBwdFilterDims, conv_bwd_filter);
+ if (conv_bwd_filter == nullptr) {
+ conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims);
+ MklConvBwdFilterPrimitiveFactory<T>::GetInstance().SetConvBwdFilter(
+ convBwdFilterDims, conv_bwd_filter);
+ }
}
+
return conv_bwd_filter;
}
@@ -845,8 +850,13 @@ class MklConvCustomBackpropFilterOp
MklConvBwdFilterParams convBwdFilterDims(fwd_src_dims, fwd_filter_dims,
diff_bias_dims, diff_dst_dims, strides, dilations, padding_left,
padding_right, TFPaddingToMklDnnPadding(this->padding_));
- conv_bwd_filter =
- MklConvBwdFilterPrimitiveFactory<T>::Get(convBwdFilterDims);
+
+ // MKL DNN allocates large buffers when a conv gradient filter primtive is
+ // created. So we don't cache conv backward primitives when the env
+ // variable TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is set to true.
+ bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled();
+ conv_bwd_filter = MklConvBwdFilterPrimitiveFactory<T>::Get(
+ convBwdFilterDims, do_not_cache);
auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc();
// allocate output tensors: diff_fitler and diff_bias (w bias)
@@ -938,6 +948,9 @@ class MklConvCustomBackpropFilterOp
if (diff_filter_reorder_required) {
diff_filter.InsertReorderToUserMem();
}
+
+ // delete primitive since it is not cached.
+ if (do_not_cache) delete conv_bwd_filter;
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index b5a98301e2..c38c9cc27c 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -174,7 +174,6 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
}
};
-
void Setup(const MklConvBwdInputParams& convBwdInputDims) {
// create memory descriptors for convolution data w/ no specified format
context_.diff_src_md.reset(new memory::desc(
@@ -242,19 +241,23 @@ class MklConvBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
static MklConvBwdInputPrimitive<T>* Get(
- const MklConvBwdInputParams& convBwdInputDims) {
+ const MklConvBwdInputParams& convBwdInputDims, bool do_not_cache) {
MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr;
- // look into the pool for reusable primitive
- conv_bwd_input = dynamic_cast<MklConvBwdInputPrimitive<T>*>(
- MklConvBwdInputPrimitiveFactory<T>::GetInstance().GetConvBwdInput(
- convBwdInputDims));
-
- if (conv_bwd_input == nullptr) {
+ if (do_not_cache) { /* Always allocate primitive */
conv_bwd_input = new MklConvBwdInputPrimitive<T>(convBwdInputDims);
- MklConvBwdInputPrimitiveFactory<T>::GetInstance().SetConvBwdInput(
- convBwdInputDims, conv_bwd_input);
+ } else {
+ // look into the pool for reusable primitive
+ conv_bwd_input = dynamic_cast<MklConvBwdInputPrimitive<T>*>(
+ MklConvBwdInputPrimitiveFactory<T>::GetInstance().GetConvBwdInput(
+ convBwdInputDims));
+ if (conv_bwd_input == nullptr) {
+ conv_bwd_input = new MklConvBwdInputPrimitive<T>(convBwdInputDims);
+ MklConvBwdInputPrimitiveFactory<T>::GetInstance().SetConvBwdInput(
+ convBwdInputDims, conv_bwd_input);
+ }
}
+
return conv_bwd_input;
}
@@ -708,8 +711,18 @@ class MklConvCustomBackpropInputOp : public MklConvBackpropCommonOp<Device, T> {
MklConvBwdInputParams convBwdInputDims(fwd_src_dims, fwd_filter_dims,
diff_dst_dims, strides, dilations, padding_left, padding_right,
TFPaddingToMklDnnPadding(this->padding_));
- conv_bwd_input =
- MklConvBwdInputPrimitiveFactory<T>::Get(convBwdInputDims);
+
+ // We don't cache those primitves if the env variable
+ // TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is true and if primitve descriptor
+ // includes potentialy large buffers. MKL DNN allocates buffers
+ // in the following cases
+ // 1. Legacy CPU without AVX512/AVX2, or
+ // 2. 1x1 convolution with stride != 1
+ bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled() &&
+ (MklPrimitiveFactory<T>::IsLegacyPlatform() ||
+ IsConv1x1StrideNot1(fwd_filter_dims, strides));
+ conv_bwd_input = MklConvBwdInputPrimitiveFactory<T>::Get(convBwdInputDims,
+ do_not_cache);
auto bwd_input_pd = conv_bwd_input->GetPrimitiveDesc();
// allocate output tensor
@@ -755,6 +768,11 @@ class MklConvCustomBackpropInputOp : public MklConvBackpropCommonOp<Device, T> {
// execute convolution input bwd
conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
+
+ // delete primitive since it is not cached.
+ if (do_not_cache) {
+ delete conv_bwd_input;
+ }
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index c6295c7280..9b10c3f3d6 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -271,18 +271,23 @@ class MklConvFwdPrimitive : public MklPrimitive {
template <typename T>
class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
- static MklConvFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims) {
+ static MklConvFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims,
+ bool do_not_cache) {
MklConvFwdPrimitive<T>* conv_fwd = nullptr;
- // try to find a suitable one in pool
- conv_fwd = dynamic_cast<MklConvFwdPrimitive<T>*>(
- MklConvFwdPrimitiveFactory<T>::GetInstance().GetConvFwd(convFwdDims));
-
- if (conv_fwd == nullptr) {
+ if (do_not_cache) { /* Always create new primitive */
conv_fwd = new MklConvFwdPrimitive<T>(convFwdDims);
- MklConvFwdPrimitiveFactory<T>::GetInstance().SetConvFwd(convFwdDims,
- conv_fwd);
+ } else {
+ // try to find a suitable one in pool
+ conv_fwd = dynamic_cast<MklConvFwdPrimitive<T>*>(
+ MklConvFwdPrimitiveFactory<T>::GetInstance().GetConvFwd(convFwdDims));
+ if (conv_fwd == nullptr) {
+ conv_fwd = new MklConvFwdPrimitive<T>(convFwdDims);
+ MklConvFwdPrimitiveFactory<T>::GetInstance().SetConvFwd(convFwdDims,
+ conv_fwd);
+ }
}
+
return conv_fwd;
}
@@ -894,6 +899,17 @@ class MklConvOp : public OpKernel {
// MKLDNN dilation starts from 0.
for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
+ // In some cases, primitve descriptor includes potentialy large buffers,
+ // we don't cache those primitves if the env variable
+ // TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is true. MKL DNN allocates buffers
+ // in the following cases
+ // 1. Legacy CPU without AVX512/AVX2, or
+ // 2. 1x1 convolution with stride != 1
+ bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled() &&
+ (src_dims[MklDnnDims::Dim_N] > kSmallBatchSize) &&
+ (MklPrimitiveFactory<T>::IsLegacyPlatform() ||
+ IsConv1x1StrideNot1(filter_dims, strides));
+
// get a conv2d fwd from primitive pool
MklConvFwdPrimitive<T>* conv_fwd = nullptr;
if (biasEnabled) {
@@ -902,12 +918,14 @@ class MklConvOp : public OpKernel {
MklConvFwdParams convFwdDims(src_dims, filter_dims, bias_dims,
dst_dims_mkl_order, strides, dilations,
padding_left, padding_right);
- conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(convFwdDims);
+ conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(
+ convFwdDims, do_not_cache);
} else {
MklConvFwdParams convFwdDims(src_dims, filter_dims, NONE_DIMS,
dst_dims_mkl_order, strides, dilations,
padding_left, padding_right);
- conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(convFwdDims);
+ conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(
+ convFwdDims, do_not_cache);
}
// allocate output tensors output_tensor and filter_out_tensor
@@ -952,6 +970,9 @@ class MklConvOp : public OpKernel {
} else {
conv_fwd->Execute(src_data, filter_data, dst_data);
}
+
+ // delete primitive since it is not cached.
+ if (do_not_cache) delete conv_fwd;
} catch (mkldnn::error &e) {
string error_msg = tensorflow::strings::StrCat(
"Status: ", e.status, ", message: ", string(e.message), ", in file ",
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index b78b763fd6..f4cfc48af5 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -35,6 +35,7 @@ using mkldnn::prop_kind;
using mkldnn::relu_backward;
using mkldnn::relu_forward;
using mkldnn::stream;
+using mkldnn::memory;
#else
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
@@ -867,11 +868,12 @@ class MklReluOpBase : public OpKernel {
eltwise_fwd->Execute(src_data, dst_data);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + string(e.message) + ", in file " +
- string(__FILE__) + ":" + std::to_string(__LINE__);
- OP_REQUIRES_OK(
- context,
- errors::Aborted("Operation received an exception:", error_msg));
+ ", message: " + string(e.message) +
+ ", in file " + string(__FILE__) + ":" +
+ std::to_string(__LINE__);
+ OP_REQUIRES_OK(context,
+ errors::Aborted("Operation received an exception:",
+ error_msg));
}
}
@@ -886,7 +888,8 @@ class MklReluGradOpBase : public OpKernel {
~MklReluGradOpBase() {}
explicit MklReluGradOpBase(OpKernelConstruction* context)
- : OpKernel(context) {}
+ : OpKernel(context) {
+ }
virtual void Compute_Scalar(OpKernelContext* context) = 0;
@@ -942,8 +945,12 @@ class MklReluGradOpBase : public OpKernel {
dnn_shape_diff_dst.GetTfDataFormat();
auto diff_dst_tf_data_format =
MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format);
- src_dims = TFShapeToMklDnnDimsInNCHW(src_tensor.shape(),
- diff_dst_tf_data_format);
+
+ src_dims = (src_tensor.dims() == 4)
+ ? TFShapeToMklDnnDimsInNCHW(src_tensor.shape(),
+ diff_dst_tf_data_format)
+ : TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(),
+ diff_dst_tf_data_format);
src_md =
memory::desc(src_dims, MklDnnType<T>(), diff_dst_mkl_data_format);
} else {
diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc
index 8bde966be9..04d8a1bdeb 100644
--- a/tensorflow/core/kernels/mkl_softmax_op.cc
+++ b/tensorflow/core/kernels/mkl_softmax_op.cc
@@ -50,6 +50,7 @@ class MklSoftmaxOp : public OpKernel {
// src_tensor now points to the 0-th input of global data struct "context"
size_t src_idx = 0;
const Tensor& src_tensor = MklGetInput(context, src_idx);
+ const int input_dims = src_tensor.dims();
// Add: get MklShape
MklDnnShape src_mkl_shape;
@@ -62,7 +63,32 @@ class MklSoftmaxOp : public OpKernel {
: src_tensor.shape();
auto src_dims = TFShapeToMklDnnDims(src_tf_shape);
auto output_dims = src_dims;
-
+ memory::format layout_type;
+ // In MKL, data format passed to mkl softmax op depends on dimension of the input tensor.
+ // Here "x" data format in MKL is used for 1 dim tensor, "nc" for 2 dim tensor,
+ // "tnc" for 3 dim tensor, "nchw" for 4 dim tensor, and "ncdhw" for 5 dim tensor.
+ // Each of the simbols has the following meaning:
+ // n = batch, c = channels, t = sequence lenght, h = height,
+ // w = width, d = depth
+ switch (input_dims) {
+ case 1:
+ layout_type = memory::format::x;
+ break;
+ case 2:
+ layout_type = memory::format::nc;
+ break;
+ case 3:
+ layout_type = memory::format::tnc;
+ break;
+ case 4:
+ layout_type = memory::format::nchw;
+ break;
+ case 5:
+ layout_type = memory::format::ncdhw;
+ break;
+ default:
+ OP_REQUIRES_OK(context, errors::Aborted("Input dims must be <= 5 and >=1"));
+ }
// Create softmax memory for src, dst: both are defined in mkl_util.h,
// they are wrapper
MklDnnData<T> src(&cpu_engine);
@@ -75,7 +101,7 @@ class MklSoftmaxOp : public OpKernel {
auto src_md =
src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
- : memory::desc(src_dims, MklDnnType<T>(), memory::format::nc);
+ : memory::desc(src_dims, MklDnnType<T>(), layout_type);
// src: setting memory descriptor and op memory descriptor
// Basically following two functions maps the TF "src_tensor" to mkl
@@ -84,10 +110,11 @@ class MklSoftmaxOp : public OpKernel {
// data format is "nc" for src and dst; since the src and dst buffer is
// always in 2D shape
src.SetUsrMem(src_md, &src_tensor);
- src.SetOpMemDesc(src_dims, memory::format::nc);
+ src.SetOpMemDesc(src_dims, layout_type);
// creating a memory descriptor
- int axis = 1; // axis to which softmax will be applied
+ // passing outermost dim as default axis, where the softmax is applied
+ int axis = input_dims - 1;
auto softmax_fwd_desc = softmax_forward::desc(prop_kind::forward_scoring,
src.GetOpMemDesc(), axis);
auto softmax_fwd_pd =
@@ -107,7 +134,7 @@ class MklSoftmaxOp : public OpKernel {
output_mkl_shape.SetMklLayout(&dst_pd);
output_mkl_shape.SetElemType(MklDnnType<T>());
output_mkl_shape.SetTfLayout(output_dims.size(), output_dims,
- memory::format::nc);
+ layout_type);
output_tf_shape.AddDim((dst_pd.get_size() / sizeof(T)));
} else { // then output is also TF shape
output_mkl_shape.SetMklTensor(false);
diff --git a/tensorflow/core/kernels/poisson-loss.h b/tensorflow/core/kernels/poisson-loss.h
new file mode 100644
index 0000000000..f91244454e
--- /dev/null
+++ b/tensorflow/core/kernels/poisson-loss.h
@@ -0,0 +1,109 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_POISSON_LOSS_H_
+#define TENSORFLOW_CORE_KERNELS_POISSON_LOSS_H_
+
+#include <cmath>
+
+#include "tensorflow/core/kernels/loss.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+class PoissonLossUpdater : public DualLossUpdater {
+ public:
+ // Update is found by a Newton algorithm (see readme.md).
+ double ComputeUpdatedDual(const int num_loss_partitions, const double label,
+ const double example_weight,
+ const double current_dual, const double wx,
+ const double weighted_example_norm) const final {
+ // Newton algorithm converges quadratically so 10 steps will be largely
+ // enough to achieve a very good precision
+ static const int newton_total_steps = 10;
+ // Initialize the Newton optimization at x such that
+ // exp(x) = label - current_dual
+ const double y_minus_a = label - current_dual;
+ double x = (y_minus_a > 0) ? log(y_minus_a) : 0;
+ for (int i = 0; i < newton_total_steps; ++i) {
+ x = NewtonStep(x, num_loss_partitions, label, wx, example_weight,
+ weighted_example_norm, current_dual);
+ }
+ return label - exp(x);
+ }
+
+ // Dual of poisson loss function.
+ // https://en.wikipedia.org/wiki/Convex_conjugate
+ double ComputeDualLoss(const double current_dual, const double example_label,
+ const double example_weight) const final {
+ // Dual of the poisson loss function is
+ // (y-a)*(log(y-a)-1), where a is the dual variable.
+ // It is defined only for a<y.
+ const double y_minus_a = example_label - current_dual;
+ if (y_minus_a == 0.0) {
+ // (y-a)*(log(y-a)-1) approaches 0 as y-a approaches 0.
+ return 0.0;
+ }
+ if (y_minus_a < 0.0) {
+ return std::numeric_limits<double>::max();
+ }
+ return y_minus_a * (log(y_minus_a) - 1) * example_weight;
+ }
+
+ double ComputePrimalLoss(const double wx, const double example_label,
+ const double example_weight) const final {
+ return (exp(wx) - wx * example_label) * example_weight;
+ }
+
+ double PrimalLossDerivative(const double wx, const double label,
+ const double example_weight) const final {
+ return (exp(wx) - label) * example_weight;
+ }
+
+ // TODO(chapelle): We need to introduce a maximum_prediction parameter,
+ // expose that parameter to the user and have this method return
+ // 1.0/maximum_prediction.
+ // Setting this at 1 for now, it only impacts the adaptive sampling.
+ double SmoothnessConstant() const final { return 1; }
+
+ Status ConvertLabel(float* const example_label) const final {
+ if (*example_label < 0.0) {
+ return errors::InvalidArgument(
+ "Only non-negative labels can be used with the Poisson log loss. "
+ "Found example with label: ", *example_label);
+ }
+ return Status::OK();
+ }
+
+ private:
+ // One Newton step (see readme.md).
+ double NewtonStep(const double x, const int num_loss_partitions,
+ const double label, const double wx,
+ const double example_weight,
+ const double weighted_example_norm,
+ const double current_dual) const {
+ const double expx = exp(x);
+ const double numerator =
+ x - wx - num_loss_partitions * weighted_example_norm *
+ example_weight * (label - current_dual - expx);
+ const double denominator =
+ 1 + num_loss_partitions * weighted_example_norm * example_weight * expx;
+ return x - numerator / denominator;
+ }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_LOGISTIC_LOSS_H_
diff --git a/tensorflow/core/kernels/qr_op_complex128.cc b/tensorflow/core/kernels/qr_op_complex128.cc
index c5b73139bb..8a3e3dc0a9 100644
--- a/tensorflow/core/kernels/qr_op_complex128.cc
+++ b/tensorflow/core/kernels/qr_op_complex128.cc
@@ -20,7 +20,17 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<complex128>), complex128);
#if GOOGLE_CUDA
-REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<complex128>), complex128);
+// We temporarily disable QR on GPU due to a bug in the QR implementation in
+// cuSolver affecting older hardware. The cuSolver team is tracking the issue
+// (https://partners.nvidia.com/bug/viewbug/2171459) and we will re-enable
+// this feature when a fix is available.
+REGISTER_KERNEL_BUILDER(Name("Qr")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<complex128>("T")
+ .HostMemory("input")
+ .HostMemory("q")
+ .HostMemory("r"),
+ QrOp<complex128>);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/qr_op_complex64.cc b/tensorflow/core/kernels/qr_op_complex64.cc
index 4e14f2639c..467fa6c2d6 100644
--- a/tensorflow/core/kernels/qr_op_complex64.cc
+++ b/tensorflow/core/kernels/qr_op_complex64.cc
@@ -20,7 +20,11 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<complex64>), complex64);
#if GOOGLE_CUDA
-REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<complex64>), complex64);
+// We temporarily disable QR on GPU due to a bug in the QR implementation in
+// cuSolver affecting older hardware. The cuSolver team is tracking the issue
+// (https://partners.nvidia.com/bug/viewbug/2171459) and we will re-enable
+// this feature when a fix is available.
+// REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<complex64>), complex64);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/qr_op_double.cc b/tensorflow/core/kernels/qr_op_double.cc
index 51885eb355..05537a0eaa 100644
--- a/tensorflow/core/kernels/qr_op_double.cc
+++ b/tensorflow/core/kernels/qr_op_double.cc
@@ -20,7 +20,17 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<double>), double);
#if GOOGLE_CUDA
-REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<double>), double);
+// We temporarily disable QR on GPU due to a bug in the QR implementation in
+// cuSolver affecting older hardware. The cuSolver team is tracking the issue
+// (https://partners.nvidia.com/bug/viewbug/2171459) and we will re-enable
+// this feature when a fix is available.
+REGISTER_KERNEL_BUILDER(Name("Qr")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<double>("T")
+ .HostMemory("input")
+ .HostMemory("q")
+ .HostMemory("r"),
+ QrOp<double>);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/qr_op_float.cc b/tensorflow/core/kernels/qr_op_float.cc
index d0a1dd4204..6aebd98186 100644
--- a/tensorflow/core/kernels/qr_op_float.cc
+++ b/tensorflow/core/kernels/qr_op_float.cc
@@ -20,7 +20,17 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<float>), float);
#if GOOGLE_CUDA
-REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<float>), float);
+// We temporarily disable QR on GPU due to a bug in the QR implementation in
+// cuSolver affecting older hardware. The cuSolver team is tracking the issue
+// (https://partners.nvidia.com/bug/viewbug/2171459) and we will re-enable
+// this feature when a fix is available.
+REGISTER_KERNEL_BUILDER(Name("Qr")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T")
+ .HostMemory("input")
+ .HostMemory("q")
+ .HostMemory("r"),
+ QrOp<float>);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/range_sampler_test.cc b/tensorflow/core/kernels/range_sampler_test.cc
index 9020121169..3d49af7cb1 100644
--- a/tensorflow/core/kernels/range_sampler_test.cc
+++ b/tensorflow/core/kernels/range_sampler_test.cc
@@ -45,7 +45,7 @@ class RangeSamplerTest : public ::testing::Test {
// Using a fixed random seed to make the test deterministic.
random::PhiloxRandom philox(123, 17);
random::SimplePhilox rnd(&philox);
- sampler_->SampleBatch(&rnd, false, &a);
+ sampler_->SampleBatch(&rnd, false, absl::MakeSpan(a));
for (int i = 0; i < num_samples; i++) {
int64 val = a[i];
ASSERT_GE(val, 0);
@@ -251,8 +251,9 @@ TEST_F(RangeSamplerTest, All) {
extras[0] = 0;
extras[1] = batch_size - 1;
sampler_->SampleBatchGetExpectedCount(nullptr, // no random numbers needed
- false, &batch, &batch_expected, extras,
- &extras_expected);
+ false, absl::MakeSpan(batch),
+ absl::MakeSpan(batch_expected), extras,
+ absl::MakeSpan(extras_expected));
for (int i = 0; i < batch_size; i++) {
EXPECT_EQ(i, batch[i]);
EXPECT_EQ(1, batch_expected[i]);
@@ -281,17 +282,18 @@ TEST_F(RangeSamplerTest, Unique) {
std::vector<float> expected(range);
// Sample one batch and get the expected counts of all values
- sampler_->SampleBatchGetExpectedCount(
- &rnd, true, &batch, MutableArraySlice<float>(), all_values, &expected);
+ sampler_->SampleBatchGetExpectedCount(&rnd, true, absl::MakeSpan(batch),
+ MutableArraySlice<float>(), all_values,
+ absl::MakeSpan(expected));
// Check that all elements are unique
std::set<int64> s(batch.begin(), batch.end());
CHECK_EQ(batch_size, s.size());
for (int trial = 0; trial < num_batches; trial++) {
std::vector<float> trial_expected(range);
- sampler_->SampleBatchGetExpectedCount(&rnd, true, &batch,
- MutableArraySlice<float>(),
- all_values, &trial_expected);
+ sampler_->SampleBatchGetExpectedCount(
+ &rnd, true, absl::MakeSpan(batch), MutableArraySlice<float>(),
+ all_values, absl::MakeSpan(trial_expected));
for (int i = 0; i < range; i++) {
EXPECT_NEAR(expected[i], trial_expected[i], expected[i] * 0.5);
}
@@ -318,8 +320,8 @@ TEST_F(RangeSamplerTest, Avoid) {
// We expect to pick all elements of [0, 100) except the avoided two.
sampler_->SampleBatchGetExpectedCountAvoid(
- &rnd, true, &batch, MutableArraySlice<float>(), ArraySlice<int64>(),
- MutableArraySlice<float>(), avoided);
+ &rnd, true, absl::MakeSpan(batch), MutableArraySlice<float>(),
+ ArraySlice<int64>(), MutableArraySlice<float>(), avoided);
int sum = 0;
for (auto val : batch) {
diff --git a/tensorflow/core/kernels/sdca_internal.cc b/tensorflow/core/kernels/sdca_internal.cc
index 1c071d3d41..a8e9b3261c 100644
--- a/tensorflow/core/kernels/sdca_internal.cc
+++ b/tensorflow/core/kernels/sdca_internal.cc
@@ -251,7 +251,7 @@ Status Examples::SampleAdaptiveProbabilities(
num_weight_vectors);
const double kappa = example_state_data(example_id, 0) +
loss_updater->PrimalLossDerivative(
- example_statistics.wx[0], label, example_weight);
+ example_statistics.wx[0], label, 1.0);
probabilities_[example_id] = example_weight *
sqrt(examples_[example_id].squared_norm_ +
regularization.symmetric_l2() *
diff --git a/tensorflow/core/kernels/sdca_ops.cc b/tensorflow/core/kernels/sdca_ops.cc
index 05c835ebc4..3bd4168dc7 100644
--- a/tensorflow/core/kernels/sdca_ops.cc
+++ b/tensorflow/core/kernels/sdca_ops.cc
@@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/kernels/hinge-loss.h"
#include "tensorflow/core/kernels/logistic-loss.h"
#include "tensorflow/core/kernels/loss.h"
+#include "tensorflow/core/kernels/poisson-loss.h"
#include "tensorflow/core/kernels/sdca_internal.h"
#include "tensorflow/core/kernels/smooth-hinge-loss.h"
#include "tensorflow/core/kernels/squared-loss.h"
@@ -75,6 +76,8 @@ struct ComputeOptions {
loss_updater.reset(new HingeLossUpdater);
} else if (loss_type == "smooth_hinge_loss") {
loss_updater.reset(new SmoothHingeLossUpdater);
+ } else if (loss_type == "poisson_loss") {
+ loss_updater.reset(new PoissonLossUpdater);
} else {
OP_REQUIRES(
context, false,
diff --git a/tensorflow/core/kernels/set_kernels.cc b/tensorflow/core/kernels/set_kernels.cc
index f893d4e945..0428909145 100644
--- a/tensorflow/core/kernels/set_kernels.cc
+++ b/tensorflow/core/kernels/set_kernels.cc
@@ -269,7 +269,7 @@ void SetSizeOp<T>::Compute(OpKernelContext* ctx) {
// Group by all but last dimension, create a set of group values, and add set
// size to output.
- VarDimArray group_ix(set_st.order(), 0, set_st.order().size() - 1);
+ VarDimArray group_ix = set_st.order().subspan(0, set_st.order().size() - 1);
std::set<T> group_set;
for (const auto& group : set_st.group(group_ix)) {
PopulateFromSparseGroup<T>(ctx, group, set_st.shape(), &group_set);
@@ -500,8 +500,8 @@ void SetOperationOp<T>::ComputeDenseToSparse(OpKernelContext* ctx) const {
std::set<T> set1_group_set;
std::set<T> set2_group_set;
- auto set2_grouper = set2_st.group(
- VarDimArray(set2_st.order(), 0, set2_st.order().size() - 1));
+ auto set2_grouper =
+ set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
auto set2_group_it = set2_grouper.begin();
std::vector<int64> group_indices;
int64 num_elements;
@@ -621,11 +621,11 @@ void SetOperationOp<T>::ComputeSparseToSparse(OpKernelContext* ctx) const {
std::set<T> set1_group_set;
std::set<T> set2_group_set;
- auto set1_grouper = set1_st.group(
- VarDimArray(set1_st.order(), 0, set1_st.order().size() - 1));
+ auto set1_grouper =
+ set1_st.group(set1_st.order().subspan(0, set1_st.order().size() - 1));
auto set1_group_it = set1_grouper.begin();
- auto set2_grouper = set2_st.group(
- VarDimArray(set2_st.order(), 0, set2_st.order().size() - 1));
+ auto set2_grouper =
+ set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
auto set2_group_it = set2_grouper.begin();
// Group by rows, and iterate over rows of both sets in parallel, creating a
diff --git a/tensorflow/core/kernels/sparse_softmax_op.cc b/tensorflow/core/kernels/sparse_softmax_op.cc
index dc3119bba4..37664fe8df 100644
--- a/tensorflow/core/kernels/sparse_softmax_op.cc
+++ b/tensorflow/core/kernels/sparse_softmax_op.cc
@@ -90,7 +90,7 @@ class SparseSoftmaxOp : public OpKernel {
// { 0, ..., rank-1 }.
const ArraySlice<int64> kReorderDims(dims);
// All but the last dim -- the class dimension to be max-reduced along.
- const ArraySlice<int64> kGroupByDims(kReorderDims, 0, rank - 1);
+ const ArraySlice<int64> kGroupByDims = kReorderDims.subspan(0, rank - 1);
st.Reorder<T>(kReorderDims);
int count = 0;
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index 271329599f..9a07ded17d 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
-
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include <algorithm>
@@ -201,7 +200,7 @@ struct ApplyFtrlV2<CPUDevice, T> {
typename TTypes<T>::ConstScalar l2_shrinkage,
typename TTypes<T>::ConstScalar lr_power) {
auto grad_with_shrinkage = grad + static_cast<T>(2) * l2_shrinkage() * var;
- auto new_accum = accum + grad_with_shrinkage.square();
+ auto new_accum = accum + grad * grad;
// special case for which lr_power=-0.5.
if (lr_power() == static_cast<T>(-0.5)) {
linear.device(d) +=
@@ -226,7 +225,7 @@ struct ApplyFtrlV2<CPUDevice, T> {
var.device(d) = (linear.abs() > linear.constant(l1()))
.select(pre_shrink, var.constant(static_cast<T>(0)));
}
- accum.device(d) += grad_with_shrinkage.square();
+ accum.device(d) += grad * grad;
}
};
@@ -2167,15 +2166,15 @@ class SparseApplyFtrlOp : public OpKernel {
// Use a macro to implement the computation here due to the templating of the
// eigen tensor library.
-#define COMPUTE_FTRL(grad_to_use) \
- auto new_accum = accum + grad_to_use.square(); \
+#define COMPUTE_FTRL(grad, grad_maybe_with_shrinkage) \
+ auto new_accum = accum + grad.square(); \
if (lr_power_scalar == static_cast<T>(-0.5)) { \
- linear += \
- grad_to_use - (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var; \
+ linear += grad_maybe_with_shrinkage - \
+ (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var; \
} else { \
- linear += grad_to_use - (new_accum.pow(-lr_power_scalar) - \
- accum.pow(-lr_power_scalar)) / \
- lr_scalar * var; \
+ linear += grad_maybe_with_shrinkage - (new_accum.pow(-lr_power_scalar) - \
+ accum.pow(-lr_power_scalar)) / \
+ lr_scalar * var; \
} \
auto l1_reg_adjust = linear.cwiseMin(l1_scalar).cwiseMax(-l1_scalar); \
auto x = l1_reg_adjust - linear; \
@@ -2188,14 +2187,14 @@ class SparseApplyFtrlOp : public OpKernel {
linear.constant(static_cast<T>(2) * l2_scalar); \
var = x / y; \
} \
- accum += grad_to_use.square();
+ accum += grad.square();
if (has_l2_shrinkage) {
auto grad_with_shrinkage =
grad + static_cast<T>(2) * l2_shrinkage_scalar * var;
- COMPUTE_FTRL(grad_with_shrinkage);
+ COMPUTE_FTRL(grad, grad_with_shrinkage);
} else {
- COMPUTE_FTRL(grad);
+ COMPUTE_FTRL(grad, grad);
}
}
#undef COMPUTE_FTRL
@@ -2228,12 +2227,12 @@ class SparseApplyFtrlOp : public OpKernel {
T g;
if (has_l2_shrinkage) {
g = grad_flat(i) +
- (static_cast<T>(2) * l2_shrinkage_scalar * var_flat(i));
+ (static_cast<T>(2) * l2_shrinkage_scalar * var_flat(index));
} else {
g = grad_flat(i);
}
- T updated_a = a + g * g;
+ T updated_a = a + grad_flat(i) * grad_flat(i);
using Eigen::numext::pow;
T sigma = pow(updated_a, -lr_power_scalar) - pow(a, -lr_power_scalar);
sigma /= lr_scalar;
@@ -2856,9 +2855,8 @@ class ApplyAdaMaxOp : public OpKernel {
const Device& device = ctx->template eigen_device<Device>();
functor::ApplyAdaMax<Device, T>()(
device, var.flat<T>(), m.flat<T>(), v.flat<T>(),
- beta1_power.scalar<T>(), lr.scalar<T>(),
- beta1.scalar<T>(), beta2.scalar<T>(), epsilon.scalar<T>(),
- grad.flat<T>());
+ beta1_power.scalar<T>(), lr.scalar<T>(), beta1.scalar<T>(),
+ beta2.scalar<T>(), epsilon.scalar<T>(), grad.flat<T>());
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
@@ -2867,16 +2865,16 @@ class ApplyAdaMaxOp : public OpKernel {
bool use_exclusive_lock_;
};
-#define REGISTER_KERNELS(D, T) \
- REGISTER_KERNEL_BUILDER( \
+#define REGISTER_KERNELS(D, T) \
+ REGISTER_KERNEL_BUILDER( \
Name("ApplyAdaMax").Device(DEVICE_##D).TypeConstraint<T>("T"), \
ApplyAdaMaxOp<D##Device, T>); \
REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdaMax") \
- .HostMemory("var") \
- .HostMemory("m") \
- .HostMemory("v") \
- .Device(DEVICE_##D) \
- .TypeConstraint<T>("T"), \
+ .HostMemory("var") \
+ .HostMemory("m") \
+ .HostMemory("v") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<T>("T"), \
ApplyAdaMaxOp<D##Device, T>);
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
@@ -2889,7 +2887,7 @@ TF_CALL_double(REGISTER_CPU_KERNELS);
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
- void ApplyAdaMax<GPUDevice, T>::operator()( \
+ void ApplyAdaMax<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::Flat var, \
typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, \
typename TTypes<T>::ConstScalar beta1_power, \
@@ -2897,7 +2895,7 @@ namespace functor {
typename TTypes<T>::ConstScalar beta1, \
typename TTypes<T>::ConstScalar beta2, \
typename TTypes<T>::ConstScalar epsilon, \
- typename TTypes<T>::ConstFlat grad); \
+ typename TTypes<T>::ConstFlat grad); \
extern template struct ApplyAdaMax<GPUDevice, T>;
DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(float);
diff --git a/tensorflow/core/lib/bfloat16/bfloat16.h b/tensorflow/core/lib/bfloat16/bfloat16.h
index d6f3f26cd5..5c917e80c1 100644
--- a/tensorflow/core/lib/bfloat16/bfloat16.h
+++ b/tensorflow/core/lib/bfloat16/bfloat16.h
@@ -61,9 +61,7 @@ struct bfloat16 {
}
B16_DEVICE_FUNC explicit bfloat16(const float v) {
- // TODO(asabne) : change the below line to
- // value = round_to_bfloat16(v).value;
- value = truncate_to_bfloat16(v).value;
+ value = round_to_bfloat16(v).value;
}
B16_DEVICE_FUNC explicit bfloat16(const double val)
diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h
index 49a8a4dbd4..982901a39c 100644
--- a/tensorflow/core/lib/core/errors.h
+++ b/tensorflow/core/lib/core/errors.h
@@ -131,11 +131,25 @@ inline string FormatNodeNameForError(const string& name) {
// LINT.ThenChange(//tensorflow/python/client/session.py)
template <typename T>
string FormatNodeNamesForError(const T& names) {
- ::tensorflow::str_util::Formatter<string> f(
- [](string* output, const string& s) {
+ return ::tensorflow::str_util::Join(
+ names, ", ", [](string* output, const string& s) {
::tensorflow::strings::StrAppend(output, FormatNodeNameForError(s));
});
- return ::tensorflow::str_util::Join(names, ", ", f);
+}
+// TODO(b/113350742): Consolidate the two different formats `{{key value}}` and
+// `^^key:value^^` in a follow-on CL.
+// LINT.IfChange
+inline string FormatColocationNodeForError(const string& name) {
+ return strings::StrCat("^^colocation_node:", name, "^^");
+}
+// LINT.ThenChange(//tensorflow/python/framework/error_interpolation.py)
+template <typename T>
+string FormatColocationNodeForError(const T& names) {
+ return ::tensorflow::str_util::Join(
+ names, ", ", [](string* output, const string& s) {
+ ::tensorflow::strings::StrAppend(output,
+ FormatColocationNodeForError(s));
+ });
}
// The CanonicalCode() for non-errors.
diff --git a/tensorflow/core/lib/core/stringpiece.cc b/tensorflow/core/lib/core/stringpiece.cc
deleted file mode 100644
index 4c488066e4..0000000000
--- a/tensorflow/core/lib/core/stringpiece.cc
+++ /dev/null
@@ -1,54 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/core/stringpiece.h"
-
-#include <algorithm>
-#include <iostream>
-
-namespace tensorflow {
-
-std::ostream& operator<<(std::ostream& o, StringPiece piece) {
- o.write(piece.data(), piece.size());
- return o;
-}
-
-size_t StringPiece::find(char c, size_t pos) const {
- if (pos >= size_) {
- return npos;
- }
- const char* result =
- reinterpret_cast<const char*>(memchr(data_ + pos, c, size_ - pos));
- return result != nullptr ? result - data_ : npos;
-}
-
-// Search range is [0..pos] inclusive. If pos == npos, search everything.
-size_t StringPiece::rfind(char c, size_t pos) const {
- if (size_ == 0) return npos;
- for (const char* p = data_ + std::min(pos, size_ - 1); p >= data_; p--) {
- if (*p == c) {
- return p - data_;
- }
- }
- return npos;
-}
-
-StringPiece StringPiece::substr(size_t pos, size_t n) const {
- if (pos > size_) pos = size_;
- if (n > size_ - pos) n = size_ - pos;
- return StringPiece(data_ + pos, n);
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h
index 02dded42c1..e7b17c9b36 100644
--- a/tensorflow/core/lib/core/stringpiece.h
+++ b/tensorflow/core/lib/core/stringpiece.h
@@ -31,124 +31,13 @@ limitations under the License.
#include <string.h>
#include <iosfwd>
#include <string>
-#include <type_traits>
+#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-class StringPiece {
- public:
- typedef size_t size_type;
-
- // Create an empty slice.
- StringPiece() : data_(nullptr), size_(0) {}
-
- // Create a slice that refers to d[0,n-1].
- StringPiece(const char* d, size_t n) : data_(d), size_(n) {}
-
- // Create a slice that refers to the contents of "s"
- StringPiece(const string& s) : data_(s.data()), size_(s.size()) {}
-
- // Create a slice that refers to s[0,strlen(s)-1]
- StringPiece(const char* s) : data_(s), size_(strlen(s)) {}
-
- // Return a pointer to the beginning of the referenced data
- const char* data() const { return data_; }
-
- // Return the length (in bytes) of the referenced data
- size_t size() const { return size_; }
-
- // Return true iff the length of the referenced data is zero
- bool empty() const { return size_ == 0; }
-
- typedef const char* const_iterator;
- typedef const char* iterator;
- iterator begin() const { return data_; }
- iterator end() const { return data_ + size_; }
-
- static const size_t npos = size_type(-1);
-
- // Return the ith byte in the referenced data.
- // REQUIRES: n < size()
- char operator[](size_t n) const {
- assert(n < size());
- return data_[n];
- }
-
- // Drop the first "n" bytes from this slice.
- void remove_prefix(size_t n) {
- assert(n <= size());
- data_ += n;
- size_ -= n;
- }
-
- void remove_suffix(size_t n) {
- assert(size_ >= n);
- size_ -= n;
- }
-
- size_t find(char c, size_t pos = 0) const;
- size_t rfind(char c, size_t pos = npos) const;
-
- StringPiece substr(size_t pos, size_t n = npos) const;
-
- // Three-way comparison. Returns value:
- // < 0 iff "*this" < "b",
- // == 0 iff "*this" == "b",
- // > 0 iff "*this" > "b"
- int compare(StringPiece b) const;
-
- // Converts to various kinds of strings, including `std::basic_string`.
- template <typename S>
- explicit operator S() const {
- static_assert(
- std::is_same<char, typename S::value_type>::value,
- "Type mismatch: S must be a string with character type char.");
- static_assert(
- std::is_same<std::char_traits<char>, typename S::traits_type>::value,
- "Type mismatch: S must be a string with traits type "
- "std::char_traits<char>.");
- if (!data()) return {};
- return S(data(), size());
- }
-
- private:
- const char* data_;
- size_t size_;
-
- // Intentionally copyable
-};
-
-inline bool operator==(StringPiece x, StringPiece y) {
- return ((x.size() == y.size()) &&
- (memcmp(x.data(), y.data(), x.size()) == 0));
-}
-
-inline bool operator!=(StringPiece x, StringPiece y) { return !(x == y); }
-
-inline bool operator<(StringPiece x, StringPiece y) { return x.compare(y) < 0; }
-inline bool operator>(StringPiece x, StringPiece y) { return x.compare(y) > 0; }
-inline bool operator<=(StringPiece x, StringPiece y) {
- return x.compare(y) <= 0;
-}
-inline bool operator>=(StringPiece x, StringPiece y) {
- return x.compare(y) >= 0;
-}
-
-inline int StringPiece::compare(StringPiece b) const {
- const size_t min_len = (size_ < b.size_) ? size_ : b.size_;
- int r = memcmp(data_, b.data_, min_len);
- if (r == 0) {
- if (size_ < b.size_)
- r = -1;
- else if (size_ > b.size_)
- r = +1;
- }
- return r;
-}
-
-// allow StringPiece to be logged
-extern std::ostream& operator<<(std::ostream& o, tensorflow::StringPiece piece);
+// Deprecated: please use absl::string_view directly.
+using StringPiece = absl::string_view;
} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/array_slice.h b/tensorflow/core/lib/gtl/array_slice.h
index b773a65569..8f47faf89e 100644
--- a/tensorflow/core/lib/gtl/array_slice.h
+++ b/tensorflow/core/lib/gtl/array_slice.h
@@ -13,293 +13,22 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// An ArraySlice<T> represents an immutable array of elements of type
-// T. It has a length "length", and a base pointer "ptr", and the
-// array it represents contains the elements "ptr[0] .. ptr[len-1]".
-// The backing store for the array is *not* owned by the ArraySlice
-// object, and clients must arrange for the backing store to remain
-// live while the ArraySlice object is in use.
-//
-// An ArraySlice<T> is somewhat analogous to a StringPiece, but for
-// array elements of type T.
-//
-// Implicit conversion operations are provided from types such as
-// std::vector<T> and util::gtl::InlinedVector<T, N>. Note that ArraySlice
-// objects constructed from types in this way may be invalidated by
-// any operations that mutate the underlying vector.
-//
-// One common use for ArraySlice is when passing arguments to a
-// routine where you want to be able to accept a variety of array
-// types (e.g. a vector, a util::gtl::InlinedVector, a C-style array,
-// etc.). The usual approach here is to have the client explicitly
-// pass in a pointer and a length, as in:
-//
-// void MyRoutine(const int* elems, int N) {
-// for (int i = 0; i < N; i++) { .. do something with elems[i] .. }
-// }
-//
-// Unfortunately, this leads to ugly and error-prone code at the call site:
-//
-// std::vector<int> my_vector;
-// MyRoutine(vector_as_array(&my_vector), my_vector.size());
-//
-// util::gtl::InlinedVector<int, 4> my_inline_vector;
-// MyRoutine(my_inline_vector.array(), my_inline_vector.size());
-//
-// int my_array[10];
-// MyRoutine(my_array, 10);
-//
-// Instead, you can use an ArraySlice as the argument to the routine:
-//
-// void MyRoutine(ArraySlice<int> a) {
-// for (int i = 0; i < a.size(); i++) { .. do something with a[i] .. }
-// }
-//
-// This makes the call sites cleaner, for the most part:
-//
-// std::vector<int> my_vector;
-// MyRoutine(my_vector);
-//
-// util::gtl::InlinedVector<int, 4> my_inline_vector;
-// MyRoutine(my_inline_vector);
-//
-// int my_array[10];
-// MyRoutine(my_array);
-//
-// int* my_array = new int[10];
-// MyRoutine(gtl::ArraySlice<int>(my_array, 10));
-//
-// MutableArraySlice<T> represents a mutable array of elements, and, like
-// ArraySlice, does not own the backing store. The implicit constructors it
-// provides allow functions not to worry about whether their mutable arguments
-// refer to vectors, arrays, proto2::RepeatedFields, etc.:
-//
-// void MyMutatingRoutine(MutableArraySlice<int> a) {
-// for (int i = 0; i < a.size(); i++) { .. mutate a[i] .. }
-// }
-//
-// std::vector<int> my_vector;
-// MyMutatingRoutine(&my_vector);
-//
-// int my_array[10];
-// MyMutatingRoutine(my_array);
-//
-// int* my_array = new int[10];
-// MyMutatingRoutine(gtl::MutableArraySlice<int>(my_array, 10));
-//
-// MyProto my_proto;
-// for (int i = 0; i < 10; ++i) { my_proto.add_value(i); }
-// MyMutatingRoutine(my_proto.mutable_value());
-
#ifndef TENSORFLOW_CORE_LIB_GTL_ARRAY_SLICE_H_
#define TENSORFLOW_CORE_LIB_GTL_ARRAY_SLICE_H_
-#include <initializer_list>
-#include <type_traits>
-#include <vector>
-
-#include "tensorflow/core/lib/gtl/array_slice_internal.h"
+#include "absl/types/span.h"
+// TODO(timshen): This is kept only because lots of targets transitively depend
+// on it. Remove all targets' dependencies.
#include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace tensorflow {
namespace gtl {
template <typename T>
-class ArraySlice {
- private:
- typedef array_slice_internal::ArraySliceImpl<T> Impl;
-
- public:
- typedef T value_type;
- typedef typename Impl::pointer pointer;
- typedef typename Impl::const_pointer const_pointer;
- typedef typename Impl::reference reference;
- typedef typename Impl::const_reference const_reference;
- typedef typename Impl::iterator iterator;
- typedef typename Impl::const_iterator const_iterator;
- typedef typename Impl::reverse_iterator reverse_iterator;
- typedef typename Impl::const_reverse_iterator const_reverse_iterator;
- typedef typename Impl::size_type size_type;
- typedef typename Impl::difference_type difference_type;
-
- static const size_type npos = Impl::npos;
-
- ArraySlice() : impl_(nullptr, 0) {}
- ArraySlice(const_pointer array, size_type length) : impl_(array, length) {}
-
- // Implicit conversion constructors
- ArraySlice(const std::vector<value_type>& v) // NOLINT(runtime/explicit)
- : impl_(v.data(), v.size()) {}
-
- template <size_t N>
- ArraySlice(const value_type (&a)[N]) // NOLINT(runtime/explicit)
- : impl_(a, N) {}
-
- template <int N>
- ArraySlice(const InlinedVector<value_type, N>& v) // NOLINT(runtime/explicit)
- : impl_(v.data(), v.size()) {}
-
- // The constructor for any class supplying 'data() const' that returns either
- // const T* or a less const-qualified version of it, and 'some_integral_type
- // size() const'. proto2::RepeatedField<T>, string and (since C++11)
- // std::vector<T,A> and std::array<T, N> are examples of this. See
- // array_slice_internal.h for details.
- template <typename V,
- typename = typename Impl::template EnableIfConvertibleFrom<V>>
- ArraySlice(const V& v) // NOLINT(runtime/explicit)
- : impl_(v) {}
-
- // Implicitly constructs an ArraySlice from an initializer list. This makes it
- // possible to pass a brace-enclosed initializer list to a function expecting
- // an ArraySlice:
- // void Process(ArraySlice<int> x);
- // Process({1, 2, 3});
- // The data referenced by the initializer_list must outlive this
- // ArraySlice. For example, "ArraySlice<int> s={1,2};" and "return
- // ArraySlice<int>({3,4});" are errors, as the resulting ArraySlice may
- // reference data that is no longer valid.
- ArraySlice(std::initializer_list<value_type> v) // NOLINT(runtime/explicit)
- : impl_(v.begin(), v.size()) {}
-
- // Substring of another ArraySlice.
- // pos must be non-negative and <= x.length().
- // len must be non-negative and will be pinned to at most x.length() - pos.
- // If len==npos, the substring continues till the end of x.
- ArraySlice(const ArraySlice& x, size_type pos, size_type len)
- : impl_(x.impl_, pos, len) {}
-
- const_pointer data() const { return impl_.data(); }
- size_type size() const { return impl_.size(); }
- size_type length() const { return size(); }
- bool empty() const { return size() == 0; }
-
- void clear() { impl_.clear(); }
-
- const_reference operator[](size_type i) const { return impl_[i]; }
- const_reference at(size_type i) const { return impl_.at(i); }
- const_reference front() const { return impl_.front(); }
- const_reference back() const { return impl_.back(); }
-
- const_iterator begin() const { return impl_.begin(); }
- const_iterator end() const { return impl_.end(); }
- const_reverse_iterator rbegin() const { return impl_.rbegin(); }
- const_reverse_iterator rend() const { return impl_.rend(); }
-
- void remove_prefix(size_type n) { impl_.remove_prefix(n); }
- void remove_suffix(size_type n) { impl_.remove_suffix(n); }
-
- // These relational operators have the same semantics as the
- // std::vector<T> relational operators: they do deep (element-wise)
- // comparisons. Array slices are equal iff their size is the same
- // and all their elements are equal.
- bool operator==(ArraySlice<T> other) const { return impl_ == other.impl_; }
- bool operator!=(ArraySlice<T> other) const { return impl_ != other.impl_; }
-
- private:
- Impl impl_;
-};
-
-// Mutable version of ArraySlice, which allows the clients to mutate the
-// underlying data. It is implicitly convertible to ArraySlice since it provides
-// the data() and size() methods with correct signatures. When a
-// MutableArraySlice is created from a pointer to a container (as opposed to raw
-// memory pointer), the pointer must not be null.
-//
-// A note on const-ness: "mutable" here refers to the mutability of the
-// underlying data, not of the slice itself. It is perfectly reasonable to have
-// a variable of type "const MutableArraySlice<T>"; this means that the bounds
-// of the view on the array cannot be changed, but the underlying data in the
-// array still may be modified. This is akin to a "T* const" pointer, as opposed
-// to a "const T*" pointer (corresponding to a non-const ArraySlice<T>).
-template <typename T>
-class MutableArraySlice {
- private:
- typedef array_slice_internal::MutableArraySliceImpl<T> Impl;
-
- public:
- typedef T value_type;
- typedef typename Impl::pointer pointer;
- typedef typename Impl::const_pointer const_pointer;
- typedef typename Impl::reference reference;
- typedef typename Impl::const_reference const_reference;
- typedef typename Impl::iterator iterator;
- typedef typename Impl::const_iterator const_iterator;
- typedef typename Impl::reverse_iterator reverse_iterator;
- typedef typename Impl::const_reverse_iterator const_reverse_iterator;
- typedef typename Impl::size_type size_type;
- typedef typename Impl::difference_type difference_type;
-
- static const size_type npos = Impl::npos;
-
- MutableArraySlice() : impl_(nullptr, 0) {}
- MutableArraySlice(pointer array, size_type length) : impl_(array, length) {}
-
- // Implicit conversion constructors
- MutableArraySlice(std::vector<value_type>* v) // NOLINT(runtime/explicit)
- : impl_(v->data(), v->size()) {}
-
- template <size_t N>
- MutableArraySlice(value_type (&a)[N]) // NOLINT(runtime/explicit)
- : impl_(a, N) {}
-
- template <int N>
- MutableArraySlice(
- InlinedVector<value_type, N>* v) // NOLINT(runtime/explicit)
- : impl_(v->data(), v->size()) {}
-
- // The constructor for any class supplying 'T* data()' or 'T* mutable_data()'
- // (the former is called if both exist), and 'some_integral_type size()
- // const'. proto2::RepeatedField is an example of this. Also supports string
- // arguments, when T==char. The appropriate ctor is selected using SFINAE. See
- // array_slice_internal.h for details.
- template <typename V,
- typename = typename Impl::template EnableIfConvertibleFrom<V>>
- MutableArraySlice(V* v) // NOLINT(runtime/explicit)
- : impl_(v) {}
+using ArraySlice = absl::Span<const T>;
- // Substring of another MutableArraySlice.
- // pos must be non-negative and <= x.length().
- // len must be non-negative and will be pinned to at most x.length() - pos.
- // If len==npos, the substring continues till the end of x.
- MutableArraySlice(const MutableArraySlice& x, size_type pos, size_type len)
- : impl_(x.impl_, pos, len) {}
-
- // Accessors.
- pointer data() const { return impl_.data(); }
- size_type size() const { return impl_.size(); }
- size_type length() const { return size(); }
- bool empty() const { return size() == 0; }
-
- void clear() { impl_.clear(); }
-
- reference operator[](size_type i) const { return impl_[i]; }
- reference at(size_type i) const { return impl_.at(i); }
- reference front() const { return impl_.front(); }
- reference back() const { return impl_.back(); }
-
- iterator begin() const { return impl_.begin(); }
- iterator end() const { return impl_.end(); }
- reverse_iterator rbegin() const { return impl_.rbegin(); }
- reverse_iterator rend() const { return impl_.rend(); }
-
- void remove_prefix(size_type n) { impl_.remove_prefix(n); }
- void remove_suffix(size_type n) { impl_.remove_suffix(n); }
-
- bool operator==(ArraySlice<T> other) const {
- return ArraySlice<T>(*this) == other;
- }
- bool operator!=(ArraySlice<T> other) const {
- return ArraySlice<T>(*this) != other;
- }
-
- private:
- Impl impl_;
-};
-
-template <typename T>
-const typename ArraySlice<T>::size_type ArraySlice<T>::npos;
template <typename T>
-const typename MutableArraySlice<T>::size_type MutableArraySlice<T>::npos;
+using MutableArraySlice = absl::Span<T>;
} // namespace gtl
} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/array_slice_internal.h b/tensorflow/core/lib/gtl/array_slice_internal.h
deleted file mode 100644
index 689dd8a646..0000000000
--- a/tensorflow/core/lib/gtl/array_slice_internal.h
+++ /dev/null
@@ -1,269 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// NOT FOR INCLUSION BY CLIENT CODE. This file is only to be included by
-// array_slice.h.
-
-// Helper functions and templates for ArraySlice.
-
-#ifndef TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_
-#define TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_
-
-#include <stddef.h>
-#include <algorithm>
-#include <iterator>
-#include <memory>
-#include <string>
-#include <type_traits>
-#include <utility>
-#include <vector>
-#include "tensorflow/core/platform/logging.h"
-
-namespace tensorflow {
-namespace gtl {
-namespace array_slice_internal {
-
-// Template logic for generic constructors.
-
-// Wrappers whose Get() delegates to the appropriate method of a container, and
-// is defined when this method exists. Delegates to the const method if C is a
-// const type.
-struct Data {
- template <typename C>
- static decltype(std::declval<C>().data()) Get(C* v) {
- return v->data();
- }
-};
-
-struct MutableData {
- template <typename C>
- static decltype(std::declval<C>().mutable_data()) Get(C* v) {
- return v->mutable_data();
- }
-};
-
-struct Size {
- template <typename C>
- static decltype(std::declval<C>().size()) Get(C* v) {
- return v->size();
- }
-};
-
-struct MutableStringData {
- // Defined only for string.
- static char* Get(string* v) { return v->empty() ? nullptr : &*v->begin(); }
-};
-
-// Checks whether M::Get(C*) is defined and has a return type R such that
-// Checker::valid<R>()==true.
-template <typename M, typename Checker, typename C>
-struct HasGetHelper : public M {
- private:
- struct None {};
- // M::Get is selected when it is viable. Get(...) is selected otherwise.
- using M::Get;
- static None Get(...);
-
- public:
- static constexpr bool HasGet() {
- using Result = decltype(Get(std::declval<C*>()));
- return !std::is_same<Result, None>() && Checker::template valid<Result>();
- }
-};
-
-// Defines HasGet() for a particular method, container, and checker. If
-// HasGet()==true, provides Get() that delegates to the method.
-template <typename M, typename Checker, typename C,
- bool /*has_get*/ = HasGetHelper<M, Checker, C>::HasGet()>
-struct Wrapper {
- static constexpr bool HasGet() { return false; }
-};
-
-template <typename M, typename Checker, typename C>
-struct Wrapper<M, Checker, C, true> {
- static constexpr bool HasGet() { return true; }
- static decltype(M::Get(std::declval<C*>())) Get(C* v) { return M::Get(v); }
-};
-
-// Type checker for a method returning an integral value.
-struct SizeChecker {
- template <typename R>
- static constexpr bool valid() {
- return std::is_integral<R>::value;
- }
-};
-
-// Type checker for a method returning either a pointer to T or a less const
-// version of that.
-template <typename T>
-struct DataChecker {
- // We want to enable conversion from std::vector<T*> to ArraySlice<const T*>
- // but
- // disable conversion from std::vector<Derived> to ArraySlice<Base>. Here we
- // use
- // the fact that U** is convertible to Q* const* if and only if Q is the same
- // type or a more cv-qualified version of U.
- template <typename R>
- static constexpr bool valid() {
- return std::is_convertible<R*, T* const*>::value;
- }
-};
-
-// Aliases to A if A::HasGet()==true, or to B otherwise.
-template <typename A, typename B>
-using FirstWithGet = typename std::conditional<A::HasGet(), A, B>::type;
-
-// Wraps C::data() const, returning a pointer to const data.
-template <typename T, typename C>
-using ContainerData = Wrapper<Data, DataChecker<const T>, const C>;
-
-// Wraps a method returning a pointer to mutable data. Prefers data() over
-// mutable_data(), and handles strings when T==char. If data() returns a pointer
-// to mutable data, it is most likely overloaded, but may also be a single
-// method 'T* C::data() const' in a non-STL-compliant container.
-template <typename T, typename C>
-using ContainerMutableData =
- FirstWithGet<Wrapper<Data, DataChecker<T>, C>,
- FirstWithGet<Wrapper<MutableData, DataChecker<T>, C>,
- Wrapper<MutableStringData, DataChecker<T>, C>>>;
-
-// Wraps C::size() const.
-template <typename C>
-using ContainerSize = Wrapper<Size, SizeChecker, const C>;
-
-// Implementation class for ArraySlice and MutableArraySlice. In the case of
-// ArraySlice, T will be a const type; for MutableArraySlice, T will be a
-// mutable type.
-template <typename T>
-class ArraySliceImplBase {
- public:
- typedef T* pointer;
- typedef const T* const_pointer;
- typedef T& reference;
- typedef const T& const_reference;
- typedef pointer iterator;
- typedef const_pointer const_iterator;
- typedef std::reverse_iterator<iterator> reverse_iterator;
- typedef std::reverse_iterator<const_iterator> const_reverse_iterator;
- typedef size_t size_type;
- typedef ptrdiff_t difference_type;
-
- static const size_type npos = static_cast<size_type>(-1);
-
- ArraySliceImplBase(pointer array, size_type length)
- : ptr_(array), length_(length) {}
-
- // Substring of another ArraySlice.
- // pos must be non-negative and <= x.length().
- // len must be non-negative and will be pinned to at most x.length() - pos.
- ArraySliceImplBase(const ArraySliceImplBase& x, size_type pos, size_type len)
- : ptr_(x.ptr_ + pos), length_(std::min(x.length_ - pos, len)) {}
-
- // Some of the const methods below return pointers and references to mutable
- // data. This is only the case in this internal class; ArraySlice and
- // MutableArraySlice provide deep-constness.
-
- pointer data() const { return ptr_; }
- size_type size() const { return length_; }
-
- void clear() {
- ptr_ = nullptr;
- length_ = 0;
- }
-
- reference operator[](size_type i) const { return ptr_[i]; }
- reference at(size_type i) const {
- DCHECK_LT(i, length_);
- return ptr_[i];
- }
- reference front() const {
- DCHECK_GT(length_, 0);
- return ptr_[0];
- }
- reference back() const {
- DCHECK_GT(length_, 0);
- return ptr_[length_ - 1];
- }
-
- void remove_prefix(size_type n) {
- DCHECK_GE(length_, n);
- ptr_ += n;
- length_ -= n;
- }
- void remove_suffix(size_type n) {
- DCHECK_GE(length_, n);
- length_ -= n;
- }
-
- iterator begin() const { return ptr_; }
- iterator end() const { return ptr_ + length_; }
- reverse_iterator rbegin() const { return reverse_iterator(end()); }
- reverse_iterator rend() const { return reverse_iterator(begin()); }
-
- bool operator==(const ArraySliceImplBase& other) const {
- if (size() != other.size()) return false;
- if (data() == other.data()) return true;
- return std::equal(data(), data() + size(), other.data());
- }
- bool operator!=(const ArraySliceImplBase& other) const {
- return !(*this == other);
- }
-
- private:
- pointer ptr_;
- size_type length_;
-};
-
-template <typename T>
-class ArraySliceImpl : public ArraySliceImplBase<const T> {
- public:
- using ArraySliceImplBase<const T>::ArraySliceImplBase;
-
- // Defined iff the data and size accessors for the container C have been
- // defined.
- template <typename C>
- using EnableIfConvertibleFrom =
- typename std::enable_if<ContainerData<T, C>::HasGet() &&
- ContainerSize<C>::HasGet()>::type;
-
- // Constructs from a container when EnableIfConvertibleFrom is
- // defined. std::addressof handles types with overloaded operator&.
- template <typename C>
- explicit ArraySliceImpl(const C& v)
- : ArraySliceImplBase<const T>(ContainerData<T, C>::Get(std::addressof(v)),
- ContainerSize<C>::Get(std::addressof(v))) {}
-};
-
-template <typename T>
-class MutableArraySliceImpl : public ArraySliceImplBase<T> {
- public:
- using ArraySliceImplBase<T>::ArraySliceImplBase;
-
- template <typename C>
- using EnableIfConvertibleFrom =
- typename std::enable_if<ContainerMutableData<T, C>::HasGet() &&
- ContainerSize<C>::HasGet()>::type;
-
- template <typename C>
- explicit MutableArraySliceImpl(C* v)
- : ArraySliceImplBase<T>(ContainerMutableData<T, C>::Get(v),
- ContainerSize<C>::Get(v)) {}
-};
-
-} // namespace array_slice_internal
-} // namespace gtl
-} // namespace tensorflow
-
-#endif // TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_
diff --git a/tensorflow/core/lib/gtl/array_slice_test.cc b/tensorflow/core/lib/gtl/array_slice_test.cc
deleted file mode 100644
index c798a488cb..0000000000
--- a/tensorflow/core/lib/gtl/array_slice_test.cc
+++ /dev/null
@@ -1,664 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/gtl/array_slice.h"
-
-#include <algorithm>
-#include <array>
-#include <string>
-#include <vector>
-
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
-#include "tensorflow/core/lib/gtl/stl_util.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-namespace gtl {
-namespace {
-
-typedef ArraySlice<int> IntSlice;
-typedef ArraySlice<char> CharSlice;
-typedef MutableArraySlice<int> MutableIntSlice;
-typedef MutableArraySlice<char> MutableCharSlice;
-typedef std::vector<int> IntVec;
-
-// Append 0..len-1 to *v
-template <typename Vector>
-static void Fill(Vector* v, int len, int offset = 0) {
- for (int i = 0; i < len; i++) {
- v->push_back(i + offset);
- }
-}
-
-static void TestHelper(const IntSlice& vorig, const IntVec& vec) {
- IntSlice other; // To test the assignment return value.
- IntSlice v = other = vorig;
- const int len = vec.size();
- EXPECT_EQ(v.size(), vec.size());
-
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(v[i], vec[i]);
- EXPECT_EQ(v.at(i), vec[i]);
- }
- EXPECT_EQ(v.begin(), gtl::vector_as_array(&vec));
-
- int counter = 0;
- for (IntSlice::iterator it = v.begin(); it != v.end(); ++it) {
- EXPECT_EQ(counter, *it);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- counter = 0;
- for (IntSlice::const_iterator it = v.begin(); it != v.end(); ++it) {
- EXPECT_EQ(counter, *it);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- if (len > 0) {
- EXPECT_EQ(0, v.front());
- EXPECT_EQ(len - 1, v.back());
- v.remove_suffix(1);
- EXPECT_EQ(len - 1, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(i, v[i]);
- }
- if (len > 1) {
- v.remove_prefix(1);
- EXPECT_EQ(len - 2, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(i + 1, v[i]);
- }
- }
- }
-}
-
-// The element access test that is applicable both when MutableArraySlice is
-// const and when it's not.
-template <class V>
-void MutableTestHelperTemplated(V v, int* ptr, const int len) {
- CHECK_EQ(v.size(), len);
-
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(ptr + i, &v[i]);
- EXPECT_EQ(ptr + i, &v.at(i));
- }
- EXPECT_EQ(ptr, v.begin());
- EXPECT_EQ(ptr + len, v.end());
- EXPECT_EQ(ptr, v.data());
-
- int counter = 0;
- for (MutableIntSlice::const_iterator it = v.begin(); it != v.end(); ++it) {
- EXPECT_EQ(ptr + counter, &*it);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- EXPECT_EQ(len, std::distance(v.rbegin(), v.rend()));
-
- if (len > 0) {
- EXPECT_EQ(ptr, &v.front());
- EXPECT_EQ(ptr + len - 1, &v.back());
- EXPECT_EQ(ptr + len - 1, &*v.rbegin());
- EXPECT_EQ(ptr, &*(v.rend() - 1));
- }
-}
-
-static void MutableTestHelper(const MutableIntSlice& vorig, int* ptr,
- const int len) {
- // Test the data accessors both when the MutableArraySlice is declared const,
- // and when it is not.
- MutableTestHelperTemplated<const MutableIntSlice&>(vorig, ptr, len);
- MutableTestHelperTemplated<MutableIntSlice>(vorig, ptr, len);
-
- MutableIntSlice other; // To test the assignment return value.
- MutableIntSlice v = other = vorig;
- EXPECT_EQ(ptr, v.data());
-
- int counter = 0;
- for (MutableIntSlice::iterator it = v.begin(); it != v.end(); ++it) {
- EXPECT_EQ(ptr + counter, &*it);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- if (len > 0) {
- // Test that elements are assignable.
- v[0] = 1;
- v.front() = 2;
- v.back() = 5;
- *v.data() = 4;
- std::fill(v.begin(), v.end(), 5);
- std::fill(v.rbegin(), v.rend(), 6);
- // Test size-changing methods.
- v.remove_suffix(1);
- EXPECT_EQ(len - 1, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(ptr + i, &v[i]);
- }
- if (len > 1) {
- v.remove_prefix(1);
- EXPECT_EQ(len - 2, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(ptr + i + 1, &v[i]);
- }
- }
- }
-}
-
-template <typename Vector>
-static void TestImplicitConversion(const IntSlice& v, const Vector& vec) {
- EXPECT_EQ(v.size(), vec.size());
- for (size_t i = 0; i < v.size(); i++) {
- EXPECT_EQ(v[i], vec[i]);
- }
-}
-
-template <typename Vector>
-static void TestImplicitConversion(const CharSlice& v, const Vector& vec) {
- TestImplicitConversion(IntVec(v.begin(), v.end()), vec);
-}
-
-static void TestImplicitConversion(const MutableIntSlice& v, const int* data,
- int size) {
- EXPECT_EQ(size, v.size());
- for (size_t i = 0; i < v.size(); i++) {
- EXPECT_EQ(data + i, &v[i]);
- }
-}
-
-static void TestImplicitConversion(const MutableCharSlice& v, const char* data,
- int size) {
- EXPECT_EQ(size, v.size());
- for (size_t i = 0; i < v.size(); i++) {
- EXPECT_EQ(data + i, &v[i]);
- }
-}
-// A struct supplying the data(), mutable_data() and size() methods, just like
-// e.g. proto2::RepeatedField.
-struct RepeatedField {
- std::vector<int> storage;
- const int* data() const { return storage.data(); }
- int* mutable_data() { return storage.data(); }
- int size() const { return storage.size(); }
-};
-
-// A struct supplying the data() (both mutable and const versions) and
-// size(). It also supplies mutable_data() but we test that data() is selected
-// instead.
-struct ContainerWithOverloads {
- std::vector<int> storage;
- std::vector<int> wrong_storage;
- const int* data() const { return storage.data(); }
- int* data() { return storage.data(); }
- // MutableArraySlice should not call mutable_data(), preferring data()
- // instead.
- int* mutable_data() { return wrong_storage.data(); }
- int size() const { return storage.size(); }
-};
-
-// A struct supplying data() and size() methods.
-struct ContainerWithShallowConstData {
- std::vector<int> storage;
- int* data() const { return const_cast<int*>(storage.data()); }
- int size() const { return storage.size(); }
-};
-
-TEST(IntSlice, Simple) {
- for (int len = 0; len < 20; len++) {
- IntVec vec;
- Fill(&vec, len);
- TestHelper(IntSlice(vec), vec);
- TestHelper(IntSlice(vec.data(), vec.size()), vec);
- }
-}
-
-TEST(IntSlice, WithPosAndLen) {
- IntVec vec;
- Fill(&vec, 20);
- for (size_t len = 0; len < vec.size(); len++) {
- IntVec subvec(vec.begin(), vec.begin() + len);
- TestImplicitConversion(IntSlice(vec, 0, len), subvec);
- TestImplicitConversion(IntSlice(IntSlice(vec), 0, len), subvec);
- }
- EXPECT_EQ(0, IntSlice(vec, 0, 0).size());
- EXPECT_EQ(0, IntSlice(IntSlice(vec), 0, 0).size());
- TestImplicitConversion(IntSlice(vec, 0, IntSlice::npos), vec);
-}
-
-TEST(IntSlice, Clear) {
- for (int len = 0; len < 20; len++) {
- IntVec vec;
- Fill(&vec, len);
- IntSlice v(vec);
- v.clear();
- EXPECT_EQ(0, v.size());
- EXPECT_EQ(v.begin(), v.end());
- }
-}
-
-TEST(IntSlice, Swap) {
- for (int l1 = 0; l1 < 20; l1++) {
- for (int l2 = 0; l2 < 20; l2++) {
- IntVec avec, bvec;
- Fill(&avec, l1);
- Fill(&bvec, l2, 100);
- IntSlice a(avec), b(bvec);
- using std::swap;
- swap(a, b);
- EXPECT_EQ(l1, b.size());
- EXPECT_EQ(l2, a.size());
- for (int i = 0; i < l1; i++) {
- EXPECT_EQ(i, b[i]);
- }
- for (int i = 0; i < l2; i++) {
- EXPECT_EQ(100 + i, a[i]);
- }
- }
- }
-}
-
-TEST(IntSlice, ImplicitConversion) {
- for (int len = 0; len < 20; len++) {
- IntVec vec;
- Fill(&vec, len);
- IntSlice slice;
- slice = vec;
- TestImplicitConversion(vec, vec);
- TestImplicitConversion(slice, vec);
- TestImplicitConversion(IntSlice(vec.data(), vec.size()), vec);
- }
-}
-
-TEST(IntSlice, InlinedVectorConversion) {
- for (int len = 0; len < 20; len++) {
- InlinedVector<int, 4> inline_vec;
- for (int i = 0; i < len; i++) {
- inline_vec.push_back(i);
- }
- IntVec vec;
- Fill(&vec, len);
- IntSlice v = inline_vec; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(inline_vec, vec);
- }
-}
-
-TEST(IntSlice, StaticArrayConversion) {
- int array[20];
- IntVec vec;
- Fill(&vec, TF_ARRAYSIZE(array));
- std::copy(vec.begin(), vec.end(), array);
- IntSlice v = array; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(array, vec);
-}
-
-TEST(IntSlice, StdArrayConversion) {
- std::array<int, 20> array;
- IntVec vec;
- Fill(&vec, array.size());
- std::copy(vec.begin(), vec.end(), array.begin());
-
- // Check assignment.
- {
- IntSlice v = array;
- static_cast<void>(v);
- }
-
- // Check sub-slice initialization.
- {
- IntSlice v = {array, 10, 15};
- static_cast<void>(v);
- }
-
- TestImplicitConversion(array, vec);
-}
-
-// Values according to the Fill function.
-static const int test_const_array[] = {0, 1, 2};
-
-TEST(IntSlice, ConstStaticArrayConversion) {
- IntVec vec;
- Fill(&vec, TF_ARRAYSIZE(test_const_array));
- IntSlice v = test_const_array; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(test_const_array, vec);
-}
-
-TEST(IntSlice, RepeatedFieldConversion) {
- RepeatedField repeated_field;
- IntVec vec;
- Fill(&vec, 20);
- repeated_field.storage = vec;
- IntSlice v = repeated_field; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(repeated_field, vec);
-}
-
-TEST(IntSlice, ContainerWithOverloadsConversion) {
- ContainerWithOverloads container;
- Fill(&container.storage, 20);
- container.wrong_storage.resize(container.size());
- IntSlice v = container; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(container, container.storage);
-}
-
-TEST(IntSlice, ContainerWithShallowConstDataConversion) {
- ContainerWithShallowConstData container;
- Fill(&container.storage, 20);
- IntSlice v = container; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(container, container.storage);
-}
-
-TEST(IntSlice, MutableIntSliceConversion) {
- IntVec vec(20);
- IntSlice slice = MutableIntSlice(&vec);
- EXPECT_EQ(vec.size(), slice.size());
- EXPECT_EQ(vec.data(), slice.data());
-}
-
-TEST(IntSlice, Equality) {
- IntVec vec1(20);
- IntVec vec2(20);
- // These two slices are from different vectors, but have the same
- // size and have the same elements (right now). They should
- // compare equal.
- const IntSlice from1(vec1);
- const IntSlice from2(vec2);
- EXPECT_EQ(from1, from1);
- EXPECT_EQ(from1, from2);
-
- // This verifies that MutableArraySlices can be compared freely with
- // ArraySlices.
- const MutableIntSlice mutable_from1(&vec1);
- const MutableIntSlice mutable_from2(&vec2);
- EXPECT_EQ(from1, mutable_from1);
- EXPECT_EQ(mutable_from1, from1);
- EXPECT_EQ(mutable_from1, mutable_from2);
- EXPECT_EQ(mutable_from2, mutable_from1);
-
- // With a different size, the array slices should not be equal.
- EXPECT_NE(from1, IntSlice(from1, 0, from1.size() - 1));
-
- // With different contents, the array slices should not be equal.
- ++vec2.back();
- EXPECT_NE(from1, from2);
-}
-
-// Compile-asserts that the argument has the expected type.
-template <typename Expected, typename T>
-void CheckType(const T& value) {
- ::testing::StaticAssertTypeEq<Expected, T>();
-}
-
-TEST(IntSlice, ExposesContainerTypesAndConsts) {
- IntSlice slice;
- const IntSlice const_slice;
- CheckType<IntSlice::iterator>(slice.begin());
- CheckType<IntSlice::const_iterator>(const_slice.end());
- CheckType<IntSlice::const_reverse_iterator>(const_slice.rbegin());
- CheckType<IntSlice::reverse_iterator>(slice.rend());
- ::testing::StaticAssertTypeEq<int, IntSlice::value_type>();
- ::testing::StaticAssertTypeEq<const int*, IntSlice::pointer>();
- ::testing::StaticAssertTypeEq<const int&, IntSlice::const_reference>();
- EXPECT_EQ(static_cast<IntSlice::size_type>(-1), IntSlice::npos);
-}
-
-void TestEmpty(IntSlice slice) { ASSERT_TRUE(slice.empty()); }
-
-void TestRange(IntSlice slice, int from, int to) {
- ASSERT_EQ(to - from + 1, slice.size());
- for (size_t i = 0; i < slice.size(); ++i) {
- EXPECT_EQ(from + i, slice[i]);
- }
-}
-
-TEST(IntSlice, InitializerListConversion) {
- TestEmpty({});
- TestRange({1}, 1, 1);
- TestRange({10, 11, 12, 13}, 10, 13);
-}
-
-TEST(CharSlice, StringConversion) {
- IntVec vec;
- Fill(&vec, 20);
- string str(vec.begin(), vec.end());
- CharSlice v = str; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(str, vec);
-}
-
-TEST(IntPtrSlice, ConstConversion) {
- int one = 1;
- int two = 2;
- std::vector<int*> vec;
- vec.push_back(&one);
- vec.push_back(&two);
- ArraySlice<const int*> v = vec;
- ASSERT_EQ(2, v.size());
- EXPECT_EQ(&one, v[0]);
- EXPECT_EQ(&two, v[1]);
-}
-
-TEST(MutableIntSlice, Simple) {
- for (int len = 0; len < 20; len++) {
- IntVec vec(len);
- MutableTestHelper(MutableIntSlice(&vec), vec.data(), len);
- MutableTestHelper(MutableIntSlice(vec.data(), vec.size()), vec.data(), len);
- }
-}
-
-TEST(MutableIntSlice, WithPosAndLen) {
- IntVec vec(20);
- for (size_t len = 0; len < vec.size(); len++) {
- TestImplicitConversion(MutableIntSlice(&vec, 0, len), vec.data(), len);
- TestImplicitConversion(MutableIntSlice(MutableIntSlice(&vec), 0, len),
- vec.data(), len);
- }
- EXPECT_EQ(0, MutableIntSlice(&vec, 0, 0).size());
- EXPECT_EQ(0, MutableIntSlice(MutableIntSlice(&vec), 0, 0).size());
- TestImplicitConversion(MutableIntSlice(&vec, 0, MutableIntSlice::npos),
- vec.data(), vec.size());
-}
-
-TEST(MutableIntSlice, Clear) {
- for (int len = 0; len < 20; len++) {
- IntVec vec(len);
- MutableIntSlice v(&vec);
- v.clear();
- EXPECT_EQ(0, v.size());
- EXPECT_EQ(v.begin(), v.end());
- }
-}
-
-TEST(MutableIntSlice, Swap) {
- for (int l1 = 0; l1 < 20; l1++) {
- for (int l2 = 0; l2 < 20; l2++) {
- IntVec avec(l1), bvec(l2);
- MutableIntSlice a(&avec), b(&bvec);
- using std::swap;
- swap(a, b);
- EXPECT_EQ(l1, b.size());
- EXPECT_EQ(l2, a.size());
- for (int i = 0; i < l1; i++) {
- EXPECT_EQ(&avec[i], &b[i]);
- }
- for (int i = 0; i < l2; i++) {
- EXPECT_EQ(&bvec[i], &a[i]);
- }
- }
- }
-}
-
-TEST(MutableIntSlice, ImplicitConversion) {
- for (int len = 0; len < 20; len++) {
- IntVec vec(len);
- MutableIntSlice slice;
- slice = &vec;
- TestImplicitConversion(&vec, vec.data(), len);
- TestImplicitConversion(slice, vec.data(), len);
- TestImplicitConversion(MutableIntSlice(vec.data(), vec.size()), vec.data(),
- len);
- }
-}
-
-TEST(MutableIntSlice, InlinedVectorConversion) {
- for (int len = 0; len < 20; len++) {
- InlinedVector<int, 4> inline_vec;
- for (int i = 0; i < len; i++) {
- inline_vec.push_back(i);
- }
- MutableIntSlice v = &inline_vec; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(&inline_vec, inline_vec.data(), inline_vec.size());
- }
-}
-
-TEST(MutableIntSlice, StaticArrayConversion) {
- int array[20];
- MutableIntSlice v = array; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(array, array, TF_ARRAYSIZE(array));
-}
-
-TEST(MutableIntSlice, StdArrayConversion) {
- std::array<int, 20> array;
-
- // Check assignment.
- {
- MutableIntSlice v = &array;
- static_cast<void>(v);
- }
-
- // Check sub-slice initialization.
- {
- MutableIntSlice v = {&array, 10, 15};
- static_cast<void>(v);
- }
-
- TestImplicitConversion(&array, &array[0], array.size());
-}
-
-TEST(MutableIntSlice, RepeatedFieldConversion) {
- RepeatedField repeated_field;
- Fill(&repeated_field.storage, 20);
- MutableIntSlice v = &repeated_field; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(&repeated_field, repeated_field.storage.data(),
- repeated_field.storage.size());
-}
-
-TEST(MutableIntSlice, ContainerWithOverloadsConversion) {
- ContainerWithOverloads container;
- Fill(&container.storage, 20);
- container.wrong_storage.resize(container.size());
- MutableIntSlice v = &container; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(&container, container.storage.data(),
- container.storage.size());
-}
-
-TEST(MutableIntSlice, ContainerWithShallowConstDataConversion) {
- ContainerWithShallowConstData container;
- Fill(&container.storage, 20);
- MutableIntSlice v = &container; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(&container, container.storage.data(),
- container.storage.size());
-}
-
-TEST(MutableIntSlice, TypedefsAndConstants) {
- ::testing::StaticAssertTypeEq<int, MutableIntSlice::value_type>();
- ::testing::StaticAssertTypeEq<int*, MutableIntSlice::pointer>();
- ::testing::StaticAssertTypeEq<const int*, MutableIntSlice::const_pointer>();
- ::testing::StaticAssertTypeEq<int&, MutableIntSlice::reference>();
- ::testing::StaticAssertTypeEq<const int&, MutableIntSlice::const_reference>();
-
- EXPECT_EQ(static_cast<MutableIntSlice::size_type>(-1), MutableIntSlice::npos);
-}
-
-TEST(MutableIntSlice, IteratorsAndReferences) {
- auto accept_pointer = [](int* x) {};
- auto accept_reference = [](int& x) {};
- auto accept_iterator = [](MutableIntSlice::iterator x) {};
- auto accept_reverse_iterator = [](MutableIntSlice::reverse_iterator x) {};
-
- int a[1];
- MutableIntSlice s = a;
-
- accept_pointer(s.data());
- accept_iterator(s.begin());
- accept_iterator(s.end());
- accept_reverse_iterator(s.rbegin());
- accept_reverse_iterator(s.rend());
-
- accept_reference(s[0]);
- accept_reference(s.at(0));
- accept_reference(s.front());
- accept_reference(s.back());
-}
-
-TEST(MutableIntSlice, IteratorsAndReferences_Const) {
- auto accept_pointer = [](int* x) {};
- auto accept_reference = [](int& x) {};
- auto accept_iterator = [](MutableIntSlice::iterator x) {};
- auto accept_reverse_iterator = [](MutableIntSlice::reverse_iterator x) {};
-
- int a[1];
- const MutableIntSlice s = a;
-
- accept_pointer(s.data());
- accept_iterator(s.begin());
- accept_iterator(s.end());
- accept_reverse_iterator(s.rbegin());
- accept_reverse_iterator(s.rend());
-
- accept_reference(s[0]);
- accept_reference(s.at(0));
- accept_reference(s.front());
- accept_reference(s.back());
-}
-
-bool TestMutableOverload(MutableIntSlice slice) { return false; }
-
-bool TestMutableOverload(MutableCharSlice slice) { return true; }
-
-TEST(MutableCharSlice, StringConversion) {
- for (int len = 0; len < 20; len++) {
- string str(len, '\0');
- MutableCharSlice v = &str; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(v, str.data(), str.size());
- }
- // Verify that only the correct overload is feasible. Note that this would
- // fail if the string ctor was declared simply as MutableArraySlice(string*),
- // since in that case both overloads would be feasible.
- string str;
- EXPECT_TRUE(TestMutableOverload(&str));
-
- // Avoid warning "unused function 'TestMutableOverload'"
- int a[1];
- EXPECT_FALSE(TestMutableOverload(a));
-}
-
-} // namespace
-} // namespace gtl
-} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/optional.cc b/tensorflow/core/lib/gtl/optional.cc
deleted file mode 100644
index 8dea073788..0000000000
--- a/tensorflow/core/lib/gtl/optional.cc
+++ /dev/null
@@ -1,25 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/gtl/optional.h"
-
-namespace tensorflow {
-namespace gtl {
-
-nullopt_t::init_t nullopt_t::init;
-extern const nullopt_t nullopt{nullopt_t::init};
-
-} // namespace gtl
-} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/optional.h b/tensorflow/core/lib/gtl/optional.h
index 7ad916ad3d..238aa18e1e 100644
--- a/tensorflow/core/lib/gtl/optional.h
+++ b/tensorflow/core/lib/gtl/optional.h
@@ -16,861 +16,18 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_LIB_GTL_OPTIONAL_H_
#define TENSORFLOW_CORE_LIB_GTL_OPTIONAL_H_
-#include <assert.h>
-#include <functional>
-#include <initializer_list>
-#include <type_traits>
-#include <utility>
-
-#include "tensorflow/core/platform/logging.h"
+#include "absl/types/optional.h"
namespace tensorflow {
namespace gtl {
-// A value of type gtl::optional<T> holds either a value of T or an
-// "empty" value. When it holds a value of T, it stores it as a direct
-// subobject, so sizeof(optional<T>) is approximately sizeof(T)+1. The interface
-// is based on the upcoming std::optional<T>, and gtl::optional<T> is
-// designed to be cheaply drop-in replaceable by std::optional<T>, once it is
-// rolled out.
-//
-// This implementation is based on the specification in the latest draft as of
-// 2017-01-05, section 20.6.
-//
-// Differences between gtl::optional<T> and std::optional<T> include:
-// - constexpr not used for nonconst member functions.
-// (dependency on some differences between C++11 and C++14.)
-// - nullopt and in_place are not constexpr. We need the inline variable
-// support in C++17 for external linkage.
-// - CHECK instead of throwing std::bad_optional_access.
-// - optional::swap() and swap() relies on std::is_(nothrow_)swappable
-// which is introduced in C++17. So we assume is_swappable is always true
-// and is_nothrow_swappable is same as std::is_trivial.
-// - make_optional cannot be constexpr due to absence of guaranteed copy
-// elision.
-//
-// Synopsis:
-//
-// #include "tensorflow/core/lib/gtl/optional.h"
-//
-// tensorflow::gtl::optional<string> f() {
-// string result;
-// if (...) {
-// ...
-// result = ...;
-// return result;
-// } else {
-// ...
-// return tensorflow::gtl::nullopt;
-// }
-// }
-//
-// int main() {
-// tensorflow::gtl::optional<string> optstr = f();
-// if (optstr) {
-// // non-empty
-// print(optstr.value());
-// } else {
-// // empty
-// error();
-// }
-// }
-template <typename T>
-class optional;
-
-// The tag constant `in_place` is used as the first parameter of an optional<T>
-// constructor to indicate that the remaining arguments should be forwarded
-// to the underlying T constructor.
-struct in_place_t {};
-extern const in_place_t in_place;
-
-// The tag constant `nullopt` is used to indicate an empty optional<T> in
-// certain functions, such as construction or assignment.
-struct nullopt_t {
- struct init_t {};
- static init_t init;
- // It must not be default-constructible to avoid ambiguity for opt = {}.
- // Note the non-const reference, it is to eliminate ambiguity for code like:
- // struct S { int value; };
- //
- // void Test() {
- // optional<S> opt;
- // opt = {{}};
- // }
- explicit constexpr nullopt_t(init_t& /*unused*/) {} // NOLINT
-};
-extern const nullopt_t nullopt;
-
-namespace internal_optional {
-
-// define forward locally because std::forward is not constexpr until C++14
-template <typename T>
-constexpr T&& forward(typename std::remove_reference<T>::type&
- t) noexcept { // NOLINT(runtime/references)
- return static_cast<T&&>(t);
-}
-
-struct empty_struct {};
-// This class stores the data in optional<T>.
-// It is specialized based on whether T is trivially destructible.
-// This is the specialization for non trivially destructible type.
-template <typename T, bool = std::is_trivially_destructible<T>::value>
-class optional_data_dtor_base {
- protected:
- // Whether there is data or not.
- bool engaged_;
- // data storage
- union {
- empty_struct dummy_;
- T data_;
- };
-
- void destruct() noexcept {
- if (engaged_) {
- data_.~T();
- engaged_ = false;
- }
- }
-
- // dummy_ must be initialized for constexpr constructor
- constexpr optional_data_dtor_base() noexcept : engaged_(false), dummy_{} {}
-
- template <typename... Args>
- constexpr explicit optional_data_dtor_base(in_place_t, Args&&... args)
- : engaged_(true), data_(internal_optional::forward<Args>(args)...) {}
-
- ~optional_data_dtor_base() { destruct(); }
-};
-
-// Specialization for trivially destructible type.
-template <typename T>
-class optional_data_dtor_base<T, true> {
- protected:
- // Whether there is data or not.
- bool engaged_;
- // data storage
- union {
- empty_struct dummy_;
- T data_;
- };
- void destruct() noexcept { engaged_ = false; }
-
- // dummy_ must be initialized for constexpr constructor
- constexpr optional_data_dtor_base() noexcept : engaged_(false), dummy_{} {}
-
- template <typename... Args>
- constexpr explicit optional_data_dtor_base(in_place_t, Args&&... args)
- : engaged_(true), data_(internal_optional::forward<Args>(args)...) {}
-
- ~optional_data_dtor_base() = default;
-};
-
-template <typename T>
-class optional_data : public optional_data_dtor_base<T> {
- protected:
- using base = optional_data_dtor_base<T>;
- using base::base;
-
- T* pointer() { return &this->data_; }
-
- constexpr const T* pointer() const { return &this->data_; }
-
- template <typename... Args>
- void construct(Args&&... args) {
- new (pointer()) T(std::forward<Args>(args)...);
- this->engaged_ = true;
- }
-
- template <typename U>
- void assign(U&& u) {
- if (this->engaged_) {
- this->data_ = std::forward<U>(u);
- } else {
- construct(std::forward<U>(u));
- }
- }
-
- optional_data() = default;
-
- optional_data(const optional_data& rhs) {
- if (rhs.engaged_) {
- construct(rhs.data_);
- }
- }
-
- optional_data(optional_data&& rhs) noexcept(
- std::is_nothrow_move_constructible<T>::value) {
- if (rhs.engaged_) {
- construct(std::move(rhs.data_));
- }
- }
-
- optional_data& operator=(const optional_data& rhs) {
- if (rhs.engaged_) {
- assign(rhs.data_);
- } else {
- this->destruct();
- }
- return *this;
- }
-
- optional_data& operator=(optional_data&& rhs) noexcept(
- std::is_nothrow_move_assignable<T>::value&&
- std::is_nothrow_move_constructible<T>::value) {
- if (rhs.engaged_) {
- assign(std::move(rhs.data_));
- } else {
- this->destruct();
- }
- return *this;
- }
-};
-
-// ordered by level of restriction, from low to high.
-// copyable implies movable.
-enum class copy_traits { copyable = 0, movable = 1, non_movable = 2 };
-
-// base class for enabling/disabling copy/move constructor.
-template <copy_traits>
-class optional_ctor_base;
-
-template <>
-class optional_ctor_base<copy_traits::copyable> {
- public:
- constexpr optional_ctor_base() = default;
- optional_ctor_base(const optional_ctor_base&) = default;
- optional_ctor_base(optional_ctor_base&&) = default;
- optional_ctor_base& operator=(const optional_ctor_base&) = default;
- optional_ctor_base& operator=(optional_ctor_base&&) = default;
-};
-
-template <>
-class optional_ctor_base<copy_traits::movable> {
- public:
- constexpr optional_ctor_base() = default;
- optional_ctor_base(const optional_ctor_base&) = delete;
- optional_ctor_base(optional_ctor_base&&) = default;
- optional_ctor_base& operator=(const optional_ctor_base&) = default;
- optional_ctor_base& operator=(optional_ctor_base&&) = default;
-};
-
-template <>
-class optional_ctor_base<copy_traits::non_movable> {
- public:
- constexpr optional_ctor_base() = default;
- optional_ctor_base(const optional_ctor_base&) = delete;
- optional_ctor_base(optional_ctor_base&&) = delete;
- optional_ctor_base& operator=(const optional_ctor_base&) = default;
- optional_ctor_base& operator=(optional_ctor_base&&) = default;
-};
-
-// base class for enabling/disabling copy/move assignment.
-template <copy_traits>
-class optional_assign_base;
-
-template <>
-class optional_assign_base<copy_traits::copyable> {
- public:
- constexpr optional_assign_base() = default;
- optional_assign_base(const optional_assign_base&) = default;
- optional_assign_base(optional_assign_base&&) = default;
- optional_assign_base& operator=(const optional_assign_base&) = default;
- optional_assign_base& operator=(optional_assign_base&&) = default;
-};
-
-template <>
-class optional_assign_base<copy_traits::movable> {
- public:
- constexpr optional_assign_base() = default;
- optional_assign_base(const optional_assign_base&) = default;
- optional_assign_base(optional_assign_base&&) = default;
- optional_assign_base& operator=(const optional_assign_base&) = delete;
- optional_assign_base& operator=(optional_assign_base&&) = default;
-};
-
-template <>
-class optional_assign_base<copy_traits::non_movable> {
- public:
- constexpr optional_assign_base() = default;
- optional_assign_base(const optional_assign_base&) = default;
- optional_assign_base(optional_assign_base&&) = default;
- optional_assign_base& operator=(const optional_assign_base&) = delete;
- optional_assign_base& operator=(optional_assign_base&&) = delete;
-};
-
+// Deprecated: please use absl::optional directly.
+using absl::make_optional;
+using absl::nullopt;
template <typename T>
-constexpr copy_traits get_ctor_copy_traits() {
- return std::is_copy_constructible<T>::value
- ? copy_traits::copyable
- : std::is_move_constructible<T>::value ? copy_traits::movable
- : copy_traits::non_movable;
-}
-
-template <typename T>
-constexpr copy_traits get_assign_copy_traits() {
- return std::is_copy_assignable<T>::value &&
- std::is_copy_constructible<T>::value
- ? copy_traits::copyable
- : std::is_move_assignable<T>::value &&
- std::is_move_constructible<T>::value
- ? copy_traits::movable
- : copy_traits::non_movable;
-}
-
-// Whether T is constructible or convertible from optional<U>.
-template <typename T, typename U>
-struct is_constructible_convertible_from_optional
- : std::integral_constant<
- bool, std::is_constructible<T, optional<U>&>::value ||
- std::is_constructible<T, optional<U>&&>::value ||
- std::is_constructible<T, const optional<U>&>::value ||
- std::is_constructible<T, const optional<U>&&>::value ||
- std::is_convertible<optional<U>&, T>::value ||
- std::is_convertible<optional<U>&&, T>::value ||
- std::is_convertible<const optional<U>&, T>::value ||
- std::is_convertible<const optional<U>&&, T>::value> {};
-
-// Whether T is constructible or convertible or assignable from optional<U>.
-template <typename T, typename U>
-struct is_constructible_convertible_assignable_from_optional
- : std::integral_constant<
- bool, is_constructible_convertible_from_optional<T, U>::value ||
- std::is_assignable<T&, optional<U>&>::value ||
- std::is_assignable<T&, optional<U>&&>::value ||
- std::is_assignable<T&, const optional<U>&>::value ||
- std::is_assignable<T&, const optional<U>&&>::value> {};
-
-} // namespace internal_optional
-
-template <typename T>
-class optional : private internal_optional::optional_data<T>,
- private internal_optional::optional_ctor_base<
- internal_optional::get_ctor_copy_traits<T>()>,
- private internal_optional::optional_assign_base<
- internal_optional::get_assign_copy_traits<T>()> {
- using data_base = internal_optional::optional_data<T>;
-
- public:
- typedef T value_type;
-
- // [optional.ctor], constructors
-
- // A default constructed optional holds the empty value, NOT a default
- // constructed T.
- constexpr optional() noexcept {}
-
- // An optional initialized with `nullopt` holds the empty value.
- constexpr optional(nullopt_t) noexcept {} // NOLINT(runtime/explicit)
-
- // Copy constructor, standard semantics.
- optional(const optional& src) = default;
-
- // Move constructor, standard semantics.
- optional(optional&& src) = default;
-
- // optional<T>(in_place, arg1, arg2, arg3) constructs a non-empty optional
- // with an in-place constructed value of T(arg1,arg2,arg3).
- // TODO(b/34201852): Add std::is_constructible<T, Args&&...> SFINAE.
- template <typename... Args>
- constexpr explicit optional(in_place_t, Args&&... args)
- : data_base(in_place_t(), internal_optional::forward<Args>(args)...) {}
-
- // optional<T>(in_place, {arg1, arg2, arg3}) constructs a non-empty optional
- // with an in-place list-initialized value of T({arg1, arg2, arg3}).
- template <typename U, typename... Args,
- typename = typename std::enable_if<std::is_constructible<
- T, std::initializer_list<U>&, Args&&...>::value>::type>
- constexpr explicit optional(in_place_t, std::initializer_list<U> il,
- Args&&... args)
- : data_base(in_place_t(), il, internal_optional::forward<Args>(args)...) {
- }
-
- template <
- typename U = T,
- typename std::enable_if<
- std::is_constructible<T, U&&>::value &&
- !std::is_same<in_place_t, typename std::decay<U>::type>::value &&
- !std::is_same<optional<T>, typename std::decay<U>::type>::value &&
- std::is_convertible<U&&, T>::value,
- bool>::type = false>
- constexpr optional(U&& v) // NOLINT
- : data_base(in_place_t(), internal_optional::forward<U>(v)) {}
-
- template <
- typename U = T,
- typename std::enable_if<
- std::is_constructible<T, U&&>::value &&
- !std::is_same<in_place_t, typename std::decay<U>::type>::value &&
- !std::is_same<optional<T>, typename std::decay<U>::type>::value &&
- !std::is_convertible<U&&, T>::value,
- bool>::type = false>
- explicit constexpr optional(U&& v)
- : data_base(in_place_t(), internal_optional::forward<U>(v)) {}
-
- // Converting copy constructor (implicit)
- template <
- typename U,
- typename std::enable_if<
- std::is_constructible<T, const U&>::value &&
- !internal_optional::is_constructible_convertible_from_optional<
- T, U>::value &&
- std::is_convertible<const U&, T>::value,
- bool>::type = false>
- optional(const optional<U>& rhs) { // NOLINT
- if (rhs) {
- this->construct(*rhs);
- }
- }
-
- // Converting copy constructor (explicit)
- template <
- typename U,
- typename std::enable_if<
- std::is_constructible<T, const U&>::value &&
- !internal_optional::is_constructible_convertible_from_optional<
- T, U>::value &&
- !std::is_convertible<const U&, T>::value,
- bool>::type = false>
- explicit optional(const optional<U>& rhs) {
- if (rhs) {
- this->construct(*rhs);
- }
- }
-
- // Converting move constructor (implicit)
- template <
- typename U,
- typename std::enable_if<
- std::is_constructible<T, U&&>::value &&
- !internal_optional::is_constructible_convertible_from_optional<
- T, U>::value &&
- std::is_convertible<U&&, T>::value,
- bool>::type = false>
- optional(optional<U>&& rhs) { // NOLINT
- if (rhs) {
- this->construct(std::move(*rhs));
- }
- }
-
- // Converting move constructor (explicit)
- template <
- typename U,
- typename std::enable_if<
- std::is_constructible<T, U&&>::value &&
- !internal_optional::is_constructible_convertible_from_optional<
- T, U>::value &&
- !std::is_convertible<U&&, T>::value,
- bool>::type = false>
- explicit optional(optional<U>&& rhs) {
- if (rhs) {
- this->construct(std::move(*rhs));
- }
- }
-
- // [optional.dtor], destructor, trivial if T is trivially destructible.
- ~optional() = default;
-
- // [optional.assign], assignment
-
- // Assignment from nullopt: opt = nullopt
- optional& operator=(nullopt_t) noexcept {
- this->destruct();
- return *this;
- }
-
- // Copy assignment, standard semantics.
- optional& operator=(const optional& src) = default;
-
- // Move assignment, standard semantics.
- optional& operator=(optional&& src) = default;
-
- // Value assignment
- template <
- typename U = T,
- typename = typename std::enable_if<
- !std::is_same<optional<T>, typename std::decay<U>::type>::value &&
- (!std::is_scalar<T>::value ||
- !std::is_same<T, typename std::decay<U>::type>::value) &&
- std::is_constructible<T, U>::value &&
- std::is_assignable<T&, U>::value>::type>
- optional& operator=(U&& v) {
- this->assign(std::forward<U>(v));
- return *this;
- }
-
- template <typename U,
- typename = typename std::enable_if<
- std::is_constructible<T, const U&>::value &&
- std::is_assignable<T&, const U&>::value &&
- !internal_optional::
- is_constructible_convertible_assignable_from_optional<
- T, U>::value>::type>
- optional& operator=(const optional<U>& rhs) {
- if (rhs) {
- this->assign(*rhs);
- } else {
- this->destruct();
- }
- return *this;
- }
-
- template <typename U,
- typename = typename std::enable_if<
- std::is_constructible<T, U>::value &&
- std::is_assignable<T&, U>::value &&
- !internal_optional::
- is_constructible_convertible_assignable_from_optional<
- T, U>::value>::type>
- optional& operator=(optional<U>&& rhs) {
- if (rhs) {
- this->assign(std::move(*rhs));
- } else {
- this->destruct();
- }
- return *this;
- }
-
- // [optional.mod], modifiers
- // Destroys the inner T value if one is present.
- void reset() noexcept { this->destruct(); }
-
- // Emplace reconstruction. (Re)constructs the underlying T in-place with the
- // given arguments forwarded:
- //
- // optional<Foo> opt;
- // opt.emplace(arg1,arg2,arg3); (Constructs Foo(arg1,arg2,arg3))
- //
- // If the optional is non-empty, and the `args` refer to subobjects of the
- // current object, then behavior is undefined. This is because the current
- // object will be destructed before the new object is constructed with `args`.
- //
- template <typename... Args,
- typename = typename std::enable_if<
- std::is_constructible<T, Args&&...>::value>::type>
- void emplace(Args&&... args) {
- this->destruct();
- this->construct(std::forward<Args>(args)...);
- }
-
- // Emplace reconstruction with initializer-list. See immediately above.
- template <class U, class... Args,
- typename = typename std::enable_if<std::is_constructible<
- T, std::initializer_list<U>&, Args&&...>::value>::type>
- void emplace(std::initializer_list<U> il, Args&&... args) {
- this->destruct();
- this->construct(il, std::forward<Args>(args)...);
- }
-
- // [optional.swap], swap
- // Swap, standard semantics.
- void swap(optional& rhs) noexcept(
- std::is_nothrow_move_constructible<T>::value&&
- std::is_trivial<T>::value) {
- if (*this) {
- if (rhs) {
- using std::swap;
- swap(**this, *rhs);
- } else {
- rhs.construct(std::move(**this));
- this->destruct();
- }
- } else {
- if (rhs) {
- this->construct(std::move(*rhs));
- rhs.destruct();
- } else {
- // no effect (swap(disengaged, disengaged))
- }
- }
- }
-
- // [optional.observe], observers
- // You may use `*opt`, and `opt->m`, to access the underlying T value and T's
- // member `m`, respectively. If the optional is empty, behavior is
- // undefined.
- constexpr const T* operator->() const { return this->pointer(); }
- T* operator->() {
- assert(this->engaged_);
- return this->pointer();
- }
- constexpr const T& operator*() const& { return reference(); }
- T& operator*() & {
- assert(this->engaged_);
- return reference();
- }
- constexpr const T&& operator*() const&& { return std::move(reference()); }
- T&& operator*() && {
- assert(this->engaged_);
- return std::move(reference());
- }
-
- // In a bool context an optional<T> will return false if and only if it is
- // empty.
- //
- // if (opt) {
- // // do something with opt.value();
- // } else {
- // // opt is empty
- // }
- //
- constexpr explicit operator bool() const noexcept { return this->engaged_; }
-
- // Returns false if and only if *this is empty.
- constexpr bool has_value() const noexcept { return this->engaged_; }
-
- // Use `opt.value()` to get a reference to underlying value. The constness
- // and lvalue/rvalue-ness of `opt` is preserved to the view of the T
- // subobject.
- const T& value() const& {
- CHECK(*this) << "Bad optional access";
- return reference();
- }
- T& value() & {
- CHECK(*this) << "Bad optional access";
- return reference();
- }
- T&& value() && { // NOLINT(build/c++11)
- CHECK(*this) << "Bad optional access";
- return std::move(reference());
- }
- const T&& value() const&& { // NOLINT(build/c++11)
- CHECK(*this) << "Bad optional access";
- return std::move(reference());
- }
-
- // Use `opt.value_or(val)` to get either the value of T or the given default
- // `val` in the empty case.
- template <class U>
- constexpr T value_or(U&& v) const& {
- return static_cast<bool>(*this) ? **this
- : static_cast<T>(std::forward<U>(v));
- }
- template <class U>
- T value_or(U&& v) && { // NOLINT(build/c++11)
- return static_cast<bool>(*this) ? std::move(**this)
- : static_cast<T>(std::forward<U>(v));
- }
-
- private:
- // Private accessors for internal storage viewed as reference to T.
- constexpr const T& reference() const { return *this->pointer(); }
- T& reference() { return *(this->pointer()); }
-
- // T constraint checks. You can't have an optional of nullopt_t, in_place_t
- // or a reference.
- static_assert(
- !std::is_same<nullopt_t, typename std::remove_cv<T>::type>::value,
- "optional<nullopt_t> is not allowed.");
- static_assert(
- !std::is_same<in_place_t, typename std::remove_cv<T>::type>::value,
- "optional<in_place_t> is not allowed.");
- static_assert(!std::is_reference<T>::value,
- "optional<reference> is not allowed.");
-};
-
-// [optional.specalg]
-// Swap, standard semantics.
-// This function shall not participate in overload resolution unless
-// is_move_constructible_v<T> is true and is_swappable_v<T> is true.
-// NOTE: we assume is_swappable is always true. There will be a compiling error
-// if T is actually not Swappable.
-template <typename T,
- typename std::enable_if<std::is_move_constructible<T>::value,
- bool>::type = false>
-void swap(optional<T>& a, optional<T>& b) noexcept(noexcept(a.swap(b))) {
- a.swap(b);
-}
-
-// NOTE: make_optional cannot be constexpr in C++11 because the copy/move
-// constructor is not constexpr and we don't have guaranteed copy elision
-// util C++17. But they are still declared constexpr for consistency with
-// the standard.
-
-// make_optional(v) creates a non-empty optional<T> where the type T is deduced
-// from v. Can also be explicitly instantiated as make_optional<T>(v).
-template <typename T>
-constexpr optional<typename std::decay<T>::type> make_optional(T&& v) {
- return optional<typename std::decay<T>::type>(std::forward<T>(v));
-}
-
-template <typename T, typename... Args>
-constexpr optional<T> make_optional(Args&&... args) {
- return optional<T>(in_place_t(), internal_optional::forward<Args>(args)...);
-}
-
-template <typename T, typename U, typename... Args>
-constexpr optional<T> make_optional(std::initializer_list<U> il,
- Args&&... args) {
- return optional<T>(in_place_t(), il,
- internal_optional::forward<Args>(args)...);
-}
-
-// Relational operators. Empty optionals are considered equal to each
-// other and less than non-empty optionals. Supports relations between
-// optional<T> and optional<T>, between optional<T> and T, and between
-// optional<T> and nullopt.
-// Note: We're careful to support T having non-bool relationals.
-
-// Relational operators [optional.relops]
-// The C++17 (N4606) "Returns:" statements are translated into code
-// in an obvious way here, and the original text retained as function docs.
-// Returns: If bool(x) != bool(y), false; otherwise if bool(x) == false, true;
-// otherwise *x == *y.
-template <class T>
-constexpr bool operator==(const optional<T>& x, const optional<T>& y) {
- return static_cast<bool>(x) != static_cast<bool>(y)
- ? false
- : static_cast<bool>(x) == false ? true : *x == *y;
-}
-// Returns: If bool(x) != bool(y), true; otherwise, if bool(x) == false, false;
-// otherwise *x != *y.
-template <class T>
-constexpr bool operator!=(const optional<T>& x, const optional<T>& y) {
- return static_cast<bool>(x) != static_cast<bool>(y)
- ? true
- : static_cast<bool>(x) == false ? false : *x != *y;
-}
-// Returns: If !y, false; otherwise, if !x, true; otherwise *x < *y.
-template <class T>
-constexpr bool operator<(const optional<T>& x, const optional<T>& y) {
- return !y ? false : !x ? true : *x < *y;
-}
-// Returns: If !x, false; otherwise, if !y, true; otherwise *x > *y.
-template <class T>
-constexpr bool operator>(const optional<T>& x, const optional<T>& y) {
- return !x ? false : !y ? true : *x > *y;
-}
-// Returns: If !x, true; otherwise, if !y, false; otherwise *x <= *y.
-template <class T>
-constexpr bool operator<=(const optional<T>& x, const optional<T>& y) {
- return !x ? true : !y ? false : *x <= *y;
-}
-// Returns: If !y, true; otherwise, if !x, false; otherwise *x >= *y.
-template <class T>
-constexpr bool operator>=(const optional<T>& x, const optional<T>& y) {
- return !y ? true : !x ? false : *x >= *y;
-}
-
-// Comparison with nullopt [optional.nullops]
-// The C++17 (N4606) "Returns:" statements are used directly here.
-template <class T>
-constexpr bool operator==(const optional<T>& x, nullopt_t) noexcept {
- return !x;
-}
-template <class T>
-constexpr bool operator==(nullopt_t, const optional<T>& x) noexcept {
- return !x;
-}
-template <class T>
-constexpr bool operator!=(const optional<T>& x, nullopt_t) noexcept {
- return static_cast<bool>(x);
-}
-template <class T>
-constexpr bool operator!=(nullopt_t, const optional<T>& x) noexcept {
- return static_cast<bool>(x);
-}
-template <class T>
-constexpr bool operator<(const optional<T>& x, nullopt_t) noexcept {
- return false;
-}
-template <class T>
-constexpr bool operator<(nullopt_t, const optional<T>& x) noexcept {
- return static_cast<bool>(x);
-}
-template <class T>
-constexpr bool operator<=(const optional<T>& x, nullopt_t) noexcept {
- return !x;
-}
-template <class T>
-constexpr bool operator<=(nullopt_t, const optional<T>& x) noexcept {
- return true;
-}
-template <class T>
-constexpr bool operator>(const optional<T>& x, nullopt_t) noexcept {
- return static_cast<bool>(x);
-}
-template <class T>
-constexpr bool operator>(nullopt_t, const optional<T>& x) noexcept {
- return false;
-}
-template <class T>
-constexpr bool operator>=(const optional<T>& x, nullopt_t) noexcept {
- return true;
-}
-template <class T>
-constexpr bool operator>=(nullopt_t, const optional<T>& x) noexcept {
- return !x;
-}
-
-// Comparison with T [optional.comp_with_t]
-// The C++17 (N4606) "Equivalent to:" statements are used directly here.
-template <class T>
-constexpr bool operator==(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x == v : false;
-}
-template <class T>
-constexpr bool operator==(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v == *x : false;
-}
-template <class T>
-constexpr bool operator!=(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x != v : true;
-}
-template <class T>
-constexpr bool operator!=(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v != *x : true;
-}
-template <class T>
-constexpr bool operator<(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x < v : true;
-}
-template <class T>
-constexpr bool operator<(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v < *x : false;
-}
-template <class T>
-constexpr bool operator<=(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x <= v : true;
-}
-template <class T>
-constexpr bool operator<=(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v <= *x : false;
-}
-template <class T>
-constexpr bool operator>(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x > v : false;
-}
-template <class T>
-constexpr bool operator>(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v > *x : true;
-}
-template <class T>
-constexpr bool operator>=(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x >= v : false;
-}
-template <class T>
-constexpr bool operator>=(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v >= *x : true;
-}
+using optional = absl::optional<T>;
} // namespace gtl
} // namespace tensorflow
-namespace std {
-
-// Normally std::hash specializations are not recommended in tensorflow code,
-// but we allow this as it is following a standard library component.
-template <class T>
-struct hash<::tensorflow::gtl::optional<T>> {
- size_t operator()(const ::tensorflow::gtl::optional<T>& opt) const {
- if (opt) {
- return hash<T>()(*opt);
- } else {
- return static_cast<size_t>(0x297814aaad196e6dULL);
- }
- }
-};
-
-} // namespace std
-
#endif // TENSORFLOW_CORE_LIB_GTL_OPTIONAL_H_
diff --git a/tensorflow/core/lib/gtl/optional_test.cc b/tensorflow/core/lib/gtl/optional_test.cc
deleted file mode 100644
index 12b5bbc60b..0000000000
--- a/tensorflow/core/lib/gtl/optional_test.cc
+++ /dev/null
@@ -1,1098 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/gtl/optional.h"
-
-#include <string>
-#include <utility>
-
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-namespace {
-
-using tensorflow::gtl::in_place;
-using tensorflow::gtl::in_place_t;
-using tensorflow::gtl::make_optional;
-using tensorflow::gtl::nullopt;
-using tensorflow::gtl::nullopt_t;
-using tensorflow::gtl::optional;
-
-template <typename T>
-string TypeQuals(T&) {
- return "&";
-}
-template <typename T>
-string TypeQuals(T&&) {
- return "&&";
-}
-template <typename T>
-string TypeQuals(const T&) {
- return "c&";
-}
-template <typename T>
-string TypeQuals(const T&&) {
- return "c&&";
-}
-
-struct StructorListener {
- int construct0 = 0;
- int construct1 = 0;
- int construct2 = 0;
- int listinit = 0;
- int copy = 0;
- int move = 0;
- int copy_assign = 0;
- int move_assign = 0;
- int destruct = 0;
-};
-
-struct Listenable {
- static StructorListener* listener;
-
- Listenable() { ++listener->construct0; }
- Listenable(int /*unused*/) { ++listener->construct1; } // NOLINT
- Listenable(int /*unused*/, int /*unused*/) { ++listener->construct2; }
- Listenable(std::initializer_list<int> /*unused*/) { ++listener->listinit; }
- Listenable(const Listenable& /*unused*/) { ++listener->copy; }
- Listenable(Listenable&& /*unused*/) { ++listener->move; } // NOLINT
- Listenable& operator=(const Listenable& /*unused*/) {
- ++listener->copy_assign;
- return *this;
- }
- Listenable& operator=(Listenable&& /*unused*/) { // NOLINT
- ++listener->move_assign;
- return *this;
- }
- ~Listenable() { ++listener->destruct; }
-};
-
-StructorListener* Listenable::listener = nullptr;
-
-// clang on macos -- even the latest major version at time of writing (8.x) --
-// does not like much of our constexpr business. clang < 3.0 also has trouble.
-#if defined(__clang__) && defined(__APPLE__)
-#define SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG
-#endif
-
-struct ConstexprType {
- constexpr ConstexprType() : x(0) {}
- constexpr explicit ConstexprType(int i) : x(i) {}
-#ifndef SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG
- constexpr ConstexprType(std::initializer_list<int> il) : x(il.size()) {}
-#endif
- constexpr ConstexprType(const char* s) : x(-1) {} // NOLINT
- int x;
-};
-
-struct Copyable {
- Copyable() {}
- Copyable(const Copyable&) {}
- Copyable& operator=(const Copyable&) { return *this; }
-};
-
-struct MoveableThrow {
- MoveableThrow() {}
- MoveableThrow(MoveableThrow&&) {}
- MoveableThrow& operator=(MoveableThrow&&) { return *this; }
-};
-
-struct MoveableNoThrow {
- MoveableNoThrow() {}
- MoveableNoThrow(MoveableNoThrow&&) noexcept {}
- MoveableNoThrow& operator=(MoveableNoThrow&&) noexcept { return *this; }
-};
-
-struct NonMovable {
- NonMovable() {}
- NonMovable(const NonMovable&) = delete;
- NonMovable& operator=(const NonMovable&) = delete;
- NonMovable(NonMovable&&) = delete;
- NonMovable& operator=(NonMovable&&) = delete;
-};
-
-TEST(optionalTest, DefaultConstructor) {
- optional<int> empty;
- EXPECT_FALSE(!!empty);
- constexpr optional<int> cempty;
- static_assert(!cempty.has_value(), "");
- EXPECT_TRUE(std::is_nothrow_default_constructible<optional<int>>::value);
-}
-
-TEST(optionalTest, NullOptConstructor) {
- optional<int> empty(nullopt);
- EXPECT_FALSE(!!empty);
- // Creating a temporary nullopt_t object instead of using nullopt because
- // nullopt cannot be constexpr and have external linkage at the same time.
- constexpr optional<int> cempty{nullopt_t(nullopt_t::init)};
- static_assert(!cempty.has_value(), "");
- EXPECT_TRUE((std::is_nothrow_constructible<optional<int>, nullopt_t>::value));
-}
-
-TEST(optionalTest, CopyConstructor) {
- optional<int> empty, opt42 = 42;
- optional<int> empty_copy(empty);
- EXPECT_FALSE(!!empty_copy);
- optional<int> opt42_copy(opt42);
- EXPECT_TRUE(!!opt42_copy);
- EXPECT_EQ(42, opt42_copy);
- // test copyablility
- EXPECT_TRUE(std::is_copy_constructible<optional<int>>::value);
- EXPECT_TRUE(std::is_copy_constructible<optional<Copyable>>::value);
- EXPECT_FALSE(std::is_copy_constructible<optional<MoveableThrow>>::value);
- EXPECT_FALSE(std::is_copy_constructible<optional<MoveableNoThrow>>::value);
- EXPECT_FALSE(std::is_copy_constructible<optional<NonMovable>>::value);
-}
-
-TEST(optionalTest, MoveConstructor) {
- optional<int> empty, opt42 = 42;
- optional<int> empty_move(std::move(empty));
- EXPECT_FALSE(!!empty_move);
- optional<int> opt42_move(std::move(opt42));
- EXPECT_TRUE(!!opt42_move);
- EXPECT_EQ(42, opt42_move);
- // test movability
- EXPECT_TRUE(std::is_move_constructible<optional<int>>::value);
- EXPECT_TRUE(std::is_move_constructible<optional<Copyable>>::value);
- EXPECT_TRUE(std::is_move_constructible<optional<MoveableThrow>>::value);
- EXPECT_TRUE(std::is_move_constructible<optional<MoveableNoThrow>>::value);
- EXPECT_FALSE(std::is_move_constructible<optional<NonMovable>>::value);
- // test noexcept
- EXPECT_TRUE(std::is_nothrow_move_constructible<optional<int>>::value);
- EXPECT_FALSE(
- std::is_nothrow_move_constructible<optional<MoveableThrow>>::value);
- EXPECT_TRUE(
- std::is_nothrow_move_constructible<optional<MoveableNoThrow>>::value);
-}
-
-TEST(optionalTest, Destructor) {
- struct Trivial {};
-
- struct NonTrivial {
- ~NonTrivial() {}
- };
-
- EXPECT_TRUE(std::is_trivially_destructible<optional<int>>::value);
- EXPECT_TRUE(std::is_trivially_destructible<optional<Trivial>>::value);
- EXPECT_FALSE(std::is_trivially_destructible<optional<NonTrivial>>::value);
-}
-
-TEST(optionalTest, InPlaceConstructor) {
- constexpr optional<ConstexprType> opt0{in_place_t()};
- static_assert(opt0, "");
- static_assert(opt0->x == 0, "");
- constexpr optional<ConstexprType> opt1{in_place_t(), 1};
- static_assert(opt1, "");
- static_assert(opt1->x == 1, "");
-#ifndef SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG
- constexpr optional<ConstexprType> opt2{in_place_t(), {1, 2}};
- static_assert(opt2, "");
- static_assert(opt2->x == 2, "");
-#endif
-
- // TODO(b/34201852): uncomment these when std::is_constructible<T, Args&&...>
- // SFINAE is added to optional::optional(in_place_t, Args&&...).
- // struct I {
- // I(in_place_t);
- // };
-
- // EXPECT_FALSE((std::is_constructible<optional<I>, in_place_t>::value));
- // EXPECT_FALSE((std::is_constructible<optional<I>, const
- // in_place_t&>::value));
-}
-
-// template<U=T> optional(U&&);
-TEST(optionalTest, ValueConstructor) {
- constexpr optional<int> opt0(0);
- static_assert(opt0, "");
- static_assert(*opt0 == 0, "");
- EXPECT_TRUE((std::is_convertible<int, optional<int>>::value));
- // Copy initialization ( = "abc") won't work due to optional(optional&&)
- // is not constexpr. Use list initialization instead. This invokes
- // optional<ConstexprType>::optional<U>(U&&), with U = const char (&) [4],
- // which direct-initializes the ConstexprType value held by the optional
- // via ConstexprType::ConstexprType(const char*).
- constexpr optional<ConstexprType> opt1 = {"abc"};
- static_assert(opt1, "");
- static_assert(-1 == opt1->x, "");
- EXPECT_TRUE(
- (std::is_convertible<const char*, optional<ConstexprType>>::value));
- // direct initialization
- constexpr optional<ConstexprType> opt2{2};
- static_assert(opt2, "");
- static_assert(2 == opt2->x, "");
- EXPECT_FALSE((std::is_convertible<int, optional<ConstexprType>>::value));
-
- // this invokes optional<int>::optional(int&&)
- // NOTE: this has different behavior than assignment, e.g.
- // "opt3 = {};" clears the optional rather than setting the value to 0
- constexpr optional<int> opt3({});
- static_assert(opt3, "");
- static_assert(*opt3 == 0, "");
-
- // this invokes the move constructor with a default constructed optional
- // because non-template function is a better match than template function.
- optional<ConstexprType> opt4({});
- EXPECT_FALSE(!!opt4);
-}
-
-struct Implicit {};
-
-struct Explicit {};
-
-struct Convert {
- Convert(const Implicit&) // NOLINT(runtime/explicit)
- : implicit(true), move(false) {}
- Convert(Implicit&&) // NOLINT(runtime/explicit)
- : implicit(true), move(true) {}
- explicit Convert(const Explicit&) : implicit(false), move(false) {}
- explicit Convert(Explicit&&) : implicit(false), move(true) {}
-
- bool implicit;
- bool move;
-};
-
-struct ConvertFromOptional {
- ConvertFromOptional(const Implicit&) // NOLINT(runtime/explicit)
- : implicit(true), move(false), from_optional(false) {}
- ConvertFromOptional(Implicit&&) // NOLINT(runtime/explicit)
- : implicit(true), move(true), from_optional(false) {}
- ConvertFromOptional(const optional<Implicit>&) // NOLINT(runtime/explicit)
- : implicit(true), move(false), from_optional(true) {}
- ConvertFromOptional(optional<Implicit>&&) // NOLINT(runtime/explicit)
- : implicit(true), move(true), from_optional(true) {}
- explicit ConvertFromOptional(const Explicit&)
- : implicit(false), move(false), from_optional(false) {}
- explicit ConvertFromOptional(Explicit&&)
- : implicit(false), move(true), from_optional(false) {}
- explicit ConvertFromOptional(const optional<Explicit>&)
- : implicit(false), move(false), from_optional(true) {}
- explicit ConvertFromOptional(optional<Explicit>&&)
- : implicit(false), move(true), from_optional(true) {}
-
- bool implicit;
- bool move;
- bool from_optional;
-};
-
-TEST(optionalTest, ConvertingConstructor) {
- optional<Implicit> i_empty;
- optional<Implicit> i(in_place);
- optional<Explicit> e_empty;
- optional<Explicit> e(in_place);
- {
- // implicitly constructing optional<Convert> from optional<Implicit>
- optional<Convert> empty = i_empty;
- EXPECT_FALSE(!!empty);
- optional<Convert> opt_copy = i;
- EXPECT_TRUE(!!opt_copy);
- EXPECT_TRUE(opt_copy->implicit);
- EXPECT_FALSE(opt_copy->move);
- optional<Convert> opt_move = optional<Implicit>(in_place);
- EXPECT_TRUE(!!opt_move);
- EXPECT_TRUE(opt_move->implicit);
- EXPECT_TRUE(opt_move->move);
- }
- {
- // explicitly constructing optional<Convert> from optional<Explicit>
- optional<Convert> empty(e_empty);
- EXPECT_FALSE(!!empty);
- optional<Convert> opt_copy(e);
- EXPECT_TRUE(!!opt_copy);
- EXPECT_FALSE(opt_copy->implicit);
- EXPECT_FALSE(opt_copy->move);
- EXPECT_FALSE((std::is_convertible<const optional<Explicit>&,
- optional<Convert>>::value));
- optional<Convert> opt_move{optional<Explicit>(in_place)};
- EXPECT_TRUE(!!opt_move);
- EXPECT_FALSE(opt_move->implicit);
- EXPECT_TRUE(opt_move->move);
- EXPECT_FALSE(
- (std::is_convertible<optional<Explicit>&&, optional<Convert>>::value));
- }
- {
- // implicitly constructing optional<ConvertFromOptional> from
- // optional<Implicit> via ConvertFromOptional(optional<Implicit>&&)
- // check that ConvertFromOptional(Implicit&&) is NOT called
- static_assert(
- gtl::internal_optional::is_constructible_convertible_from_optional<
- ConvertFromOptional, Implicit>::value,
- "");
- optional<ConvertFromOptional> opt0 = i_empty;
- EXPECT_TRUE(!!opt0);
- EXPECT_TRUE(opt0->implicit);
- EXPECT_FALSE(opt0->move);
- EXPECT_TRUE(opt0->from_optional);
- optional<ConvertFromOptional> opt1 = optional<Implicit>();
- EXPECT_TRUE(!!opt1);
- EXPECT_TRUE(opt1->implicit);
- EXPECT_TRUE(opt1->move);
- EXPECT_TRUE(opt1->from_optional);
- }
- {
- // implicitly constructing optional<ConvertFromOptional> from
- // optional<Explicit> via ConvertFromOptional(optional<Explicit>&&)
- // check that ConvertFromOptional(Explicit&&) is NOT called
- optional<ConvertFromOptional> opt0(e_empty);
- EXPECT_TRUE(!!opt0);
- EXPECT_FALSE(opt0->implicit);
- EXPECT_FALSE(opt0->move);
- EXPECT_TRUE(opt0->from_optional);
- EXPECT_FALSE((std::is_convertible<const optional<Explicit>&,
- optional<ConvertFromOptional>>::value));
- optional<ConvertFromOptional> opt1{optional<Explicit>()};
- EXPECT_TRUE(!!opt1);
- EXPECT_FALSE(opt1->implicit);
- EXPECT_TRUE(opt1->move);
- EXPECT_TRUE(opt1->from_optional);
- EXPECT_FALSE((std::is_convertible<optional<Explicit>&&,
- optional<ConvertFromOptional>>::value));
- }
-}
-
-TEST(optionalTest, StructorBasic) {
- StructorListener listener;
- Listenable::listener = &listener;
- {
- optional<Listenable> empty;
- EXPECT_FALSE(!!empty);
- optional<Listenable> opt0(in_place);
- EXPECT_TRUE(!!opt0);
- optional<Listenable> opt1(in_place, 1);
- EXPECT_TRUE(!!opt1);
- optional<Listenable> opt2(in_place, 1, 2);
- EXPECT_TRUE(!!opt2);
- }
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.construct1);
- EXPECT_EQ(1, listener.construct2);
- EXPECT_EQ(3, listener.destruct);
-}
-
-TEST(optionalTest, CopyMoveStructor) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> original(in_place);
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(0, listener.copy);
- EXPECT_EQ(0, listener.move);
- optional<Listenable> copy(original);
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.copy);
- EXPECT_EQ(0, listener.move);
- optional<Listenable> move(std::move(original));
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.copy);
- EXPECT_EQ(1, listener.move);
-}
-
-TEST(optionalTest, ListInit) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> listinit1(in_place, {1});
- optional<Listenable> listinit2(in_place, {1, 2});
- EXPECT_EQ(2, listener.listinit);
-}
-
-TEST(optionalTest, AssignFromNullopt) {
- optional<int> opt(1);
- opt = nullopt;
- EXPECT_FALSE(!!opt);
-
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> opt1(in_place);
- opt1 = nullopt;
- EXPECT_FALSE(opt1);
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.destruct);
-
- EXPECT_TRUE((std::is_nothrow_assignable<optional<int>, nullopt_t>::value));
- EXPECT_TRUE(
- (std::is_nothrow_assignable<optional<Listenable>, nullopt_t>::value));
-}
-
-TEST(optionalTest, CopyAssignment) {
- const optional<int> empty, opt1 = 1, opt2 = 2;
- optional<int> empty_to_opt1, opt1_to_opt2, opt2_to_empty;
-
- EXPECT_FALSE(!!empty_to_opt1);
- empty_to_opt1 = empty;
- EXPECT_FALSE(!!empty_to_opt1);
- empty_to_opt1 = opt1;
- EXPECT_TRUE(!!empty_to_opt1);
- EXPECT_EQ(1, empty_to_opt1.value());
-
- EXPECT_FALSE(!!opt1_to_opt2);
- opt1_to_opt2 = opt1;
- EXPECT_TRUE(!!opt1_to_opt2);
- EXPECT_EQ(1, opt1_to_opt2.value());
- opt1_to_opt2 = opt2;
- EXPECT_TRUE(!!opt1_to_opt2);
- EXPECT_EQ(2, opt1_to_opt2.value());
-
- EXPECT_FALSE(!!opt2_to_empty);
- opt2_to_empty = opt2;
- EXPECT_TRUE(!!opt2_to_empty);
- EXPECT_EQ(2, opt2_to_empty.value());
- opt2_to_empty = empty;
- EXPECT_FALSE(!!opt2_to_empty);
-
- EXPECT_TRUE(std::is_copy_assignable<optional<Copyable>>::value);
- EXPECT_FALSE(std::is_copy_assignable<optional<MoveableThrow>>::value);
- EXPECT_FALSE(std::is_copy_assignable<optional<MoveableNoThrow>>::value);
- EXPECT_FALSE(std::is_copy_assignable<optional<NonMovable>>::value);
-}
-
-TEST(optionalTest, MoveAssignment) {
- StructorListener listener;
- Listenable::listener = &listener;
-
- optional<Listenable> empty1, empty2, set1(in_place), set2(in_place);
- EXPECT_EQ(2, listener.construct0);
- optional<Listenable> empty_to_empty, empty_to_set, set_to_empty(in_place),
- set_to_set(in_place);
- EXPECT_EQ(4, listener.construct0);
- empty_to_empty = std::move(empty1);
- empty_to_set = std::move(set1);
- set_to_empty = std::move(empty2);
- set_to_set = std::move(set2);
- EXPECT_EQ(0, listener.copy);
- EXPECT_EQ(1, listener.move);
- EXPECT_EQ(1, listener.destruct);
- EXPECT_EQ(1, listener.move_assign);
-
- EXPECT_TRUE(std::is_move_assignable<optional<Copyable>>::value);
- EXPECT_TRUE(std::is_move_assignable<optional<MoveableThrow>>::value);
- EXPECT_TRUE(std::is_move_assignable<optional<MoveableNoThrow>>::value);
- EXPECT_FALSE(std::is_move_assignable<optional<NonMovable>>::value);
-
- EXPECT_FALSE(std::is_nothrow_move_assignable<optional<MoveableThrow>>::value);
- EXPECT_TRUE(
- std::is_nothrow_move_assignable<optional<MoveableNoThrow>>::value);
-}
-
-struct NoConvertToOptional {
- // disable implicit conversion from const NoConvertToOptional&
- // to optional<NoConvertToOptional>.
- NoConvertToOptional(const NoConvertToOptional&) = delete;
-};
-
-struct CopyConvert {
- CopyConvert(const NoConvertToOptional&);
- CopyConvert& operator=(const CopyConvert&) = delete;
- CopyConvert& operator=(const NoConvertToOptional&);
-};
-
-struct CopyConvertFromOptional {
- CopyConvertFromOptional(const NoConvertToOptional&);
- CopyConvertFromOptional(const optional<NoConvertToOptional>&);
- CopyConvertFromOptional& operator=(const CopyConvertFromOptional&) = delete;
- CopyConvertFromOptional& operator=(const NoConvertToOptional&);
- CopyConvertFromOptional& operator=(const optional<NoConvertToOptional>&);
-};
-
-struct MoveConvert {
- MoveConvert(NoConvertToOptional&&);
- MoveConvert& operator=(const MoveConvert&) = delete;
- MoveConvert& operator=(NoConvertToOptional&&);
-};
-
-struct MoveConvertFromOptional {
- MoveConvertFromOptional(NoConvertToOptional&&);
- MoveConvertFromOptional(optional<NoConvertToOptional>&&);
- MoveConvertFromOptional& operator=(const MoveConvertFromOptional&) = delete;
- MoveConvertFromOptional& operator=(NoConvertToOptional&&);
- MoveConvertFromOptional& operator=(optional<NoConvertToOptional>&&);
-};
-
-// template <class U = T> optional<T>& operator=(U&& v);
-TEST(optionalTest, ValueAssignment) {
- optional<int> opt;
- EXPECT_FALSE(!!opt);
- opt = 42;
- EXPECT_TRUE(!!opt);
- EXPECT_EQ(42, opt.value());
- opt = nullopt;
- EXPECT_FALSE(!!opt);
- opt = 42;
- EXPECT_TRUE(!!opt);
- EXPECT_EQ(42, opt.value());
- opt = 43;
- EXPECT_TRUE(!!opt);
- EXPECT_EQ(43, opt.value());
- opt = {}; // this should clear optional
- EXPECT_FALSE(!!opt);
-
- opt = {44};
- EXPECT_TRUE(!!opt);
- EXPECT_EQ(44, opt.value());
-
- // U = const NoConvertToOptional&
- EXPECT_TRUE((std::is_assignable<optional<CopyConvert>&,
- const NoConvertToOptional&>::value));
- // U = const optional<NoConvertToOptional>&
- EXPECT_TRUE((std::is_assignable<optional<CopyConvertFromOptional>&,
- const NoConvertToOptional&>::value));
- // U = const NoConvertToOptional& triggers SFINAE because
- // std::is_constructible_v<MoveConvert, const NoConvertToOptional&> is false
- EXPECT_FALSE((std::is_assignable<optional<MoveConvert>&,
- const NoConvertToOptional&>::value));
- // U = NoConvertToOptional
- EXPECT_TRUE((std::is_assignable<optional<MoveConvert>&,
- NoConvertToOptional&&>::value));
- // U = const NoConvertToOptional& triggers SFINAE because
- // std::is_constructible_v<MoveConvertFromOptional, const
- // NoConvertToOptional&> is false
- EXPECT_FALSE((std::is_assignable<optional<MoveConvertFromOptional>&,
- const NoConvertToOptional&>::value));
- // U = NoConvertToOptional
- EXPECT_TRUE((std::is_assignable<optional<MoveConvertFromOptional>&,
- NoConvertToOptional&&>::value));
- // U = const optional<NoConvertToOptional>&
- EXPECT_TRUE(
- (std::is_assignable<optional<CopyConvertFromOptional>&,
- const optional<NoConvertToOptional>&>::value));
- // U = optional<NoConvertToOptional>
- EXPECT_TRUE((std::is_assignable<optional<MoveConvertFromOptional>&,
- optional<NoConvertToOptional>&&>::value));
-}
-
-// template <class U> optional<T>& operator=(const optional<U>& rhs);
-// template <class U> optional<T>& operator=(optional<U>&& rhs);
-TEST(optionalTest, ConvertingAssignment) {
- optional<int> opt_i;
- optional<char> opt_c('c');
- opt_i = opt_c;
- EXPECT_TRUE(!!opt_i);
- EXPECT_EQ(*opt_c, *opt_i);
- opt_i = optional<char>();
- EXPECT_FALSE(!!opt_i);
- opt_i = optional<char>('d');
- EXPECT_TRUE(!!opt_i);
- EXPECT_EQ('d', *opt_i);
-
- optional<string> opt_str;
- optional<const char*> opt_cstr("abc");
- opt_str = opt_cstr;
- EXPECT_TRUE(!!opt_str);
- EXPECT_EQ(string("abc"), *opt_str);
- opt_str = optional<const char*>();
- EXPECT_FALSE(!!opt_str);
- opt_str = optional<const char*>("def");
- EXPECT_TRUE(!!opt_str);
- EXPECT_EQ(string("def"), *opt_str);
-
- // operator=(const optional<U>&) with U = NoConvertToOptional
- EXPECT_TRUE(
- (std::is_assignable<optional<CopyConvert>,
- const optional<NoConvertToOptional>&>::value));
- // operator=(const optional<U>&) with U = NoConvertToOptional
- // triggers SFINAE because
- // std::is_constructible_v<MoveConvert, const NoConvertToOptional&> is false
- EXPECT_FALSE(
- (std::is_assignable<optional<MoveConvert>&,
- const optional<NoConvertToOptional>&>::value));
- // operator=(optional<U>&&) with U = NoConvertToOptional
- EXPECT_TRUE((std::is_assignable<optional<MoveConvert>&,
- optional<NoConvertToOptional>&&>::value));
- // operator=(const optional<U>&) with U = NoConvertToOptional triggers SFINAE
- // because std::is_constructible_v<MoveConvertFromOptional,
- // const NoConvertToOptional&> is false.
- // operator=(U&&) with U = const optional<NoConverToOptional>& triggers SFINAE
- // because std::is_constructible<MoveConvertFromOptional,
- // optional<NoConvertToOptional>&&> is true.
- EXPECT_FALSE(
- (std::is_assignable<optional<MoveConvertFromOptional>&,
- const optional<NoConvertToOptional>&>::value));
-}
-
-TEST(optionalTest, ResetAndHasValue) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> opt;
- EXPECT_FALSE(!!opt);
- EXPECT_FALSE(opt.has_value());
- opt.emplace();
- EXPECT_TRUE(!!opt);
- EXPECT_TRUE(opt.has_value());
- opt.reset();
- EXPECT_FALSE(!!opt);
- EXPECT_FALSE(opt.has_value());
- EXPECT_EQ(1, listener.destruct);
- opt.reset();
- EXPECT_FALSE(!!opt);
- EXPECT_FALSE(opt.has_value());
-
- constexpr optional<int> empty;
- static_assert(!empty.has_value(), "");
- constexpr optional<int> nonempty(1);
- static_assert(nonempty.has_value(), "");
-}
-
-TEST(optionalTest, Emplace) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> opt;
- EXPECT_FALSE(!!opt);
- opt.emplace(1);
- EXPECT_TRUE(!!opt);
- opt.emplace(1, 2);
- EXPECT_EQ(1, listener.construct1);
- EXPECT_EQ(1, listener.construct2);
- EXPECT_EQ(1, listener.destruct);
-}
-
-TEST(optionalTest, ListEmplace) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> opt;
- EXPECT_FALSE(!!opt);
- opt.emplace({1});
- EXPECT_TRUE(!!opt);
- opt.emplace({1, 2});
- EXPECT_EQ(2, listener.listinit);
- EXPECT_EQ(1, listener.destruct);
-}
-
-TEST(optionalTest, Swap) {
- optional<int> opt_empty, opt1 = 1, opt2 = 2;
- EXPECT_FALSE(!!opt_empty);
- EXPECT_TRUE(!!opt1);
- EXPECT_EQ(1, opt1.value());
- EXPECT_TRUE(!!opt2);
- EXPECT_EQ(2, opt2.value());
- swap(opt_empty, opt1);
- EXPECT_FALSE(!!opt1);
- EXPECT_TRUE(!!opt_empty);
- EXPECT_EQ(1, opt_empty.value());
- EXPECT_TRUE(!!opt2);
- EXPECT_EQ(2, opt2.value());
- swap(opt_empty, opt1);
- EXPECT_FALSE(!!opt_empty);
- EXPECT_TRUE(!!opt1);
- EXPECT_EQ(1, opt1.value());
- EXPECT_TRUE(!!opt2);
- EXPECT_EQ(2, opt2.value());
- swap(opt1, opt2);
- EXPECT_FALSE(!!opt_empty);
- EXPECT_TRUE(!!opt1);
- EXPECT_EQ(2, opt1.value());
- EXPECT_TRUE(!!opt2);
- EXPECT_EQ(1, opt2.value());
-
- EXPECT_TRUE(noexcept(opt1.swap(opt2)));
- EXPECT_TRUE(noexcept(swap(opt1, opt2)));
-}
-
-TEST(optionalTest, PointerStuff) {
- optional<string> opt(in_place, "foo");
- EXPECT_EQ("foo", *opt);
- const auto& opt_const = opt;
- EXPECT_EQ("foo", *opt_const);
- EXPECT_EQ(opt->size(), 3);
- EXPECT_EQ(opt_const->size(), 3);
-
- constexpr optional<ConstexprType> opt1(1);
- static_assert(opt1->x == 1, "");
-}
-
-// gcc has a bug pre 4.9 where it doesn't do correct overload resolution
-// between rvalue reference qualified member methods. Skip that test to make
-// the build green again when using the old compiler.
-#if defined(__GNUC__) && !defined(__clang__)
-#if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 9)
-#define SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG
-#endif
-#endif
-
-TEST(optionalTest, Value) {
- using O = optional<string>;
- using CO = const optional<string>;
- O lvalue(in_place, "lvalue");
- CO clvalue(in_place, "clvalue");
- EXPECT_EQ("lvalue", lvalue.value());
- EXPECT_EQ("clvalue", clvalue.value());
- EXPECT_EQ("xvalue", O(in_place, "xvalue").value());
-#ifndef SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG
- EXPECT_EQ("cxvalue", CO(in_place, "cxvalue").value());
- EXPECT_EQ("&", TypeQuals(lvalue.value()));
- EXPECT_EQ("c&", TypeQuals(clvalue.value()));
- EXPECT_EQ("&&", TypeQuals(O(in_place, "xvalue").value()));
- EXPECT_EQ("c&&", TypeQuals(CO(in_place, "cxvalue").value()));
-#endif
-}
-
-TEST(optionalTest, DerefOperator) {
- using O = optional<string>;
- using CO = const optional<string>;
- O lvalue(in_place, "lvalue");
- CO clvalue(in_place, "clvalue");
- EXPECT_EQ("lvalue", *lvalue);
- EXPECT_EQ("clvalue", *clvalue);
- EXPECT_EQ("xvalue", *O(in_place, "xvalue"));
-#ifndef SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG
- EXPECT_EQ("cxvalue", *CO(in_place, "cxvalue"));
- EXPECT_EQ("&", TypeQuals(*lvalue));
- EXPECT_EQ("c&", TypeQuals(*clvalue));
- EXPECT_EQ("&&", TypeQuals(*O(in_place, "xvalue")));
- EXPECT_EQ("c&&", TypeQuals(*CO(in_place, "cxvalue")));
-#endif
-
- constexpr optional<int> opt1(1);
- static_assert(*opt1 == 1, "");
-
-#if !defined(SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG) && \
- !defined(SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG)
- using COI = const optional<int>;
- static_assert(*COI(2) == 2, "");
-#endif
-}
-
-TEST(optionalTest, ValueOr) {
- optional<double> opt_empty, opt_set = 1.2;
- EXPECT_EQ(42.0, opt_empty.value_or(42));
- EXPECT_EQ(1.2, opt_set.value_or(42));
- EXPECT_EQ(42.0, optional<double>().value_or(42));
- EXPECT_EQ(1.2, optional<double>(1.2).value_or(42));
-
-#ifndef SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG
- constexpr optional<double> copt_empty;
- static_assert(42.0 == copt_empty.value_or(42), "");
-
- constexpr optional<double> copt_set = {1.2};
- static_assert(1.2 == copt_set.value_or(42), "");
-
- using COD = const optional<double>;
- static_assert(42.0 == COD().value_or(42), "");
- static_assert(1.2 == COD(1.2).value_or(42), "");
-#endif
-}
-
-// make_optional cannot be constexpr until C++17
-TEST(optionalTest, make_optional) {
- auto opt_int = make_optional(42);
- EXPECT_TRUE((std::is_same<decltype(opt_int), optional<int>>::value));
- EXPECT_EQ(42, opt_int);
-
- StructorListener listener;
- Listenable::listener = &listener;
-
- optional<Listenable> opt0 = make_optional<Listenable>();
- EXPECT_EQ(1, listener.construct0);
- optional<Listenable> opt1 = make_optional<Listenable>(1);
- EXPECT_EQ(1, listener.construct1);
- optional<Listenable> opt2 = make_optional<Listenable>(1, 2);
- EXPECT_EQ(1, listener.construct2);
- optional<Listenable> opt3 = make_optional<Listenable>({1});
- optional<Listenable> opt4 = make_optional<Listenable>({1, 2});
- EXPECT_EQ(2, listener.listinit);
-}
-
-TEST(optionalTest, Comparisons) {
- optional<int> ae, be, a2 = 2, b2 = 2, a4 = 4, b4 = 4;
-
-#define optionalTest_Comparisons_EXPECT_LESS(x, y) \
- EXPECT_FALSE((x) == (y)); \
- EXPECT_TRUE((x) != (y)); \
- EXPECT_TRUE((x) < (y)); \
- EXPECT_FALSE((x) > (y)); \
- EXPECT_TRUE((x) <= (y)); \
- EXPECT_FALSE((x) >= (y));
-
-#define optionalTest_Comparisons_EXPECT_SAME(x, y) \
- EXPECT_TRUE((x) == (y)); \
- EXPECT_FALSE((x) != (y)); \
- EXPECT_FALSE((x) < (y)); \
- EXPECT_FALSE((x) > (y)); \
- EXPECT_TRUE((x) <= (y)); \
- EXPECT_TRUE((x) >= (y));
-
-#define optionalTest_Comparisons_EXPECT_GREATER(x, y) \
- EXPECT_FALSE((x) == (y)); \
- EXPECT_TRUE((x) != (y)); \
- EXPECT_FALSE((x) < (y)); \
- EXPECT_TRUE((x) > (y)); \
- EXPECT_FALSE((x) <= (y)); \
- EXPECT_TRUE((x) >= (y));
-
- // LHS: nullopt, ae, a2, 3, a4
- // RHS: nullopt, be, b2, 3, b4
-
- // optionalTest_Comparisons_EXPECT_NOT_TO_WORK(nullopt,nullopt);
- optionalTest_Comparisons_EXPECT_SAME(nullopt, be);
- optionalTest_Comparisons_EXPECT_LESS(nullopt, b2);
- // optionalTest_Comparisons_EXPECT_NOT_TO_WORK(nullopt,3);
- optionalTest_Comparisons_EXPECT_LESS(nullopt, b4);
-
- optionalTest_Comparisons_EXPECT_SAME(ae, nullopt);
- optionalTest_Comparisons_EXPECT_SAME(ae, be);
- optionalTest_Comparisons_EXPECT_LESS(ae, b2);
- optionalTest_Comparisons_EXPECT_LESS(ae, 3);
- optionalTest_Comparisons_EXPECT_LESS(ae, b4);
-
- optionalTest_Comparisons_EXPECT_GREATER(a2, nullopt);
- optionalTest_Comparisons_EXPECT_GREATER(a2, be);
- optionalTest_Comparisons_EXPECT_SAME(a2, b2);
- optionalTest_Comparisons_EXPECT_LESS(a2, 3);
- optionalTest_Comparisons_EXPECT_LESS(a2, b4);
-
- // optionalTest_Comparisons_EXPECT_NOT_TO_WORK(3,nullopt);
- optionalTest_Comparisons_EXPECT_GREATER(3, be);
- optionalTest_Comparisons_EXPECT_GREATER(3, b2);
- optionalTest_Comparisons_EXPECT_SAME(3, 3);
- optionalTest_Comparisons_EXPECT_LESS(3, b4);
-
- optionalTest_Comparisons_EXPECT_GREATER(a4, nullopt);
- optionalTest_Comparisons_EXPECT_GREATER(a4, be);
- optionalTest_Comparisons_EXPECT_GREATER(a4, b2);
- optionalTest_Comparisons_EXPECT_GREATER(a4, 3);
- optionalTest_Comparisons_EXPECT_SAME(a4, b4);
-}
-
-TEST(optionalTest, SwapRegression) {
- StructorListener listener;
- Listenable::listener = &listener;
-
- {
- optional<Listenable> a;
- optional<Listenable> b(in_place);
- a.swap(b);
- }
-
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.move);
- EXPECT_EQ(2, listener.destruct);
-
- {
- optional<Listenable> a(in_place);
- optional<Listenable> b;
- a.swap(b);
- }
-
- EXPECT_EQ(2, listener.construct0);
- EXPECT_EQ(2, listener.move);
- EXPECT_EQ(4, listener.destruct);
-}
-
-TEST(optionalTest, BigStringLeakCheck) {
- constexpr size_t n = 1 << 16;
-
- using OS = optional<string>;
-
- OS a;
- OS b = nullopt;
- OS c = string(n, 'c');
- string sd(n, 'd');
- OS d = sd;
- OS e(in_place, n, 'e');
- OS f;
- f.emplace(n, 'f');
-
- OS ca(a);
- OS cb(b);
- OS cc(c);
- OS cd(d);
- OS ce(e);
-
- OS oa;
- OS ob = nullopt;
- OS oc = string(n, 'c');
- string sod(n, 'd');
- OS od = sod;
- OS oe(in_place, n, 'e');
- OS of;
- of.emplace(n, 'f');
-
- OS ma(std::move(oa));
- OS mb(std::move(ob));
- OS mc(std::move(oc));
- OS md(std::move(od));
- OS me(std::move(oe));
- OS mf(std::move(of));
-
- OS aa1;
- OS ab1 = nullopt;
- OS ac1 = string(n, 'c');
- string sad1(n, 'd');
- OS ad1 = sad1;
- OS ae1(in_place, n, 'e');
- OS af1;
- af1.emplace(n, 'f');
-
- OS aa2;
- OS ab2 = nullopt;
- OS ac2 = string(n, 'c');
- string sad2(n, 'd');
- OS ad2 = sad2;
- OS ae2(in_place, n, 'e');
- OS af2;
- af2.emplace(n, 'f');
-
- aa1 = af2;
- ab1 = ae2;
- ac1 = ad2;
- ad1 = ac2;
- ae1 = ab2;
- af1 = aa2;
-
- OS aa3;
- OS ab3 = nullopt;
- OS ac3 = string(n, 'c');
- string sad3(n, 'd');
- OS ad3 = sad3;
- OS ae3(in_place, n, 'e');
- OS af3;
- af3.emplace(n, 'f');
-
- aa3 = nullopt;
- ab3 = nullopt;
- ac3 = nullopt;
- ad3 = nullopt;
- ae3 = nullopt;
- af3 = nullopt;
-
- OS aa4;
- OS ab4 = nullopt;
- OS ac4 = string(n, 'c');
- string sad4(n, 'd');
- OS ad4 = sad4;
- OS ae4(in_place, n, 'e');
- OS af4;
- af4.emplace(n, 'f');
-
- aa4 = OS(in_place, n, 'a');
- ab4 = OS(in_place, n, 'b');
- ac4 = OS(in_place, n, 'c');
- ad4 = OS(in_place, n, 'd');
- ae4 = OS(in_place, n, 'e');
- af4 = OS(in_place, n, 'f');
-
- OS aa5;
- OS ab5 = nullopt;
- OS ac5 = string(n, 'c');
- string sad5(n, 'd');
- OS ad5 = sad5;
- OS ae5(in_place, n, 'e');
- OS af5;
- af5.emplace(n, 'f');
-
- string saa5(n, 'a');
- string sab5(n, 'a');
- string sac5(n, 'a');
- string sad52(n, 'a');
- string sae5(n, 'a');
- string saf5(n, 'a');
-
- aa5 = saa5;
- ab5 = sab5;
- ac5 = sac5;
- ad5 = sad52;
- ae5 = sae5;
- af5 = saf5;
-
- OS aa6;
- OS ab6 = nullopt;
- OS ac6 = string(n, 'c');
- string sad6(n, 'd');
- OS ad6 = sad6;
- OS ae6(in_place, n, 'e');
- OS af6;
- af6.emplace(n, 'f');
-
- aa6 = string(n, 'a');
- ab6 = string(n, 'b');
- ac6 = string(n, 'c');
- ad6 = string(n, 'd');
- ae6 = string(n, 'e');
- af6 = string(n, 'f');
-
- OS aa7;
- OS ab7 = nullopt;
- OS ac7 = string(n, 'c');
- string sad7(n, 'd');
- OS ad7 = sad7;
- OS ae7(in_place, n, 'e');
- OS af7;
- af7.emplace(n, 'f');
-
- aa7.emplace(n, 'A');
- ab7.emplace(n, 'B');
- ac7.emplace(n, 'C');
- ad7.emplace(n, 'D');
- ae7.emplace(n, 'E');
- af7.emplace(n, 'F');
-}
-
-TEST(optionalTest, MoveAssignRegression) {
- StructorListener listener;
- Listenable::listener = &listener;
-
- {
- optional<Listenable> a;
- Listenable b;
- a = std::move(b);
- }
-
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.move);
- EXPECT_EQ(2, listener.destruct);
-}
-
-TEST(optionalTest, ValueType) {
- EXPECT_TRUE((std::is_same<optional<int>::value_type, int>::value));
- EXPECT_TRUE((std::is_same<optional<string>::value_type, string>::value));
- EXPECT_FALSE((std::is_same<optional<int>::value_type, nullopt_t>::value));
-}
-
-TEST(optionalTest, Hash) {
- std::hash<optional<int>> hash;
- std::set<size_t> hashcodes;
- hashcodes.insert(hash(nullopt));
- for (int i = 0; i < 100; ++i) {
- hashcodes.insert(hash(i));
- }
- EXPECT_GT(hashcodes.size(), 90);
-}
-
-struct MoveMeNoThrow {
- MoveMeNoThrow() : x(0) {}
- MoveMeNoThrow(const MoveMeNoThrow& other) : x(other.x) {
- LOG(FATAL) << "Should not be called.";
- }
- MoveMeNoThrow(MoveMeNoThrow&& other) noexcept : x(other.x) {}
- int x;
-};
-
-struct MoveMeThrow {
- MoveMeThrow() : x(0) {}
- MoveMeThrow(const MoveMeThrow& other) : x(other.x) {}
- MoveMeThrow(MoveMeThrow&& other) : x(other.x) {}
- int x;
-};
-
-TEST(optionalTest, NoExcept) {
- static_assert(
- std::is_nothrow_move_constructible<optional<MoveMeNoThrow>>::value, "");
- static_assert(
- !std::is_nothrow_move_constructible<optional<MoveMeThrow>>::value, "");
- std::vector<optional<MoveMeNoThrow>> v;
- v.reserve(10);
- for (int i = 0; i < 10; ++i) v.emplace_back();
-}
-
-} // namespace
-} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h
index 5ae3d220e3..a620f59447 100644
--- a/tensorflow/core/lib/strings/strcat.h
+++ b/tensorflow/core/lib/strings/strcat.h
@@ -59,29 +59,29 @@ namespace tensorflow {
namespace strings {
enum PadSpec {
- NO_PAD = 1,
- ZERO_PAD_2,
- ZERO_PAD_3,
- ZERO_PAD_4,
- ZERO_PAD_5,
- ZERO_PAD_6,
- ZERO_PAD_7,
- ZERO_PAD_8,
- ZERO_PAD_9,
- ZERO_PAD_10,
- ZERO_PAD_11,
- ZERO_PAD_12,
- ZERO_PAD_13,
- ZERO_PAD_14,
- ZERO_PAD_15,
- ZERO_PAD_16,
+ kNoPad = 1,
+ kZeroPad2,
+ kZeroPad3,
+ kZeroPad4,
+ kZeroPad5,
+ kZeroPad6,
+ kZeroPad7,
+ kZeroPad8,
+ kZeroPad9,
+ kZeroPad10,
+ kZeroPad11,
+ kZeroPad12,
+ kZeroPad13,
+ kZeroPad14,
+ kZeroPad15,
+ kZeroPad16
};
struct Hex {
uint64 value;
enum PadSpec spec;
template <class Int>
- explicit Hex(Int v, PadSpec s = NO_PAD) : spec(s) {
+ explicit Hex(Int v, PadSpec s = kNoPad) : spec(s) {
// Prevent sign-extension by casting integers to
// their unsigned counterparts.
static_assert(
@@ -124,6 +124,9 @@ class AlphaNum {
AlphaNum(const StringPiece &pc) : piece_(pc) {} // NOLINT(runtime/explicit)
AlphaNum(const tensorflow::string &str) // NOLINT(runtime/explicit)
: piece_(str) {}
+ template <typename A>
+ AlphaNum(const std::basic_string<char, std::char_traits<char>, A> &str)
+ : piece_(str) {} // NOLINT(runtime/explicit)
StringPiece::size_type size() const { return piece_.size(); }
const char *data() const { return piece_.data(); }
diff --git a/tensorflow/core/lib/strings/strcat_test.cc b/tensorflow/core/lib/strings/strcat_test.cc
index 8cc64a6f0a..6c4e5526b1 100644
--- a/tensorflow/core/lib/strings/strcat_test.cc
+++ b/tensorflow/core/lib/strings/strcat_test.cc
@@ -308,11 +308,11 @@ TEST(StrAppend, Death) {
static void CheckHex64(uint64 v) {
using tensorflow::strings::Hex;
- string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_16));
+ string actual = StrCat(Hex(v, tensorflow::strings::kZeroPad16));
string expected = Printf("%016llx", static_cast<unsigned long long>(v));
EXPECT_EQ(expected, actual) << " decimal value " << v;
- actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8));
+ actual = StrCat(Hex(v, tensorflow::strings::kZeroPad8));
expected = Printf("%08llx", static_cast<unsigned long long>(v));
EXPECT_EQ(expected, actual) << " decimal value " << v;
@@ -323,7 +323,7 @@ static void CheckHex64(uint64 v) {
static void CheckHex32(uint32 v) {
using tensorflow::strings::Hex;
- string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8));
+ string actual = StrCat(Hex(v, tensorflow::strings::kZeroPad8));
string expected = Printf("%08x", v);
EXPECT_EQ(expected, actual) << " decimal value " << v;
@@ -334,7 +334,7 @@ static void CheckHex32(uint32 v) {
static void CheckHexSigned32(int32 v) {
using tensorflow::strings::Hex;
- string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8));
+ string actual = StrCat(Hex(v, tensorflow::strings::kZeroPad8));
string expected = Printf("%08x", v);
EXPECT_EQ(expected, actual) << " decimal value " << v;
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 82e4831e00..cb0cb46752 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -22505,33 +22505,6 @@ op {
is_stateful: true
}
op {
- name: "FeatureStatsDataset"
- input_arg {
- name: "input_dataset"
- type: DT_VARIANT
- }
- input_arg {
- name: "tag"
- type: DT_STRING
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
-}
-op {
name: "Fill"
input_arg {
name: "dims"
@@ -37424,6 +37397,201 @@ op {
}
}
op {
+ name: "ParseSequenceExample"
+ input_arg {
+ name: "serialized"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "debug_name"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "context_dense_defaults"
+ type_list_attr: "Tcontext_dense"
+ }
+ output_arg {
+ name: "context_sparse_indices"
+ type: DT_INT64
+ number_attr: "Ncontext_sparse"
+ }
+ output_arg {
+ name: "context_sparse_values"
+ type_list_attr: "context_sparse_types"
+ }
+ output_arg {
+ name: "context_sparse_shapes"
+ type: DT_INT64
+ number_attr: "Ncontext_sparse"
+ }
+ output_arg {
+ name: "context_dense_values"
+ type_list_attr: "Tcontext_dense"
+ }
+ output_arg {
+ name: "feature_list_sparse_indices"
+ type: DT_INT64
+ number_attr: "Nfeature_list_sparse"
+ }
+ output_arg {
+ name: "feature_list_sparse_values"
+ type_list_attr: "feature_list_sparse_types"
+ }
+ output_arg {
+ name: "feature_list_sparse_shapes"
+ type: DT_INT64
+ number_attr: "Nfeature_list_sparse"
+ }
+ output_arg {
+ name: "feature_list_dense_values"
+ type_list_attr: "feature_list_dense_types"
+ }
+ output_arg {
+ name: "feature_list_dense_lengths"
+ type: DT_INT64
+ number_attr: "Nfeature_list_dense"
+ }
+ attr {
+ name: "feature_list_dense_missing_assumed_empty"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "context_sparse_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "context_dense_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "feature_list_sparse_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "feature_list_dense_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "Ncontext_sparse"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "Ncontext_dense"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "Nfeature_list_sparse"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "Nfeature_list_dense"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "context_sparse_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "Tcontext_dense"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "feature_list_dense_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "context_dense_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "feature_list_sparse_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "feature_list_dense_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+}
+op {
name: "ParseSingleExample"
input_arg {
name: "serialized"
@@ -56497,6 +56665,125 @@ op {
}
}
op {
+ name: "SdcaOptimizer"
+ input_arg {
+ name: "sparse_example_indices"
+ type: DT_INT64
+ number_attr: "num_sparse_features"
+ }
+ input_arg {
+ name: "sparse_feature_indices"
+ type: DT_INT64
+ number_attr: "num_sparse_features"
+ }
+ input_arg {
+ name: "sparse_feature_values"
+ type: DT_FLOAT
+ number_attr: "num_sparse_features_with_values"
+ }
+ input_arg {
+ name: "dense_features"
+ type: DT_FLOAT
+ number_attr: "num_dense_features"
+ }
+ input_arg {
+ name: "example_weights"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "example_labels"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "sparse_indices"
+ type: DT_INT64
+ number_attr: "num_sparse_features"
+ }
+ input_arg {
+ name: "sparse_weights"
+ type: DT_FLOAT
+ number_attr: "num_sparse_features"
+ }
+ input_arg {
+ name: "dense_weights"
+ type: DT_FLOAT
+ number_attr: "num_dense_features"
+ }
+ input_arg {
+ name: "example_state_data"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "out_example_state_data"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "out_delta_sparse_weights"
+ type: DT_FLOAT
+ number_attr: "num_sparse_features"
+ }
+ output_arg {
+ name: "out_delta_dense_weights"
+ type: DT_FLOAT
+ number_attr: "num_dense_features"
+ }
+ attr {
+ name: "loss_type"
+ type: "string"
+ allowed_values {
+ list {
+ s: "logistic_loss"
+ s: "squared_loss"
+ s: "hinge_loss"
+ s: "smooth_hinge_loss"
+ s: "poisson_loss"
+ }
+ }
+ }
+ attr {
+ name: "adaptative"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "num_sparse_features"
+ type: "int"
+ has_minimum: true
+ }
+ attr {
+ name: "num_sparse_features_with_values"
+ type: "int"
+ has_minimum: true
+ }
+ attr {
+ name: "num_dense_features"
+ type: "int"
+ has_minimum: true
+ }
+ attr {
+ name: "l1"
+ type: "float"
+ }
+ attr {
+ name: "l2"
+ type: "float"
+ }
+ attr {
+ name: "num_loss_partitions"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "num_inner_iterations"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "SdcaShrinkL1"
input_arg {
name: "weights"
@@ -71867,6 +72154,25 @@ op {
}
}
op {
+ name: "TensorListGather"
+ input_arg {
+ name: "input_handle"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "indices"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "values"
+ type_attr: "element_dtype"
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+}
+op {
name: "TensorListGetItem"
input_arg {
name: "input_handle"
@@ -71983,6 +72289,39 @@ op {
}
}
op {
+ name: "TensorListScatter"
+ input_arg {
+ name: "tensor"
+ type_attr: "element_dtype"
+ }
+ input_arg {
+ name: "indices"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "element_shape"
+ type_attr: "shape_type"
+ }
+ output_arg {
+ name: "output_handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+ attr {
+ name: "shape_type"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "TensorListSetItem"
input_arg {
name: "input_handle"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 41f5f9aebe..f03639e833 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -182,18 +182,6 @@ REGISTER_OP("ParseExampleDataset")
// sparse_keys combined) here.
.SetShapeFn(shape_inference::ScalarShape);
-REGISTER_OP("FeatureStatsDataset")
- .Input("input_dataset: variant")
- .Input("tag: string")
- .Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle tag_shape;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
- return shape_inference::ScalarShape(c);
- });
-
REGISTER_OP("SetStatsAggregatorDataset")
.Input("input_dataset: variant")
.Input("stats_aggregator: resource")
diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc
index b9f94ba1c5..7d79df9c1c 100644
--- a/tensorflow/core/ops/list_ops.cc
+++ b/tensorflow/core/ops/list_ops.cc
@@ -210,7 +210,8 @@ REGISTER_OP("TensorListFromTensor")
shape_inference::ShapeHandle o;
TF_RETURN_IF_ERROR(c->Subshape(s, 1, &o));
shape_inference::ShapeHandle element_shape;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &element_shape));
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
+ 1, &element_shape));
TF_RETURN_IF_ERROR(c->Merge(o, element_shape, &o));
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{{element_shape, t}});
@@ -240,7 +241,8 @@ REGISTER_OP("TensorListReserve")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
shape_inference::ShapeHandle s;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(0, &s));
DataType t;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
c->set_output_handle_shapes_and_types(
@@ -295,6 +297,51 @@ REGISTER_OP("TensorListSetItem")
return Status::OK();
});
+REGISTER_OP("TensorListGather")
+ .Input("input_handle: variant")
+ .Input("indices: int32")
+ .Output("values: element_dtype")
+ .Attr("element_dtype: type")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ DataType t;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
+ auto* handle_data = c->input_handle_shapes_and_types(0);
+ shape_inference::ShapeHandle element_shape = c->UnknownShape();
+ if (handle_data != nullptr) {
+ const shape_inference::ShapeAndType& list_shape_type =
+ (*handle_data)[0];
+ element_shape = list_shape_type.shape;
+ if (list_shape_type.dtype != t) {
+ return errors::InvalidArgument("Expected list with element dtype ",
+ DataTypeString(t),
+ " but got list with element dtype ",
+ DataTypeString(list_shape_type.dtype));
+ }
+ }
+ shape_inference::ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->Concatenate(c->input(1), element_shape, &out));
+ c->set_output(0, out);
+ return Status::OK();
+ });
+
+REGISTER_OP("TensorListScatter")
+ .Input("tensor: element_dtype")
+ .Input("indices: int32")
+ .Input("element_shape: shape_type")
+ .Output("output_handle: variant")
+ .Attr("element_dtype: type")
+ .Attr("shape_type: {int32, int64}")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ DataType t;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
+ shape_inference::ShapeHandle s;
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &s));
+ c->set_output_handle_shapes_and_types(0, {{s, t}});
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
REGISTER_OP("TensorListConcatLists")
.Input("input_a: variant")
.Input("input_b: variant")
diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc
index fbde692e95..639d211767 100644
--- a/tensorflow/core/ops/logging_ops.cc
+++ b/tensorflow/core/ops/logging_ops.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/dataset_stateful_op_whitelist.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
@@ -27,6 +28,8 @@ REGISTER_OP("Assert")
.Attr("summarize: int = 3")
.SetShapeFn(shape_inference::NoOutputs);
+WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("Assert");
+
REGISTER_OP("Print")
.Input("input: T")
.Input("data: U")
@@ -39,6 +42,8 @@ REGISTER_OP("Print")
.Attr("summarize: int = 3")
.SetShapeFn(shape_inference::UnchangedShape);
+WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("Print");
+
// ----------------------------------------------------------------------------
// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
// inputs or outputs in various ways.
@@ -116,4 +121,6 @@ REGISTER_OP("Timestamp")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);
+WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("Timestamp");
+
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 9429d91cb9..4419f93d0c 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -10459,33 +10459,6 @@ op {
is_stateful: true
}
op {
- name: "FeatureStatsDataset"
- input_arg {
- name: "input_dataset"
- type: DT_VARIANT
- }
- input_arg {
- name: "tag"
- type: DT_STRING
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
-}
-op {
name: "Fill"
input_arg {
name: "dims"
@@ -18475,6 +18448,201 @@ op {
}
}
op {
+ name: "ParseSequenceExample"
+ input_arg {
+ name: "serialized"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "debug_name"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "context_dense_defaults"
+ type_list_attr: "Tcontext_dense"
+ }
+ output_arg {
+ name: "context_sparse_indices"
+ type: DT_INT64
+ number_attr: "Ncontext_sparse"
+ }
+ output_arg {
+ name: "context_sparse_values"
+ type_list_attr: "context_sparse_types"
+ }
+ output_arg {
+ name: "context_sparse_shapes"
+ type: DT_INT64
+ number_attr: "Ncontext_sparse"
+ }
+ output_arg {
+ name: "context_dense_values"
+ type_list_attr: "Tcontext_dense"
+ }
+ output_arg {
+ name: "feature_list_sparse_indices"
+ type: DT_INT64
+ number_attr: "Nfeature_list_sparse"
+ }
+ output_arg {
+ name: "feature_list_sparse_values"
+ type_list_attr: "feature_list_sparse_types"
+ }
+ output_arg {
+ name: "feature_list_sparse_shapes"
+ type: DT_INT64
+ number_attr: "Nfeature_list_sparse"
+ }
+ output_arg {
+ name: "feature_list_dense_values"
+ type_list_attr: "feature_list_dense_types"
+ }
+ output_arg {
+ name: "feature_list_dense_lengths"
+ type: DT_INT64
+ number_attr: "Nfeature_list_dense"
+ }
+ attr {
+ name: "feature_list_dense_missing_assumed_empty"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "context_sparse_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "context_dense_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "feature_list_sparse_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "feature_list_dense_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "Ncontext_sparse"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "Ncontext_dense"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "Nfeature_list_sparse"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "Nfeature_list_dense"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "context_sparse_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "Tcontext_dense"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "feature_list_dense_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "context_dense_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "feature_list_sparse_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "feature_list_dense_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+}
+op {
name: "ParseSingleExample"
input_arg {
name: "serialized"
@@ -26809,6 +26977,7 @@ op {
s: "squared_loss"
s: "hinge_loss"
s: "smooth_hinge_loss"
+ s: "poisson_loss"
}
}
}
@@ -33981,6 +34150,25 @@ op {
}
}
op {
+ name: "TensorListGather"
+ input_arg {
+ name: "input_handle"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "indices"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "values"
+ type_attr: "element_dtype"
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+}
+op {
name: "TensorListGetItem"
input_arg {
name: "input_handle"
@@ -34097,6 +34285,39 @@ op {
}
}
op {
+ name: "TensorListScatter"
+ input_arg {
+ name: "tensor"
+ type_attr: "element_dtype"
+ }
+ input_arg {
+ name: "indices"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "element_shape"
+ type_attr: "shape_type"
+ }
+ output_arg {
+ name: "output_handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+ attr {
+ name: "shape_type"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "TensorListSetItem"
input_arg {
name: "input_handle"
diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc
index ddb714b4e9..79ca96d249 100644
--- a/tensorflow/core/ops/parsing_ops.cc
+++ b/tensorflow/core/ops/parsing_ops.cc
@@ -132,6 +132,99 @@ REGISTER_OP("ParseSingleExample")
return Status::OK();
});
+REGISTER_OP("ParseSequenceExample")
+ .Input("serialized: string")
+ .Input("debug_name: string")
+ .Input("context_dense_defaults: Tcontext_dense")
+ .Output("context_sparse_indices: Ncontext_sparse * int64")
+ .Output("context_sparse_values: context_sparse_types")
+ .Output("context_sparse_shapes: Ncontext_sparse * int64")
+ .Output("context_dense_values: Tcontext_dense")
+ .Output("feature_list_sparse_indices: Nfeature_list_sparse * int64")
+ .Output("feature_list_sparse_values: feature_list_sparse_types")
+ .Output("feature_list_sparse_shapes: Nfeature_list_sparse * int64")
+ .Output("feature_list_dense_values: feature_list_dense_types")
+ .Output("feature_list_dense_lengths: Nfeature_list_dense * int64")
+ .Attr("feature_list_dense_missing_assumed_empty: list(string) >= 0")
+ .Attr("context_sparse_keys: list(string) >= 0")
+ .Attr("context_dense_keys: list(string) >= 0")
+ .Attr("feature_list_sparse_keys: list(string) >= 0")
+ .Attr("feature_list_dense_keys: list(string) >= 0")
+ .Attr("Ncontext_sparse: int >= 0 = 0")
+ .Attr("Ncontext_dense: int >= 0 = 0")
+ .Attr("Nfeature_list_sparse: int >= 0 = 0")
+ .Attr("Nfeature_list_dense: int >= 0 = 0")
+ .Attr("context_sparse_types: list({float,int64,string}) >= 0 = []")
+ .Attr("Tcontext_dense: list({float,int64,string}) >= 0 = []")
+ .Attr("feature_list_dense_types: list({float,int64,string}) >= 0 = []")
+ .Attr("context_dense_shapes: list(shape) >= 0 = []")
+ .Attr("feature_list_sparse_types: list({float,int64,string}) >= 0 = []")
+ .Attr("feature_list_dense_shapes: list(shape) >= 0 = []")
+ .SetShapeFn([](InferenceContext* c) {
+ ParseSequenceExampleAttrs attrs;
+ TF_RETURN_IF_ERROR(attrs.Init(c));
+
+ // Verify that the input is a vector, and carry the shape if known.
+ ShapeHandle input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input));
+ shape_inference::DimensionHandle num_examples = c->Dim(input, 0);
+
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // debug_name
+
+ int output_idx = 0;
+
+ // Output context_sparse_indices, context_sparse_values, and
+ // context_sparse_shapes.
+ for (int i = 0; i < attrs.num_context_sparse; ++i) {
+ c->set_output(output_idx++, c->Matrix(c->UnknownDim(), 2));
+ }
+ for (int i = 0; i < attrs.num_context_sparse; ++i) {
+ c->set_output(output_idx++, c->Vector(c->UnknownDim()));
+ }
+ for (int i = 0; i < attrs.num_context_sparse; ++i) {
+ c->set_output(output_idx++, c->Vector(2));
+ }
+
+ // Output context_dense_values.
+ for (int i = 0; i < attrs.num_context_dense; ++i) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ attrs.context_dense_shapes[i], &s));
+ TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(num_examples), s, &s));
+ c->set_output(output_idx++, s);
+ }
+
+ // Output feature_list_sparse_indices, feature_list_sparse_values,
+ // feature_list_sparse_shapes.
+ for (int i = 0; i < attrs.num_feature_list_sparse; ++i) {
+ c->set_output(output_idx++, c->Matrix(c->UnknownDim(), 3));
+ }
+ for (int i = 0; i < attrs.num_feature_list_sparse; ++i) {
+ c->set_output(output_idx++, c->Vector(c->UnknownDim()));
+ }
+ for (int i = 0; i < attrs.num_feature_list_sparse; ++i) {
+ c->set_output(output_idx++, c->Vector(3));
+ }
+
+ // Output feature_list_dense_shapes.
+ for (int i = 0; i < attrs.num_feature_list_dense; ++i) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ attrs.feature_list_dense_shapes[i], &s));
+ TF_RETURN_IF_ERROR(
+ c->Concatenate(c->Matrix(num_examples, c->UnknownDim()), s, &s));
+ c->set_output(output_idx++, s);
+ }
+
+ // Output feature_list_dense_lengths.
+ for (int i = 0; i < attrs.num_feature_list_dense; ++i) {
+ c->set_output(output_idx++, c->Vector(num_examples));
+ }
+
+ return Status::OK();
+ });
+
REGISTER_OP("ParseSingleSequenceExample")
.Input("serialized: string")
.Input("feature_list_dense_missing_assumed_empty: string")
diff --git a/tensorflow/core/ops/parsing_ops_test.cc b/tensorflow/core/ops/parsing_ops_test.cc
index 9121d7ae92..c65e66d1a8 100644
--- a/tensorflow/core/ops/parsing_ops_test.cc
+++ b/tensorflow/core/ops/parsing_ops_test.cc
@@ -143,6 +143,88 @@ TEST(ParsingOpsTest, ParseExample_ShapeFn) {
"?;?;?;?;?;?;?;?;?;?");
}
+TEST(ParsingOpsTest, ParseSequenceExample_ShapeFn) {
+ ShapeInferenceTestOp op("ParseSequenceExample");
+ auto set_outputs = [&op](int num_context_sparse, int num_context_dense,
+ int num_feature_list_sparse,
+ int num_feature_list_dense,
+ bool add_extra_shape = false) {
+ using NodeOutList = std::vector<NodeDefBuilder::NodeOut>;
+ using DataTypeList = std::vector<DataType>;
+ string string_in("test");
+ NodeDefBuilder::NodeOut node_in{"a", 0, DT_STRING};
+ TF_ASSERT_OK(
+ NodeDefBuilder("test", "ParseSequenceExample")
+ .Input("serialized", 0, DT_STRING)
+ .Input("debug_name", 0, DT_STRING)
+ .Input(NodeOutList(num_context_dense, node_in))
+ .Attr("Ncontext_sparse", num_context_sparse)
+ .Attr("Ncontext_dense", num_context_dense)
+ .Attr("Nfeature_list_sparse", num_feature_list_sparse)
+ .Attr("Nfeature_list_dense", num_feature_list_dense)
+ .Attr("feature_list_dense_missing_assumed_empty",
+ std::vector<string>(num_feature_list_dense, string_in))
+ .Attr("context_sparse_keys",
+ std::vector<string>(num_context_sparse, string_in))
+ .Attr("context_dense_keys",
+ std::vector<string>(num_context_dense, string_in))
+ .Attr("feature_list_sparse_keys",
+ std::vector<string>(num_feature_list_sparse, string_in))
+ .Attr("feature_list_dense_keys",
+ std::vector<string>(num_feature_list_dense, string_in))
+ .Attr("context_sparse_types",
+ DataTypeList(num_context_sparse, DT_FLOAT))
+ .Attr("context_dense_types",
+ DataTypeList(num_context_dense, DT_FLOAT))
+ .Attr("context_dense_shapes",
+ MakeDenseShapes(num_context_dense, add_extra_shape, 0))
+ .Attr("feature_list_sparse_types",
+ DataTypeList(num_feature_list_sparse, DT_FLOAT))
+ .Attr("feature_list_dense_types",
+ DataTypeList(num_feature_list_dense, DT_FLOAT))
+ .Attr("feature_list_dense_shapes",
+ MakeDenseShapes(num_feature_list_dense, add_extra_shape, 0))
+ .Finalize(&op.node_def));
+ };
+
+ // Verify inputs 'serialized' and 'debug_name'.
+ set_outputs(0, 0, 0, 0);
+ INFER_OK(op, "[?];[?]", "");
+ INFER_OK(op, "[8];[8]", "");
+ INFER_ERROR("must be rank 1", op, "[];[?]");
+ INFER_ERROR("must be rank 1", op, "[?];[]");
+
+ // context inputs with no feature_list inputs.
+ set_outputs(2 /* num_context_sparse */, 3 /* num_context_dense */, 0, 0);
+ INFER_OK(op, "[?];[?];?;?;?",
+ ("[?,2];[?,2];[?];[?];[2];[2];" // context sparse
+ "[d0_0,1];[d0_0,1,2];[d0_0,1,2,3]")); // context dense
+
+ // feature_list inputs with no context inputs.
+ set_outputs(0, 0, 2 /* num_feature_list_sparse */,
+ 3 /* num_feature_list_dense */);
+ INFER_OK(op, "[?];[?]",
+ ("[?,3];[?,3];[?];[?];[3];[3];" // feature_list sparse
+ "[d0_0,?,1];[d0_0,?,1,2];[d0_0,?,1,2,3];" // feature_list dense
+ "[d0_0];[d0_0];[d0_0]")); // feature_list length
+
+ // Combine previous two test cases.
+ set_outputs(2, 3, 2, 3);
+ INFER_OK(op, "[7];[7];?;?;?",
+ ("[?,2];[?,2];[?];[?];[2];[2];" // context sparse
+ "[d0_0,1];[d0_0,1,2];[d0_0,1,2,3];" // context dense
+ "[?,3];[?,3];[?];[?];[3];[3];" // feature_list sparse
+ "[d0_0,?,1];[d0_0,?,1,2];[d0_0,?,1,2,3];" // feature_list dense
+ "[d0_0];[d0_0];[d0_0]")); // feature_list length
+
+ // Confirm an error from ParseSequenceExampleAttrs.Init().
+ set_outputs(1, 1, 1, 1, true /* add_extra_shape */);
+ INFER_ERROR(
+ "num_context_dense (1) must match the size of context_dense_keys (1), "
+ "context_dense_types (1) and context_dense_shapes (2)",
+ op, "[?];[?];?");
+}
+
TEST(ParsingOpsTest, ParseSingleSequenceExample_ShapeFn) {
ShapeInferenceTestOp op("ParseSingleSequenceExample");
auto set_outputs = [&op](int num_context_sparse, int num_context_dense,
diff --git a/tensorflow/core/ops/sdca_ops.cc b/tensorflow/core/ops/sdca_ops.cc
index 4025070adb..fdf53a55dd 100644
--- a/tensorflow/core/ops/sdca_ops.cc
+++ b/tensorflow/core/ops/sdca_ops.cc
@@ -41,7 +41,7 @@ static Status ApplySdcaOptimizerShapeFn(InferenceContext* c) {
REGISTER_OP("SdcaOptimizer")
.Attr(
"loss_type: {'logistic_loss', 'squared_loss', 'hinge_loss',"
- "'smooth_hinge_loss'}")
+ "'smooth_hinge_loss', 'poisson_loss'}")
.Attr("adaptative : bool=false")
.Attr("num_sparse_features: int >= 0")
.Attr("num_sparse_features_with_values: int >= 0")
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 6a4ff9a1cb..07b2e3426b 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -623,7 +623,11 @@ def tf_additional_lib_defines():
def tf_additional_lib_deps():
"""Additional dependencies needed to build TF libraries."""
- return ["@com_google_absl//absl/base:base"] + if_static(
+ return [
+ "@com_google_absl//absl/base:base",
+ "@com_google_absl//absl/types:span",
+ "@com_google_absl//absl/types:optional",
+ ] + if_static(
["@nsync//:nsync_cpp"],
["@nsync//:nsync_headers"],
) + select({
diff --git a/tensorflow/core/protobuf/debug.proto b/tensorflow/core/protobuf/debug.proto
index 811cf406b9..8ca76c44c0 100644
--- a/tensorflow/core/protobuf/debug.proto
+++ b/tensorflow/core/protobuf/debug.proto
@@ -60,6 +60,12 @@ message DebugOptions {
// Note that this is distinct from the session run count and the executor
// step count.
int64 global_step = 10;
+
+ // Whether the total disk usage of tfdbg is to be reset to zero
+ // in this Session.run call. This is used by wrappers and hooks
+ // such as the local CLI ones to indicate that the dumped tensors
+ // are cleaned up from the disk after each Session.run.
+ bool reset_disk_byte_usage = 11;
}
message DebuggedSourceFile {
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc
index a38cd1d09f..e52d55e2ff 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.cc
+++ b/tensorflow/core/util/example_proto_fast_parsing.cc
@@ -1722,10 +1722,11 @@ Status FastParseSequenceExample(
const FastParseExampleConfig& feature_list_config,
gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names,
thread::ThreadPool* thread_pool, Result* context_result,
- Result* feature_list_result) {
+ Result* feature_list_result, std::vector<Tensor>* dense_feature_lengths) {
int num_examples = serialized.size();
DCHECK(context_result != nullptr);
DCHECK(feature_list_result != nullptr);
+ DCHECK(dense_feature_lengths != nullptr);
std::map<StringPiece, bool> context_is_sparse;
std::map<StringPiece, std::pair<DataType, size_t>>
context_feature_type_and_lengths;
@@ -1740,9 +1741,22 @@ Status FastParseSequenceExample(
context_is_sparse[c.feature_name] = true;
}
for (auto& c : context_config.dense) {
+ if (context_is_sparse[c.feature_name]) {
+ return errors::InvalidArgument("Context feature " + c.feature_name +
+ " cannot be both dense and sparse");
+ }
TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
context_feature_type_and_lengths[c.feature_name] =
- std::make_pair(c.dtype, 0);
+ std::make_pair(c.dtype, c.default_value.NumElements());
+ if (c.default_value.NumElements() > 0) {
+ if (!c.shape.IsCompatibleWith(c.default_value.shape())) {
+ return errors::InvalidArgument("Default value for context feature ",
+ c.feature_name,
+ " has an incorrect shape: saw ",
+ c.default_value.shape().DebugString(),
+ " but expected ", c.shape.DebugString());
+ }
+ }
context_is_sparse[c.feature_name] = false;
}
std::map<StringPiece, bool> sequence_is_sparse;
@@ -1755,6 +1769,10 @@ Status FastParseSequenceExample(
sequence_is_sparse[c.feature_name] = true;
}
for (auto& c : feature_list_config.dense) {
+ if (sequence_is_sparse[c.feature_name]) {
+ return errors::InvalidArgument("Sequence feature " + c.feature_name +
+ " cannot be both dense and sparse");
+ }
TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
sequence_feature_type_and_lengths[c.feature_name] =
std::make_pair(c.dtype, 0);
@@ -1792,14 +1810,14 @@ Status FastParseSequenceExample(
features = sequence_features;
config = &sequence_feature_type_and_lengths;
} else if (!SkipExtraneousTag(&stream)) {
- return errors::InvalidArgument(strings::StrCat(
- "Invalid protocol message input, example id: ", example_name));
+ return errors::InvalidArgument(
+ "Invalid protocol message input, example id: ", example_name);
}
if (features != nullptr) {
uint32 length;
if (!stream.ReadVarint32(&length)) {
- return errors::InvalidArgument(strings::StrCat(
- "Invalid protocol message input, example id: ", example_name));
+ return errors::InvalidArgument(
+ "Invalid protocol message input, example id: ", example_name);
}
auto limit = stream.PushLimit(length);
while (!stream.ExpectAtEnd()) {
@@ -1807,16 +1825,16 @@ Status FastParseSequenceExample(
uint32 length;
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!stream.ReadVarint32(&length)) {
- return errors::InvalidArgument(strings::StrCat(
- "Invalid protocol message input, example id: ", example_name));
+ return errors::InvalidArgument(
+ "Invalid protocol message input, example id: ", example_name);
}
auto limit = stream.PushLimit(length);
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!ParseString(&stream, &key) ||
!stream.ExpectTag(kDelimitedTag(2)) ||
!ParseString(&stream, &value) || !stream.ExpectAtEnd()) {
- return errors::InvalidArgument(strings::StrCat(
- "Invalid protocol message input, example id: ", example_name));
+ return errors::InvalidArgument(
+ "Invalid protocol message input, example id: ", example_name);
}
stream.PopLimit(limit);
// Only save if this feature was requested.
@@ -1851,9 +1869,8 @@ Status FastParseSequenceExample(
break;
}
if (num == -1) {
- return errors::InvalidArgument(
- strings::StrCat("Error in context feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in context feature ", c.first,
+ " in example ", example_name);
}
num_elements += num;
}
@@ -1876,9 +1893,9 @@ Status FastParseSequenceExample(
uint32 feature_length;
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!stream.ReadVarint32(&feature_length)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.first, " in example ",
+ example_name);
}
if (feature_length > 2) {
auto limit = stream.PushLimit(feature_length);
@@ -1898,22 +1915,22 @@ Status FastParseSequenceExample(
break;
}
if (num == -1) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.first, " in example ",
+ example_name);
}
num_elements += num;
stream.PopLimit(limit);
} else if (feature_length == 2) {
if (!SkipEmptyFeature(&stream, dtype)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.first, " in example ",
+ example_name);
}
} else if (feature_length != 0) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.first, " in example ",
+ example_name);
}
}
}
@@ -1936,15 +1953,19 @@ Status FastParseSequenceExample(
feature_list_result->sparse_indices.resize(feature_list_config.sparse.size());
feature_list_result->sparse_shapes.resize(feature_list_config.sparse.size());
feature_list_result->dense_values.resize(feature_list_config.dense.size());
+ dense_feature_lengths->resize(feature_list_config.dense.size());
+
int t = 0;
for (const auto& c : context_config.dense) {
- TensorShape dense_shape;
+ TensorShape dense_shape, example_shape;
DataType dtype = c.dtype;
- size_t expected_max_elements =
+ const size_t expected_max_elements =
context_feature_type_and_lengths[c.feature_name].second;
- if (expected_max_elements != dense_shape.num_elements()) {
- return errors::InvalidArgument(strings::StrCat(
- "Inconsistent number of elements for feature ", c.feature_name));
+ if (!c.shape.AsTensorShape(&example_shape) ||
+ expected_max_elements != example_shape.num_elements()) {
+ return errors::InvalidArgument(
+ "Inconsistent number of elements for feature ", c.feature_name, ": ",
+ expected_max_elements, " vs ", dense_shape.num_elements());
}
dense_shape.AddDim(num_examples);
for (const int dim : c.shape.dim_sizes()) {
@@ -1968,18 +1989,58 @@ Status FastParseSequenceExample(
out_int64 = context_result->dense_values[t].flat<int64>().data();
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in feature ", c.feature_name);
}
t++;
// Fill in the values.
for (int e = 0; e < num_examples; e++) {
size_t num_elements = 0;
- const auto& feature = all_context_features[e][c.feature_name];
+ const auto feature_iter = all_context_features[e].find(c.feature_name);
const string& example_name =
example_names.empty() ? kUnknown : example_names[e];
- if (!feature.empty()) {
+ if (feature_iter == all_context_features[e].end()) {
+ // Copy the default value, if present. If not, return an error.
+ if (c.default_value.NumElements() == 0) {
+ return errors::InvalidArgument(
+ "Feature: ", c.feature_name,
+ " (data type: ", DataTypeString(c.dtype), ")",
+ " is required but could not be found.");
+ }
+ const string* in_bytes = nullptr;
+ const float* in_float = nullptr;
+ const int64* in_int64 = nullptr;
+ size_t num = 0;
+ switch (dtype) {
+ case DT_STRING:
+ in_bytes = c.default_value.flat<string>().data();
+ num = c.default_value.NumElements();
+ for (int p = 0; p < num; p++) {
+ *out_bytes++ = *in_bytes++;
+ }
+ break;
+ case DT_FLOAT:
+ in_float = c.default_value.flat<float>().data();
+ num = c.default_value.NumElements();
+ for (int p = 0; p < num; p++) {
+ *out_float++ = *in_float++;
+ }
+ break;
+ case DT_INT64:
+ in_int64 = c.default_value.flat<int64>().data();
+ num = c.default_value.NumElements();
+ for (int p = 0; p < num; p++) {
+ *out_int64++ = *in_int64++;
+ }
+ break;
+ default:
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
+ }
+ num_elements += num;
+ } else if (!feature_iter->second.empty()) {
+ const auto& feature = feature_iter->second;
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(feature.data()), feature.size());
EnableAliasing(&stream);
@@ -1998,14 +2059,14 @@ Status FastParseSequenceExample(
out_int64 += num_added;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
num_elements += num_added;
}
if (num_elements != expected_max_elements) {
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected number of elements in example ", example_name));
+ return errors::InvalidArgument(
+ "Unexpected number of elements in example ", example_name);
}
}
}
@@ -2037,8 +2098,8 @@ Status FastParseSequenceExample(
out_int64 = context_result->sparse_values[t].flat<int64>().data();
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in feature ", c.feature_name);
}
int64* out_indices = context_result->sparse_indices[t].flat<int64>().data();
auto out_shape = context_result->sparse_shapes[t].vec<int64>();
@@ -2070,8 +2131,8 @@ Status FastParseSequenceExample(
out_int64 += num_added;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
num_elements += num_added;
max_num_cols = std::max(max_num_cols, num_added);
@@ -2082,30 +2143,35 @@ Status FastParseSequenceExample(
}
}
if (num_elements != expected_num_elements) {
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected total number of elements in feature ", c.feature_name));
+ return errors::InvalidArgument(
+ "Unexpected total number of elements in feature ", c.feature_name);
}
out_shape(0) = num_examples;
out_shape(1) = max_num_cols;
}
t = 0;
+ TensorShape dense_length_shape({num_examples});
for (const auto& c : feature_list_config.dense) {
TensorShape dense_shape, row_shape;
DataType dtype = c.dtype;
- size_t expected_max_elements =
+ const size_t expected_max_elements =
sequence_feature_type_and_lengths[c.feature_name].second;
- int64 expected_max_rows = expected_max_elements / row_shape.num_elements();
if (!c.shape.AsTensorShape(&row_shape) ||
- expected_max_elements != expected_max_rows * row_shape.num_elements()) {
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected shape error in feature ", c.feature_name));
+ expected_max_elements !=
+ (expected_max_elements / row_shape.num_elements()) *
+ row_shape.num_elements()) {
+ return errors::InvalidArgument("Unexpected shape error in feature ",
+ c.feature_name);
}
+ int64 expected_max_rows = expected_max_elements / row_shape.num_elements();
dense_shape.AddDim(num_examples);
dense_shape.AddDim(expected_max_rows);
for (const int dim : feature_list_config.dense[t].shape.dim_sizes()) {
dense_shape.AddDim(dim);
}
feature_list_result->dense_values[t] = Tensor(dtype, dense_shape);
+ (*dense_feature_lengths)[t] = Tensor(DT_INT64, dense_length_shape);
+ int64* out_lengths = (*dense_feature_lengths)[t].flat<int64>().data();
string* out_bytes = nullptr;
float* out_float = nullptr;
@@ -2121,18 +2187,26 @@ Status FastParseSequenceExample(
out_int64 = feature_list_result->dense_values[t].flat<int64>().data();
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in feature ", c.feature_name);
}
t++;
// Fill in the values.
for (int e = 0; e < num_examples; e++) {
- size_t num_elements = 0;
- const auto& feature = all_sequence_features[e][c.feature_name];
+ size_t num_elements = 0, num_rows = 0;
+ const auto feature_iter = all_sequence_features[e].find(c.feature_name);
const string& example_name =
example_names.empty() ? kUnknown : example_names[e];
- if (!feature.empty()) {
+ if (feature_iter == all_sequence_features[e].end()) {
+ // Return an error if this feature was not allowed to be missing.
+ // Otherwise, we'll pad as needed below.
+ if (!c.variable_length) {
+ return errors::InvalidArgument("Missing feature ", c.feature_name,
+ " in example ", example_name);
+ }
+ } else if (!feature_iter->second.empty()) {
+ const auto& feature = feature_iter->second;
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(feature.data()), feature.size());
EnableAliasing(&stream);
@@ -2140,9 +2214,9 @@ Status FastParseSequenceExample(
uint32 feature_length;
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!stream.ReadVarint32(&feature_length)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.feature_name,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.feature_name, " in example ",
+ example_name);
}
auto limit = stream.PushLimit(feature_length);
size_t num_added;
@@ -2160,10 +2234,11 @@ Status FastParseSequenceExample(
out_int64 += num_added;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
num_elements += num_added;
+ num_rows++;
if (num_added != row_shape.num_elements()) {
return errors::InvalidArgument(
"Unexpected number of elements in feature ", c.feature_name,
@@ -2172,6 +2247,7 @@ Status FastParseSequenceExample(
stream.PopLimit(limit);
}
}
+ *out_lengths++ = num_rows;
// Pad as necessary.
int num_to_pad = expected_max_elements - num_elements;
switch (dtype) {
@@ -2187,8 +2263,8 @@ Status FastParseSequenceExample(
out_int64 += num_to_pad;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
}
}
@@ -2219,8 +2295,8 @@ Status FastParseSequenceExample(
out_int64 = feature_list_result->sparse_values[t].flat<int64>().data();
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in feature ", c.feature_name);
}
int64* out_indices =
feature_list_result->sparse_indices[t].flat<int64>().data();
@@ -2244,9 +2320,9 @@ Status FastParseSequenceExample(
uint32 feature_length;
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!stream.ReadVarint32(&feature_length)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.feature_name,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.feature_name, " in example ",
+ example_name);
}
if (feature_length > 2) {
auto limit = stream.PushLimit(feature_length);
@@ -2265,8 +2341,8 @@ Status FastParseSequenceExample(
out_int64 += num_added;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
num_elements += num_added;
max_num_cols = std::max(max_num_cols, num_added);
@@ -2278,14 +2354,14 @@ Status FastParseSequenceExample(
stream.PopLimit(limit);
} else if (feature_length == 2) {
if (!SkipEmptyFeature(&stream, dtype)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.feature_name,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.feature_name, " in example ",
+ example_name);
}
} else if (feature_length != 0) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.feature_name,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.feature_name, " in example ",
+ example_name);
}
num_rows++;
}
@@ -2293,8 +2369,8 @@ Status FastParseSequenceExample(
}
}
if (num_elements != expected_num_elements) {
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected number of elements in feature ", c.feature_name));
+ return errors::InvalidArgument(
+ "Unexpected number of elements in feature ", c.feature_name);
}
out_shape(0) = num_examples;
out_shape(1) = max_num_rows;
diff --git a/tensorflow/core/util/example_proto_fast_parsing.h b/tensorflow/core/util/example_proto_fast_parsing.h
index db5b5ff929..055d9c2c30 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.h
+++ b/tensorflow/core/util/example_proto_fast_parsing.h
@@ -118,7 +118,8 @@ Status FastParseSequenceExample(
const example::FastParseExampleConfig& feature_list_config,
gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names,
thread::ThreadPool* thread_pool, example::Result* context_result,
- example::Result* feature_list_result);
+ example::Result* feature_list_result,
+ std::vector<Tensor>* dense_feature_lengths);
// This function parses serialized Example and populates given example.
// It uses the same specialized parser as FastParseExample which is efficient.
diff --git a/tensorflow/core/util/example_proto_fast_parsing_test.cc b/tensorflow/core/util/example_proto_fast_parsing_test.cc
index 37faa927bf..6c5f80a535 100644
--- a/tensorflow/core/util/example_proto_fast_parsing_test.cc
+++ b/tensorflow/core/util/example_proto_fast_parsing_test.cc
@@ -42,7 +42,7 @@ string SerializedToReadable(string serialized) {
string result;
result += '"';
for (char c : serialized)
- result += strings::StrCat("\\x", strings::Hex(c, strings::ZERO_PAD_2));
+ result += strings::StrCat("\\x", strings::Hex(c, strings::kZeroPad2));
result += '"';
return result;
}
diff --git a/tensorflow/core/util/example_proto_helper.cc b/tensorflow/core/util/example_proto_helper.cc
index e156a3bc8f..41fb20c00a 100644
--- a/tensorflow/core/util/example_proto_helper.cc
+++ b/tensorflow/core/util/example_proto_helper.cc
@@ -443,6 +443,59 @@ Status ParseSingleExampleAttrs::FinishInit() {
return Status::OK();
}
+Status ParseSequenceExampleAttrs::FinishInit() {
+ if (num_context_sparse != context_sparse_keys.size() ||
+ num_context_sparse != context_sparse_types.size()) {
+ return errors::InvalidArgument(
+ "num_context_sparse (", num_context_sparse,
+ ") must match the size of context_sparse_keys (",
+ context_sparse_keys.size(), ") and context_sparse_types (",
+ context_sparse_types.size(), ")");
+ }
+ if (num_context_dense != context_dense_keys.size() ||
+ num_context_dense != context_dense_types.size() ||
+ num_context_dense != context_dense_shapes.size()) {
+ return errors::InvalidArgument(
+ "num_context_dense (", num_context_dense,
+ ") must match the size of context_dense_keys (",
+ context_dense_keys.size(), "), context_dense_types (",
+ context_dense_types.size(), ") and context_dense_shapes (",
+ context_dense_shapes.size(), ")");
+ }
+ if (num_feature_list_sparse != feature_list_sparse_keys.size() ||
+ num_feature_list_sparse != feature_list_sparse_types.size()) {
+ return errors::InvalidArgument(
+ "num_feature_list_sparse (", num_feature_list_sparse,
+ ") must match the size of feature_list_sparse_keys (",
+ feature_list_sparse_keys.size(), ") and feature_list_sparse_types (",
+ feature_list_sparse_types.size(), ")");
+ }
+ if (num_feature_list_dense != feature_list_dense_keys.size() ||
+ num_feature_list_dense != feature_list_dense_types.size() ||
+ num_feature_list_dense != feature_list_dense_shapes.size()) {
+ return errors::InvalidArgument(
+ "num_feature_list_dense (", num_feature_list_dense,
+ ") must match the size of feature_list_dense_keys (",
+ feature_list_dense_keys.size(), "), feature_list_dense_types (",
+ feature_list_dense_types.size(), ") and feature_list_dense_shapes (",
+ feature_list_dense_shapes.size(), ")");
+ }
+ for (const DataType& type : context_dense_types) {
+ TF_RETURN_IF_ERROR(CheckValidType(type));
+ }
+ for (const DataType& type : context_sparse_types) {
+ TF_RETURN_IF_ERROR(CheckValidType(type));
+ }
+ for (const DataType& type : feature_list_dense_types) {
+ TF_RETURN_IF_ERROR(CheckValidType(type));
+ }
+ for (const DataType& type : feature_list_sparse_types) {
+ TF_RETURN_IF_ERROR(CheckValidType(type));
+ }
+
+ return Status::OK();
+}
+
Status ParseSingleSequenceExampleAttrs::FinishInit() {
if (static_cast<size_t>(num_context_sparse) != context_sparse_types.size()) {
return errors::InvalidArgument(
diff --git a/tensorflow/core/util/example_proto_helper.h b/tensorflow/core/util/example_proto_helper.h
index e511704962..c183ee4d96 100644
--- a/tensorflow/core/util/example_proto_helper.h
+++ b/tensorflow/core/util/example_proto_helper.h
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"
@@ -271,6 +272,66 @@ class ParseSingleExampleAttrs {
Status FinishInit(); // for context-independent parts of Init.
};
+// Parses the attributes passed to ParseSequenceExample.
+// REQUIRES: Init must be called after construction.
+class ParseSequenceExampleAttrs {
+ public:
+ template <typename ContextType>
+ Status Init(ContextType* ctx) {
+ std::vector<string> feature_list_dense_missing_assumed_empty_tmp;
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_dense_missing_assumed_empty",
+ &feature_list_dense_missing_assumed_empty_tmp));
+ for (const string& feature : feature_list_dense_missing_assumed_empty_tmp) {
+ feature_list_dense_missing_assumed_empty.insert(feature);
+ }
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("context_sparse_keys", &context_sparse_keys));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("context_dense_keys", &context_dense_keys));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_sparse_keys", &feature_list_sparse_keys));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_dense_keys", &feature_list_dense_keys));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("context_sparse_types", &context_sparse_types));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_dense", &num_context_dense));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("Nfeature_list_dense", &num_feature_list_dense));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_sparse", &num_context_sparse));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("Tcontext_dense", &context_dense_types));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_sparse_types", &feature_list_sparse_types));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_dense_types", &feature_list_dense_types));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("Nfeature_list_sparse", &num_feature_list_sparse));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("context_dense_shapes", &context_dense_shapes));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_dense_shapes", &feature_list_dense_shapes));
+ return FinishInit();
+ }
+
+ std::unordered_set<string> feature_list_dense_missing_assumed_empty;
+ int64 num_context_sparse;
+ int64 num_context_dense;
+ int64 num_feature_list_sparse;
+ int64 num_feature_list_dense;
+ std::vector<string> context_sparse_keys;
+ std::vector<string> context_dense_keys;
+ std::vector<string> feature_list_sparse_keys;
+ std::vector<string> feature_list_dense_keys;
+ std::vector<DataType> context_sparse_types;
+ std::vector<DataType> context_dense_types;
+ std::vector<TensorShape> context_dense_shapes;
+ std::vector<DataType> feature_list_sparse_types;
+ std::vector<DataType> feature_list_dense_types;
+ std::vector<TensorShape> feature_list_dense_shapes;
+
+ private:
+ Status FinishInit(); // for context-independent parts of Init.
+};
+
// Parses the attributes passed to ParseSingleSequenceExample.
// REQUIRES: Init must be called after construction.
class ParseSingleSequenceExampleAttrs {
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 0a96a603d0..680211edff 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
#ifdef INTEL_MKL
+#include <string>
#include <memory>
#include <unordered_map>
#include <utility>
@@ -33,6 +34,12 @@ limitations under the License.
#endif
#ifdef INTEL_MKL_ML_ONLY
+// Using pragma message since #warning doesn't work with all compilers
+#pragma message("Compiling for INTEL MKL ML only will be deprecated soon.")
+#pragma message("Please use MKL DNN (the default option for --config=mkl)")
+#endif
+
+#ifdef INTEL_MKL_ML_ONLY
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
#include "mkl_service.h"
@@ -50,6 +57,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/util/env_var.h"
#ifndef INTEL_MKL_ML_ONLY
#include "mkldnn.hpp"
@@ -96,6 +104,8 @@ typedef enum {
Dim3d_I = 1
} MklDnnDims3D;
+static const int kSmallBatchSize = 32;
+
#ifdef INTEL_MKL_ML_ONLY
class MklShape {
public:
@@ -1994,7 +2004,9 @@ const mkldnn::memory::dims NONE_DIMS = {};
template <typename T>
class MklPrimitiveFactory {
public:
- MklPrimitiveFactory() {}
+ MklPrimitiveFactory() {
+ }
+
~MklPrimitiveFactory() {}
MklPrimitive* GetOp(const string& key) {
@@ -2017,6 +2029,22 @@ class MklPrimitiveFactory {
map[key] = op;
}
+ /// Function to decide whether HW has AVX512 or AVX2
+ /// For those legacy device(w/o AVX512 and AVX2),
+ /// MKL-DNN GEMM will be used.
+ static inline bool IsLegacyPlatform() {
+ return (!port::TestCPUFeature(port::CPUFeature::AVX512F)
+ && !port::TestCPUFeature(port::CPUFeature::AVX2));
+ }
+
+ /// Fuction to check whether primitive memory optimization is enabled
+ static inline bool IsPrimitiveMemOptEnabled() {
+ bool is_primitive_mem_opt_enabled = true;
+ TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE", true,
+ &is_primitive_mem_opt_enabled));
+ return is_primitive_mem_opt_enabled;
+ }
+
private:
static inline std::unordered_map<string, MklPrimitive*>& GetHashMap() {
static thread_local std::unordered_map<string, MklPrimitive*> map_;
@@ -2054,7 +2082,7 @@ class FactoryKeyCreator {
const char delimiter = 'x';
const int kMaxKeyLength = 256;
void Append(StringPiece s) {
- key_.append(s.ToString());
+ key_.append(string(s));
key_.append(1, delimiter);
}
};
@@ -2093,7 +2121,7 @@ class MklReorderPrimitive : public MklPrimitive {
context_.dst_mem->set_data_handle(to->get_data_handle());
}
- private:
+ private:
struct ReorderContext {
std::shared_ptr<mkldnn::memory> src_mem;
std::shared_ptr<mkldnn::memory> dst_mem;
@@ -2135,7 +2163,7 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
return instance_;
}
- private:
+ private:
MklReorderPrimitiveFactory() {}
~MklReorderPrimitiveFactory() {}
@@ -2180,6 +2208,15 @@ inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
return *reorder_prim->GetPrimitive();
}
+// utility function to determine if it is conv 1x1 and stride != 1
+// for purpose of temporarily disabling primitive reuse
+inline bool IsConv1x1StrideNot1(memory::dims filter_dims, memory::dims strides) {
+ if (filter_dims.size() != 4 || strides.size() != 2) return false;
+
+ return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
+ ((strides[0] != 1) || (strides[1] != 1)));
+}
+
#endif // INTEL_MKL_DNN
} // namespace tensorflow
diff --git a/tensorflow/core/util/status_util.h b/tensorflow/core/util/status_util.h
deleted file mode 100644
index ea92f61dce..0000000000
--- a/tensorflow/core/util/status_util.h
+++ /dev/null
@@ -1,36 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_
-#define TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_
-
-#include "tensorflow/core/graph/graph.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-
-namespace tensorflow {
-
-// Creates a tag to be used in an exception error message. This can be parsed by
-// the Python layer and replaced with information about the node.
-//
-// For example, error_format_tag(node, "${file}") returns
-// "^^node:NODE_NAME:${line}^^" which would be rewritten by the Python layer as
-// e.g. "file/where/node/was/created.py".
-inline string error_format_tag(const Node& node, const string& format) {
- return strings::StrCat("^^node:", node.name(), ":", format, "^^");
-}
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_
diff --git a/tensorflow/examples/speech_commands/models.py b/tensorflow/examples/speech_commands/models.py
index 4d1454be0d..c63d4c3c7d 100644
--- a/tensorflow/examples/speech_commands/models.py
+++ b/tensorflow/examples/speech_commands/models.py
@@ -634,7 +634,7 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
label_count = model_settings['label_count']
final_fc_weights = tf.get_variable(
name='final_fc_weights',
- initializer=tf.truncated_normal(stddev=0.01),
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
shape=[second_fc_output_channels, label_count])
final_fc_bias = tf.get_variable(
name='final_fc_bias',
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 0aba0393af..5ebd409b15 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -3377,6 +3377,30 @@ func PopulationCount(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
+// Calculates the prior from the training data (the bias) and fills in the first node with the logits' prior. Returns a boolean indicating whether to continue centering.
+//
+// Arguments:
+// tree_ensemble_handle: Handle to the tree ensemble.
+// mean_gradients: A tensor with shape=[logits_dimension] with mean of gradients for a first node.
+// mean_hessians: A tensor with shape=[logits_dimension] mean of hessians for a first node.
+// l1: l1 regularization factor on leaf weights, per instance based.
+// l2: l2 regularization factor on leaf weights, per instance based.
+//
+// Returns Bool, whether to continue bias centering.
+func BoostedTreesCenterBias(scope *Scope, tree_ensemble_handle tf.Output, mean_gradients tf.Output, mean_hessians tf.Output, l1 tf.Output, l2 tf.Output) (continue_centering tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesCenterBias",
+ Input: []tf.Input{
+ tree_ensemble_handle, mean_gradients, mean_hessians, l1, l2,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the mean along sparse segments of a tensor.
//
// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
@@ -4650,51 +4674,6 @@ func CholeskyGrad(scope *Scope, l tf.Output, grad tf.Output) (output tf.Output)
return op.Output(0)
}
-// Computes the mean along sparse segments of a tensor.
-//
-// Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is
-// misisng, the `output` tensor at that position will be zeroed.
-//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
-//
-// Arguments:
-//
-// indices: A 1-D tensor. Has same rank as `segment_ids`.
-// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
-// num_segments: Should equal the number of distinct segment IDs.
-//
-// Returns Has same shape as data, except for dimension 0 which has size
-// `num_segments`.
-func SparseSegmentMeanWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SparseSegmentMeanWithNumSegments",
- Input: []tf.Input{
- data, indices, segment_ids, num_segments,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes hyperbolic cosine of x element-wise.
-func Cosh(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Cosh",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Creates a dataset that emits each dim-0 slice of `components` once.
func TensorSliceDataset(scope *Scope, components []tf.Output, output_shapes []tf.Shape) (handle tf.Output) {
if scope.Err() != nil {
@@ -8921,21 +8900,6 @@ func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value
return op.Output(0)
}
-// Computes tan of x element-wise.
-func Tan(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Tan",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Updates the tree ensemble by either adding a layer to the last tree being grown
//
// or by starting a new tree.
@@ -8976,6 +8940,21 @@ func BoostedTreesUpdateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, fe
return scope.AddOperation(opspec)
}
+// Computes tan of x element-wise.
+func Tan(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Tan",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// EncodeJpegAttr is an optional argument to EncodeJpeg.
type EncodeJpegAttr func(optionalAttr)
@@ -16650,30 +16629,6 @@ func OrderedMapUnstageNoKey(scope *Scope, indices tf.Output, dtypes []tf.DataTyp
return key, values
}
-// Calculates the prior from the training data (the bias) and fills in the first node with the logits' prior. Returns a boolean indicating whether to continue centering.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble.
-// mean_gradients: A tensor with shape=[logits_dimension] with mean of gradients for a first node.
-// mean_hessians: A tensor with shape=[logits_dimension] mean of hessians for a first node.
-// l1: l1 regularization factor on leaf weights, per instance based.
-// l2: l2 regularization factor on leaf weights, per instance based.
-//
-// Returns Bool, whether to continue bias centering.
-func BoostedTreesCenterBias(scope *Scope, tree_ensemble_handle tf.Output, mean_gradients tf.Output, mean_hessians tf.Output, l1 tf.Output, l2 tf.Output) (continue_centering tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesCenterBias",
- Input: []tf.Input{
- tree_ensemble_handle, mean_gradients, mean_hessians, l1, l2,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// SerializeManySparseAttr is an optional argument to SerializeManySparse.
type SerializeManySparseAttr func(optionalAttr)
@@ -20421,6 +20376,51 @@ func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf
return op.Output(0)
}
+// Computes hyperbolic cosine of x element-wise.
+func Cosh(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Cosh",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes the mean along sparse segments of a tensor.
+//
+// Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is
+// misisng, the `output` tensor at that position will be zeroed.
+//
+// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
+// segments.
+//
+// Arguments:
+//
+// indices: A 1-D tensor. Has same rank as `segment_ids`.
+// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+// num_segments: Should equal the number of distinct segment IDs.
+//
+// Returns Has same shape as data, except for dimension 0 which has size
+// `num_segments`.
+func SparseSegmentMeanWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseSegmentMeanWithNumSegments",
+ Input: []tf.Input{
+ data, indices, segment_ids, num_segments,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// CudnnRNNParamsSizeAttr is an optional argument to CudnnRNNParamsSize.
type CudnnRNNParamsSizeAttr func(optionalAttr)
@@ -26671,41 +26671,6 @@ func LatencyStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, o
return op.Output(0)
}
-// Runs multiple additive regression ensemble predictors on input instances and
-//
-// computes the update to cached logits. It is designed to be used during training.
-// It traverses the trees starting from cached tree id and cached node id and
-// calculates the updates to be pushed to the cache.
-//
-// Arguments:
-//
-// cached_tree_ids: Rank 1 Tensor containing cached tree ids which is the starting
-// tree of prediction.
-// cached_node_ids: Rank 1 Tensor containing cached node id which is the starting
-// node of prediction.
-// bucketized_features: A list of rank 1 Tensors containing bucket id for each
-// feature.
-// logits_dimension: scalar, dimension of the logits, to be used for partial logits
-// shape.
-//
-// Returns Rank 2 Tensor containing logits update (with respect to cached
-// values stored) for each example.Rank 1 Tensor containing new tree ids for each example.Rank 1 Tensor containing new node ids in the new tree_ids.
-func BoostedTreesTrainingPredict(scope *Scope, tree_ensemble_handle tf.Output, cached_tree_ids tf.Output, cached_node_ids tf.Output, bucketized_features []tf.Output, logits_dimension int64) (partial_logits tf.Output, tree_ids tf.Output, node_ids tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"logits_dimension": logits_dimension}
- opspec := tf.OpSpec{
- Type: "BoostedTreesTrainingPredict",
- Input: []tf.Input{
- tree_ensemble_handle, cached_tree_ids, cached_node_ids, tf.OutputList(bucketized_features),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// MapSizeAttr is an optional argument to MapSize.
type MapSizeAttr func(optionalAttr)
@@ -31918,3 +31883,38 @@ func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Outpu
}
return scope.AddOperation(opspec)
}
+
+// Runs multiple additive regression ensemble predictors on input instances and
+//
+// computes the update to cached logits. It is designed to be used during training.
+// It traverses the trees starting from cached tree id and cached node id and
+// calculates the updates to be pushed to the cache.
+//
+// Arguments:
+//
+// cached_tree_ids: Rank 1 Tensor containing cached tree ids which is the starting
+// tree of prediction.
+// cached_node_ids: Rank 1 Tensor containing cached node id which is the starting
+// node of prediction.
+// bucketized_features: A list of rank 1 Tensors containing bucket id for each
+// feature.
+// logits_dimension: scalar, dimension of the logits, to be used for partial logits
+// shape.
+//
+// Returns Rank 2 Tensor containing logits update (with respect to cached
+// values stored) for each example.Rank 1 Tensor containing new tree ids for each example.Rank 1 Tensor containing new node ids in the new tree_ids.
+func BoostedTreesTrainingPredict(scope *Scope, tree_ensemble_handle tf.Output, cached_tree_ids tf.Output, cached_node_ids tf.Output, bucketized_features []tf.Output, logits_dimension int64) (partial_logits tf.Output, tree_ids tf.Output, node_ids tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"logits_dimension": logits_dimension}
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesTrainingPredict",
+ Input: []tf.Input{
+ tree_ensemble_handle, cached_tree_ids, cached_node_ids, tf.OutputList(bucketized_features),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 74b001a572..459f494b48 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 28)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 4)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 55d2709845..849d165bfa 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -1100,6 +1100,23 @@ py_test(
],
)
+py_test(
+ name = "disk_usage_test",
+ size = "small",
+ srcs = ["wrappers/disk_usage_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dumping_wrapper",
+ ":hooks",
+ "//tensorflow/python:client",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ ],
+)
+
sh_test(
name = "examples_test",
size = "medium",
diff --git a/tensorflow/python/debug/examples/debug_tflearn_iris.py b/tensorflow/python/debug/examples/debug_tflearn_iris.py
index 7cbaae46b4..019f13c450 100644
--- a/tensorflow/python/debug/examples/debug_tflearn_iris.py
+++ b/tensorflow/python/debug/examples/debug_tflearn_iris.py
@@ -113,17 +113,16 @@ def main(_):
n_classes=3,
model_dir=model_dir)
- hooks = None
if FLAGS.debug and FLAGS.tensorboard_debug_address:
raise ValueError(
"The --debug and --tensorboard_debug_address flags are mutually "
"exclusive.")
+ hooks = []
if FLAGS.debug:
- debug_hook = tf_debug.LocalCLIDebugHook(ui_type=FLAGS.ui_type,
- dump_root=FLAGS.dump_root)
+ hooks.append(tf_debug.LocalCLIDebugHook(ui_type=FLAGS.ui_type,
+ dump_root=FLAGS.dump_root))
elif FLAGS.tensorboard_debug_address:
- debug_hook = tf_debug.TensorBoardDebugHook(FLAGS.tensorboard_debug_address)
- hooks = [debug_hook]
+ hooks.append(tf_debug.TensorBoardDebugHook(FLAGS.tensorboard_debug_address))
# Train model, using tfdbg hook.
classifier.train(training_input_fn,
diff --git a/tensorflow/python/debug/lib/debug_utils.py b/tensorflow/python/debug/lib/debug_utils.py
index f1e972940b..f2a43a6152 100644
--- a/tensorflow/python/debug/lib/debug_utils.py
+++ b/tensorflow/python/debug/lib/debug_utils.py
@@ -87,7 +87,8 @@ def watch_graph(run_options,
op_type_regex_whitelist=None,
tensor_dtype_regex_whitelist=None,
tolerate_debug_op_creation_failures=False,
- global_step=-1):
+ global_step=-1,
+ reset_disk_byte_usage=False):
"""Add debug watches to `RunOptions` for a TensorFlow graph.
To watch all `Tensor`s on the graph, let both `node_name_regex_whitelist`
@@ -130,6 +131,8 @@ def watch_graph(run_options,
throwing exceptions.
global_step: (`int`) Optional global_step count for this debug tensor
watch.
+ reset_disk_byte_usage: (`bool`) whether to reset the tracked disk byte
+ usage to zero (default: `False`).
"""
if isinstance(debug_ops, str):
@@ -170,6 +173,7 @@ def watch_graph(run_options,
tolerate_debug_op_creation_failures=(
tolerate_debug_op_creation_failures),
global_step=global_step)
+ run_options.debug_options.reset_disk_byte_usage = reset_disk_byte_usage
def watch_graph_with_blacklists(run_options,
@@ -180,7 +184,8 @@ def watch_graph_with_blacklists(run_options,
op_type_regex_blacklist=None,
tensor_dtype_regex_blacklist=None,
tolerate_debug_op_creation_failures=False,
- global_step=-1):
+ global_step=-1,
+ reset_disk_byte_usage=False):
"""Add debug tensor watches, blacklisting nodes and op types.
This is similar to `watch_graph()`, but the node names and op types are
@@ -219,6 +224,8 @@ def watch_graph_with_blacklists(run_options,
throwing exceptions.
global_step: (`int`) Optional global_step count for this debug tensor
watch.
+ reset_disk_byte_usage: (`bool`) whether to reset the tracked disk byte
+ usage to zero (default: `False`).
"""
if isinstance(debug_ops, str):
@@ -259,3 +266,4 @@ def watch_graph_with_blacklists(run_options,
tolerate_debug_op_creation_failures=(
tolerate_debug_op_creation_failures),
global_step=global_step)
+ run_options.debug_options.reset_disk_byte_usage = reset_disk_byte_usage
diff --git a/tensorflow/python/debug/wrappers/disk_usage_test.py b/tensorflow/python/debug/wrappers/disk_usage_test.py
new file mode 100644
index 0000000000..0874525966
--- /dev/null
+++ b/tensorflow/python/debug/wrappers/disk_usage_test.py
@@ -0,0 +1,109 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Debugger Wrapper Session Consisting of a Local Curses-based CLI."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+
+from tensorflow.python.client import session
+from tensorflow.python.debug.wrappers import dumping_wrapper
+from tensorflow.python.debug.wrappers import hooks
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.training import monitored_session
+
+
+class DumpingDebugWrapperDiskUsageLimitTest(test_util.TensorFlowTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ # For efficient testing, set the disk usage bytes limit to a small
+ # number (10).
+ os.environ["TFDBG_DISK_BYTES_LIMIT"] = "10"
+
+ def setUp(self):
+ self.session_root = tempfile.mkdtemp()
+
+ self.v = variables.Variable(10.0, dtype=dtypes.float32, name="v")
+ self.delta = constant_op.constant(1.0, dtype=dtypes.float32, name="delta")
+ self.eta = constant_op.constant(-1.4, dtype=dtypes.float32, name="eta")
+ self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v")
+ self.dec_v = state_ops.assign_add(self.v, self.eta, name="dec_v")
+
+ self.sess = session.Session()
+ self.sess.run(self.v.initializer)
+
+ def testWrapperSessionNotExceedingLimit(self):
+ def _watch_fn(fetches, feeds):
+ del fetches, feeds
+ return "DebugIdentity", r"(.*delta.*|.*inc_v.*)", r".*"
+ sess = dumping_wrapper.DumpingDebugWrapperSession(
+ self.sess, session_root=self.session_root,
+ watch_fn=_watch_fn, log_usage=False)
+ sess.run(self.inc_v)
+
+ def testWrapperSessionExceedingLimit(self):
+ def _watch_fn(fetches, feeds):
+ del fetches, feeds
+ return "DebugIdentity", r".*delta.*", r".*"
+ sess = dumping_wrapper.DumpingDebugWrapperSession(
+ self.sess, session_root=self.session_root,
+ watch_fn=_watch_fn, log_usage=False)
+ # Due to the watch function, each run should dump only 1 tensor,
+ # which has a size of 4 bytes, which corresponds to the dumped 'delta:0'
+ # tensor of scalar shape and float32 dtype.
+ # 1st run should pass, after which the disk usage is at 4 bytes.
+ sess.run(self.inc_v)
+ # 2nd run should also pass, after which 8 bytes are used.
+ sess.run(self.inc_v)
+ # 3rd run should fail, because the total byte count (12) exceeds the
+ # limit (10)
+ with self.assertRaises(ValueError):
+ sess.run(self.inc_v)
+
+ def testHookNotExceedingLimit(self):
+ def _watch_fn(fetches, feeds):
+ del fetches, feeds
+ return "DebugIdentity", r".*delta.*", r".*"
+ dumping_hook = hooks.DumpingDebugHook(
+ self.session_root, watch_fn=_watch_fn, log_usage=False)
+ mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
+ mon_sess.run(self.inc_v)
+
+ def testHookExceedingLimit(self):
+ def _watch_fn(fetches, feeds):
+ del fetches, feeds
+ return "DebugIdentity", r".*delta.*", r".*"
+ dumping_hook = hooks.DumpingDebugHook(
+ self.session_root, watch_fn=_watch_fn, log_usage=False)
+ mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
+ # Like in `testWrapperSessionExceedingLimit`, the first two calls
+ # should be within the byte limit, but the third one should error
+ # out due to exceeding the limit.
+ mon_sess.run(self.inc_v)
+ mon_sess.run(self.inc_v)
+ with self.assertRaises(ValueError):
+ mon_sess.run(self.inc_v)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py
index b9524ce649..afda1fdc0d 100644
--- a/tensorflow/python/debug/wrappers/framework.py
+++ b/tensorflow/python/debug/wrappers/framework.py
@@ -447,13 +447,16 @@ class BaseDebugWrapperSession(session.SessionInterface):
"callable_runner and callable_options are mutually exclusive, but "
"are both specified in this call to BaseDebugWrapperSession.run().")
- if not (callable_runner or callable_options):
- self.increment_run_call_count()
- elif callable_runner and (fetches or feed_dict):
+ if callable_runner and (fetches or feed_dict):
raise ValueError(
"callable_runner and fetches/feed_dict are mutually exclusive, "
"but are used simultaneously.")
+ elif callable_options and (fetches or feed_dict):
+ raise ValueError(
+ "callable_options and fetches/feed_dict are mutually exclusive, "
+ "but are used simultaneously.")
+ self.increment_run_call_count()
empty_fetches = not nest.flatten(fetches)
if empty_fetches:
tf_logging.info(
@@ -649,6 +652,18 @@ class BaseDebugWrapperSession(session.SessionInterface):
def increment_run_call_count(self):
self._run_call_count += 1
+ def _is_disk_usage_reset_each_run(self):
+ """Indicates whether disk usage is reset after each Session.run.
+
+ Subclasses that clean up the disk usage after every run should
+ override this protected method.
+
+ Returns:
+ (`bool`) Whether the disk usage amount is reset to zero after
+ each Session.run.
+ """
+ return False
+
def _decorate_run_options_for_debug(
self,
run_options,
@@ -686,7 +701,9 @@ class BaseDebugWrapperSession(session.SessionInterface):
node_name_regex_whitelist=node_name_regex_whitelist,
op_type_regex_whitelist=op_type_regex_whitelist,
tensor_dtype_regex_whitelist=tensor_dtype_regex_whitelist,
- tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures)
+ tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures,
+ reset_disk_byte_usage=(self._run_call_count == 1 or
+ self._is_disk_usage_reset_each_run()))
def _decorate_run_options_for_profile(self, run_options):
"""Modify a RunOptions object for profiling TensorFlow graph execution.
diff --git a/tensorflow/python/debug/wrappers/hooks.py b/tensorflow/python/debug/wrappers/hooks.py
index 5e4604fda4..872b675506 100644
--- a/tensorflow/python/debug/wrappers/hooks.py
+++ b/tensorflow/python/debug/wrappers/hooks.py
@@ -188,6 +188,7 @@ class DumpingDebugHook(session_run_hook.SessionRunHook):
pass
def before_run(self, run_context):
+ reset_disk_byte_usage = False
if not self._session_wrapper:
self._session_wrapper = dumping_wrapper.DumpingDebugWrapperSession(
run_context.session,
@@ -195,6 +196,7 @@ class DumpingDebugHook(session_run_hook.SessionRunHook):
watch_fn=self._watch_fn,
thread_name_filter=self._thread_name_filter,
log_usage=self._log_usage)
+ reset_disk_byte_usage = True
self._session_wrapper.increment_run_call_count()
@@ -212,7 +214,8 @@ class DumpingDebugHook(session_run_hook.SessionRunHook):
op_type_regex_whitelist=watch_options.op_type_regex_whitelist,
tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist,
tolerate_debug_op_creation_failures=(
- watch_options.tolerate_debug_op_creation_failures))
+ watch_options.tolerate_debug_op_creation_failures),
+ reset_disk_byte_usage=reset_disk_byte_usage)
run_args = session_run_hook.SessionRunArgs(
None, feed_dict=None, options=run_options)
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
index 668ffb57f1..a3ce4d388b 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
@@ -124,6 +124,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
self._ui_type = ui_type
+ def _is_disk_usage_reset_each_run(self):
+ # The dumped tensors are all cleaned up after every Session.run
+ # in a command-line wrapper.
+ return True
+
def _initialize_argparsers(self):
self._argparsers = {}
ap = argparse.ArgumentParser(
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index a081c30781..bdc869c643 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -34,7 +34,11 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ ":distribute_coordinator_context",
+ ":multi_worker_util",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:session",
"//tensorflow/python:training",
],
)
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index 46cdd64a6e..bd3562f1ff 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -24,9 +24,10 @@ import os
import threading
import time
-from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator_context
+from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import monitored_session
from tensorflow.python.training import server_lib
@@ -238,19 +239,26 @@ class _WorkerContext(object):
Returns:
a descendant of SessionCreator.
"""
- # TODO(yuefengz): merge session config.
- if self._strategy.should_init:
+ if config:
+ session_config = copy.deepcopy(config)
+ session_config.MergeFrom(self._session_config)
+ else:
+ session_config = self._session_config
+
+ if not self._strategy or self._strategy.should_init:
+ logging.info("Creating chief session creator with config: %r", config)
return monitored_session.ChiefSessionCreator(
scaffold,
master=self.master_target,
- config=config or self._session_config,
+ config=session_config,
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=checkpoint_filename_with_path)
else:
+ logging.info("Creating worker session creator with config: %r", config)
return monitored_session.WorkerSessionCreator(
scaffold,
master=self.master_target,
- config=config or self._session_config,
+ config=session_config,
max_wait_secs=max_wait_secs)
@property
@@ -313,12 +321,17 @@ def _run_single_worker(worker_fn,
rpc_layer="",
worker_barrier=None):
"""Runs a single worker by calling `worker_fn` under context."""
+ session_config = copy.deepcopy(session_config)
strategy = copy.deepcopy(strategy)
# If there is an EVALUATOR task, we run single-machine eval on that task.
if task_type == _TaskType.EVALUATOR:
- strategy.configure(session_config)
+ # It is possible to not have a strategy object for EVALUATOR task.
+ if strategy:
+ strategy.configure(session_config)
else:
+ assert strategy
strategy.configure(session_config, cluster_spec, task_type, task_id)
+
context = _WorkerContext(
strategy,
cluster_spec,
@@ -331,6 +344,25 @@ def _run_single_worker(worker_fn,
worker_fn(strategy)
+def _split_cluster_for_evaluator(cluster_spec, task_type):
+ """Split the cluster for evaluator since it needn't talk to other tasks."""
+ # Splitting the cluster is important to prevent the evaluator from talking to
+ # other tasks in the cluster. Since we allow evaluator not to use
+ # distribution strategies and as a result ops in the evalauator task may have
+ # unspecified devices. Those ops may end up on other tasks if we don't split
+ # the cluster.
+ new_cluster_spec = multi_worker_util.normalize_cluster_spec(
+ cluster_spec).as_dict()
+ if task_type == _TaskType.EVALUATOR:
+ assert _TaskType.EVALUATOR in new_cluster_spec
+ new_cluster_spec = {
+ _TaskType.EVALUATOR: new_cluster_spec[_TaskType.EVALUATOR]
+ }
+ else:
+ new_cluster_spec.pop(_TaskType.EVALUATOR, None)
+ return multi_worker_util.normalize_cluster_spec(new_cluster_spec)
+
+
def _run_std_server(cluster_spec=None,
task_type=None,
task_id=None,
@@ -338,16 +370,19 @@ def _run_std_server(cluster_spec=None,
rpc_layer=None,
environment=None):
"""Runs a standard server."""
+ assert cluster_spec
+ target = cluster_spec.task_address(task_type, task_id)
+ if rpc_layer:
+ target = rpc_layer + "://" + target
class _FakeServer(object):
"""A fake server that runs a master session."""
def start(self):
- assert cluster_spec
- target = cluster_spec.task_address(task_type, task_id)
- if rpc_layer:
- target = rpc_layer + "://" + target
# A tensorflow server starts when a remote session is created.
+ logging.info(
+ "Creating a remote session to start a TensorFlow server, "
+ "target = %r, session_config=%r", target, session_config)
session.Session(target=target, config=session_config)
def join(self):
@@ -359,6 +394,13 @@ def _run_std_server(cluster_spec=None,
server.start()
return server
else:
+ if session_config:
+ logging.info(
+ "Starting standard TensorFlow server, target = %r, session_config= "
+ "%r", target, session_config)
+ else:
+ logging.info("Starting standard TensorFlow server, target = %r", target)
+ cluster_spec = _split_cluster_for_evaluator(cluster_spec, task_type)
server = server_lib.Server(
cluster_spec,
job_name=task_type,
@@ -376,7 +418,7 @@ def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread(
target=_run_single_worker,
- args=(eval_fn, eval_strategy, None, _TaskType.EVALUATOR, 0,
+ args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0,
session_config),
kwargs={
"rpc_layer": rpc_layer,
@@ -432,6 +474,106 @@ def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
if eval_thread:
eval_thread.join()
+
+def _configure_session_config_for_std_servers(
+ strategy, eval_strategy, session_config, cluster_spec, task_type, task_id):
+ # pylint: disable=g-doc-args
+ """Call strategy's `configure` to mutate the session_config.
+
+ The session_config is currently needed as default config for a TensorFlow
+ server. In the future, we should be able to remove this method and only pass
+ the session config to a client session.
+ """
+ if task_type == _TaskType.EVALUATOR:
+ if eval_strategy:
+ eval_strategy.configure(session_config=session_config)
+ else:
+ # The strategy may be shared in standalone client mode.
+ strategy = copy.deepcopy(strategy)
+ strategy.configure(
+ session_config=session_config,
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id)
+ # Remove the device filters specific to the strategy, so that the
+ # TensorFlow server brought up with one strategy can be used by other
+ # strategies. The device filters can be set in the client side as well.
+ del session_config.device_filters[:]
+
+
+def run_standard_tensorflow_server(session_config=None):
+ """Starts a standard TensorFlow server.
+
+ This method parses configurations from "TF_CONFIG" environment variable and
+ starts a TensorFlow server. The "TF_CONFIG" is typically a json string and
+ must have information of the cluster and the role of the server in the
+ cluster. One example is:
+
+ TF_CONFIG='{
+ "cluster": {
+ "worker": ["host1:2222", "host2:2222", "host3:2222"],
+ "ps": ["host4:2222", "host5:2222"]
+ },
+ "task": {"type": "worker", "index": 1}
+ }'
+
+ This "TF_CONFIG" specifies there are 3 workers and 2 ps tasks in the cluster
+ and the current role is worker 1.
+
+ Valid task types are "chief", "worker", "ps" and "evaluator" and you can have
+ at most one "chief" and at most one "evaluator".
+
+ An optional key-value can be specified is "rpc_layer". The default value is
+ "grpc".
+
+ Args:
+ session_config: an optional `tf.ConfigProto` object. Users can pass in
+ the session config object to configure server-local devices.
+
+ Returns:
+ a `tf.train.Server` object which has already been started.
+
+ Raises:
+ ValueError: if the "TF_CONFIG" environment is not complete.
+ """
+ tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
+ if "cluster" not in tf_config:
+ raise ValueError("\"cluster\" is not found in TF_CONFIG.")
+ cluster_spec = multi_worker_util.normalize_cluster_spec(tf_config["cluster"])
+ if "task" not in tf_config:
+ raise ValueError("\"task\" is not found in TF_CONFIG.")
+ task_env = tf_config["task"]
+ if "type" not in task_env:
+ raise ValueError(
+ "\"task_type\" is not found in the `task` part of TF_CONFIG.")
+ task_type = task_env["type"]
+ task_id = int(task_env.get("index", 0))
+
+ rpc_layer = tf_config.get("rpc_layer", "grpc")
+
+ session_config = session_config or config_pb2.ConfigProto()
+ # Set the collective group leader for collective ops to initialize collective
+ # ops when server starts.
+ if "chief" in cluster_spec.jobs:
+ session_config.experimental.collective_group_leader = (
+ "/job:chief/replica:0/task:0")
+ else:
+ if "worker" not in cluster_spec.jobs:
+ raise ValueError(
+ "You must have `chief` or `worker` jobs in the `cluster_spec`.")
+ session_config.experimental.collective_group_leader = (
+ "/job:worker/replica:0/task:0")
+
+ server = _run_std_server(
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id,
+ session_config=session_config,
+ rpc_layer=rpc_layer)
+ server.start()
+ return server
+
+
# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode.
# TODO(yuefengz): we may need a smart way to figure out whether the current task
# is the special task when we support cluster_spec propagation.
@@ -533,8 +675,10 @@ def run_distribute_coordinator(worker_fn,
strategy: a DistributionStrategy object which specifying whether it should
run between-graph replicated training or not, whether to run init ops,
etc. This object will also be configured given `session_config`,
- `cluster_spc`, `task_type` and `task_id`.
- eval_fn: optional function for "evaluator" task.
+ `cluster_spec`, `task_type` and `task_id`.
+ eval_fn: optional function for "evaluator" task. If `eval_fn` is not passed
+ in but a "evaluator" task found in the `cluster_spec`, the `worker_fn`
+ will be used for this task.
eval_strategy: optional DistributionStrategy object for "evaluator" task.
mode: in which mode this distribute coordinator runs.
cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
@@ -558,17 +702,17 @@ def run_distribute_coordinator(worker_fn,
task_id = int(task_env.get("index", task_id))
if cluster_spec:
- if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
- cluster_spec = server_lib.ClusterSpec(cluster_spec)
- elif not isinstance(cluster_spec, server_lib.ClusterSpec):
- raise ValueError(
- "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
- "`tf.train.ClusterDef` object")
+ cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
# TODO(yuefengz): validate cluster_spec.
rpc_layer = tf_config.get("rpc_layer", rpc_layer)
environment = tf_config.get("environment", None)
+ # Setting the session config is necessary for some strategies such
+ # CollectiveAllReduceStrategy.
+ session_config = session_config or config_pb2.ConfigProto(
+ allow_soft_placement=True)
+
if cluster_spec:
logging.info(
"Running Distribute Coordinator with mode = %r, cluster_spec = %r, "
@@ -581,11 +725,18 @@ def run_distribute_coordinator(worker_fn,
_run_single_worker(worker_fn, strategy, None, None, None, session_config,
rpc_layer)
if eval_fn:
- _run_single_worker(eval_fn, eval_strategy or strategy, None, None, None,
+ _run_single_worker(eval_fn, eval_strategy, None, None, None,
session_config, rpc_layer)
+ else:
+ logging.warning("Skipped evaluation since `eval_fn` is not passed in.")
elif mode == CoordinatorMode.STANDALONE_CLIENT:
+ if not eval_fn:
+ logging.warning("`eval_fn` is not passed in. The `worker_fn` will be "
+ "used if an \"evaluator\" task exists in the cluster.")
eval_fn = eval_fn or worker_fn
- eval_strategy = eval_strategy or strategy
+ if not eval_strategy:
+ logging.warning("`eval_strategy` is not passed in. No distribution "
+ "strategy will be used for evaluation.")
# The client must know the cluster but servers in the cluster don't have to
# know the client.
@@ -598,10 +749,14 @@ def run_distribute_coordinator(worker_fn,
cluster_spec, session_config, rpc_layer)
else:
# If not a client job, run the standard server.
+ _configure_session_config_for_std_servers(strategy, eval_strategy,
+ session_config, cluster_spec,
+ task_type, task_id)
server = _run_std_server(
cluster_spec=cluster_spec,
task_type=task_type,
task_id=task_id,
+ session_config=session_config,
rpc_layer=rpc_layer,
environment=environment)
server.join()
@@ -609,14 +764,24 @@ def run_distribute_coordinator(worker_fn,
if mode != CoordinatorMode.INDEPENDENT_WORKER:
raise ValueError("Unexpected coordinator mode: %r" % mode)
+ if not eval_fn:
+ logging.warning("`eval_fn` is not passed in. The `worker_fn` will be "
+ "used if an \"evaluator\" task exists in the cluster.")
eval_fn = eval_fn or worker_fn
- eval_strategy = eval_strategy or strategy
-
- # Every one starts a standard server.
+ if not eval_strategy:
+ logging.warning("`eval_strategy` is not passed in. No distribution "
+ "strategy will be used for evaluation.")
+
+ # Every one starts a standard server, get session config from `configure`
+ # method.
+ _configure_session_config_for_std_servers(strategy, eval_strategy,
+ session_config, cluster_spec,
+ task_type, task_id)
server = _run_std_server(
cluster_spec=cluster_spec,
task_type=task_type,
task_id=task_id,
+ session_config=session_config,
rpc_layer=rpc_layer,
environment=environment)
diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py
index 5dd57fa134..b07308a1b5 100644
--- a/tensorflow/python/distribute/distribute_coordinator_test.py
+++ b/tensorflow/python/distribute/distribute_coordinator_test.py
@@ -23,19 +23,18 @@ import copy
import json
import os
import sys
-import time
import threading
+import time
import six
-# pylint: disable=invalid-name
_portpicker_import_error = None
try:
import portpicker # pylint: disable=g-import-not-at-top
-except ImportError as _error:
+except ImportError as _error: # pylint: disable=invalid-name
_portpicker_import_error = _error
portpicker = None
-# pylint: enable=invalid-name
+# pylint: disable=g-import-not-at-top
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator
@@ -96,11 +95,10 @@ class MockStrategy(object):
return self._between_graph
def configure(self,
- session_options=None,
+ session_config=None,
cluster_spec=None,
task_type=None,
task_id=None):
- del session_options, cluster_spec, task_type
if self._should_init is None:
if task_id == 0:
self._should_init = True
@@ -117,6 +115,17 @@ class MockStrategy(object):
else:
self._should_save_summary = False
+ if session_config:
+ if (cluster_spec and task_type and task_id is not None and
+ self._between_graph):
+ session_config.intra_op_parallelism_threads += 1
+ if task_type in ["chief", "worker"]:
+ session_config.device_filters.extend(
+ ["/job:%s/task:%d" % (task_type, task_id), "/job:ps"])
+ else:
+ session_config.inter_op_parallelism_threads += 1
+ session_config.device_filters.append("/job:somejob")
+
@property
def should_init(self):
return self._should_init
@@ -134,6 +143,10 @@ class MockServer(object):
def __init__(self):
self._joined = False
+ self._started = False
+
+ def start(self):
+ self._started = True
def join(self):
assert not self._joined
@@ -143,6 +156,10 @@ class MockServer(object):
def joined(self):
return self._joined
+ @property
+ def started(self):
+ return self._started
+
class DistributeCoordinatorTestBase(test.TestCase):
@@ -151,6 +168,7 @@ class DistributeCoordinatorTestBase(test.TestCase):
# We have to create a global in-process cluster because once an in-process
# tensorflow server is created, there is no way to terminate it. Please see
# multi_worker_test_base.py for more details.
+ # TODO(yuefengz): use the utitliy from multi_worker_test_base.
cls._workers, cls._ps = test_util.create_local_cluster(
NUM_WORKERS, num_ps=NUM_PS)
cls._cluster_spec = {
@@ -175,6 +193,7 @@ class DistributeCoordinatorTestBase(test.TestCase):
with session.Session(graph=None, config=config, target=target) as sess:
yield sess
+ # TODO(yuefengz): use the utitliy from multi_worker_test_base.
def _create_cluster_spec(self,
has_chief=False,
num_workers=1,
@@ -748,7 +767,7 @@ class DistributeCoordinatorTestInpendentWorkerMode(
def _thread_fn(cluster_spec):
distribute_coordinator.run_distribute_coordinator(
None,
- None,
+ MockStrategy(between_graph=True),
mode=INDEPENDENT_WORKER,
cluster_spec=cluster_spec,
task_type="ps",
@@ -785,7 +804,7 @@ class DistributeCoordinatorTestInpendentWorkerMode(
distribute_coordinator, "_run_std_server", _run_mock_server):
distribute_coordinator.run_distribute_coordinator(
None,
- None,
+ MockStrategy(between_graph=True),
mode=INDEPENDENT_WORKER,
cluster_spec=cluster_spec,
task_type="ps",
@@ -793,6 +812,121 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self.assertEqual(rpc_layer_from_coordinator[0], "cake")
+class StrategyConfigureTest(test.TestCase):
+
+ def setUp(self):
+ self._device_filters = []
+ self._intra_op_parallelism_threads = None
+ self._inter_op_parallelism_threads = None
+ super(StrategyConfigureTest, self).setUp()
+
+ def _dump_device_filters(self, *args, **kwargs):
+ session_config = kwargs.get("session_config", None)
+ self._device_filters.extend(session_config.device_filters)
+ self._intra_op_parallelism_threads = (
+ session_config.intra_op_parallelism_threads)
+ self._inter_op_parallelism_threads = (
+ session_config.inter_op_parallelism_threads)
+ return MockServer()
+
+ def _worker_fn(self, strategy):
+ worker_context = distribute_coordinator_context.get_current_worker_context()
+ session_config = worker_context._session_config
+ self._device_filters.extend(session_config.device_filters)
+ self._intra_op_parallelism_threads = (
+ session_config.intra_op_parallelism_threads)
+ self._inter_op_parallelism_threads = (
+ session_config.inter_op_parallelism_threads)
+ return MockServer()
+
+ def test_session_config_in_std_server(self):
+ cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
+ tf_config = {"cluster": cluster_spec}
+
+ with test.mock.patch.dict(
+ "os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
+ distribute_coordinator, "_run_std_server",
+ self._dump_device_filters):
+ distribute_coordinator.run_distribute_coordinator(
+ lambda _: None,
+ MockStrategy(between_graph=True),
+ mode=INDEPENDENT_WORKER,
+ cluster_spec=cluster_spec,
+ task_type="worker",
+ task_id=0)
+ self.assertEqual(self._intra_op_parallelism_threads, 1)
+ self.assertEqual(self._inter_op_parallelism_threads, 0)
+
+ def test_session_config_in_session_creator(self):
+ cluster_spec = {"worker": ["localhost:0"]}
+ tf_config = {"cluster": cluster_spec}
+
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}):
+ distribute_coordinator.run_distribute_coordinator(
+ self._worker_fn,
+ MockStrategy(between_graph=True),
+ mode=INDEPENDENT_WORKER,
+ cluster_spec=cluster_spec,
+ task_type="worker",
+ task_id=0)
+ self.assertEqual(self._device_filters, ["/job:worker/task:0", "/job:ps"])
+ self.assertEqual(self._intra_op_parallelism_threads, 2)
+ self.assertEqual(self._inter_op_parallelism_threads, 0)
+
+ def test_eval_strategy_configure(self):
+ cluster_spec = {"evaluator": ["localhost:0"]}
+ tf_config = {"cluster": cluster_spec}
+
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}):
+ distribute_coordinator.run_distribute_coordinator(
+ lambda _: None,
+ MockStrategy(between_graph=False),
+ eval_fn=self._worker_fn,
+ eval_strategy=MockStrategy(between_graph=True),
+ mode=INDEPENDENT_WORKER,
+ cluster_spec=cluster_spec,
+ task_type="evaluator",
+ task_id=0)
+ self.assertEqual(self._device_filters, ["/job:somejob"])
+ self.assertEqual(self._intra_op_parallelism_threads, 0)
+ self.assertEqual(self._inter_op_parallelism_threads, 2)
+
+
+class RunStandardTensorflowServerTest(test.TestCase):
+
+ def test_std_server_arguments(self):
+ cs = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
+ tf_config = {"cluster": cs, "task": {"type": "ps", "id": 0}}
+
+ def _mock_run_std_server(cluster_spec=None,
+ task_type=None,
+ task_id=None,
+ session_config=None,
+ rpc_layer=None):
+ self.assertEqual(cluster_spec.as_dict(), cs)
+ self.assertEqual(task_type, "ps")
+ self.assertEqual(task_id, 0)
+ self.assertEqual(session_config.experimental.collective_group_leader,
+ "/job:worker/replica:0/task:0")
+ self.assertEqual(session_config.intra_op_parallelism_threads, 1)
+ self.assertEqual(rpc_layer, "grpc")
+
+ return MockServer()
+
+ with test.mock.patch.dict(
+ "os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
+ distribute_coordinator, "_run_std_server", _mock_run_std_server):
+ session_config = config_pb2.ConfigProto()
+ session_config.intra_op_parallelism_threads = 1
+ mock_server = distribute_coordinator.run_standard_tensorflow_server(
+ session_config)
+ self.assertTrue(mock_server.started)
+
+
if __name__ == "__main__":
# TODO(yuefengz): find a smart way to terminite std server threads.
with test.mock.patch.object(sys, "exit", os._exit):
diff --git a/tensorflow/python/distribute/estimator_training.py b/tensorflow/python/distribute/estimator_training.py
index 202e19c420..e17a598123 100644
--- a/tensorflow/python/distribute/estimator_training.py
+++ b/tensorflow/python/distribute/estimator_training.py
@@ -147,8 +147,8 @@ def init_run_config(config, tf_config):
# `experimental_distribute.remote_cluster` is set.
if (config._train_distribute and config._experimental_distribute and
config._experimental_distribute.remote_cluster):
- if tf_config:
- raise ValueError('Cannot set both TF_CONFIG environment variable and '
+ if cluster_spec:
+ raise ValueError('Cannot set both "cluster_spec" of TF_CONFIG and '
'`experimental_distribute.remote_cluster`')
config._distribute_coordinator_mode = dc.CoordinatorMode.STANDALONE_CLIENT
config._cluster_spec = config._experimental_distribute.remote_cluster
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 7978383e55..9891068056 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -522,7 +522,7 @@ def make_vjp(f, params=None, persistent=True):
args = _ensure_unique_tensor_objects(parameter_positions, args)
for i in parameter_positions:
sources.append(args[i])
- tape.watch(args[i])
+ tape.watch(this_tape, args[i])
result = f(*args)
if result is None:
raise ValueError("Cannot differentiate a function that returns None; "
@@ -748,7 +748,7 @@ class GradientTape(object):
tensor: a Tensor or list of Tensors.
"""
for t in nest.flatten(tensor):
- tape.watch(_handle_or_self(t))
+ tape.watch(self._tape, _handle_or_self(t))
@tf_contextlib.contextmanager
def stop_recording(self):
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 3d3f54b9c4..caf36b6a36 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -23,7 +23,6 @@ import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
-from tensorflow.python.eager import tape
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -87,7 +86,6 @@ class BackpropTest(test.TestCase):
initial_value=constant_op.constant(1.0), name='x')
def fn():
- tape.watch_variable(x)
b = constant_op.constant(2.0)
c = math_ops.add(x.value(), b)
return math_ops.add(c, constant_op.constant(3.0))
@@ -194,7 +192,6 @@ class BackpropTest(test.TestCase):
initial_value=random_init, dtype=dtypes.float32, name='embedding')
def f():
- tape.watch_variable(embedding)
embedded_x = embedding_ops.embedding_lookup(embedding, x)
return constant_op.constant(1.0, dtypes.float32) - embedded_x
@@ -316,6 +313,24 @@ class BackpropTest(test.TestCase):
grad = backprop.gradients_function(second, [0])(f)[0]
self.assertAllEqual([[0.0]], grad)
+ @test_util.run_in_graph_and_eager_modes
+ def testWatchingIsTapeLocal(self):
+ x1 = resource_variable_ops.ResourceVariable(2.0, trainable=False)
+ x2 = resource_variable_ops.ResourceVariable(2.0, trainable=False)
+
+ with backprop.GradientTape() as tape1:
+ with backprop.GradientTape() as tape2:
+ tape1.watch(x1)
+ tape2.watch([x1, x2])
+ y = x1 ** 3
+ z = x2 ** 2
+ dy, dz = tape2.gradient([y, z], [x1, x2])
+ d2y, d2z = tape1.gradient([dy, dz], [x1, x2])
+
+ self.evaluate([x1.initializer, x2.initializer])
+ self.assertEqual(self.evaluate(d2y), 12.0)
+ self.assertIsNone(d2z)
+
@test_util.assert_no_new_tensors
def testMakeVJP(self):
@@ -404,7 +419,6 @@ class BackpropTest(test.TestCase):
def f():
with context.device('gpu:0'):
- tape.watch_variable(v)
return v.read_value()
self.assertEqual(
@@ -784,7 +798,6 @@ class BackpropTest(test.TestCase):
initial_value=array_ops.constant([1.0]), name='x')
def fn():
- tape.watch_variable(x)
a = math_ops.add(x.value(), 1.0)
# Make sure convert_to_tensor works correctly with list of TensorNodes.
b = array_ops.stack([a, a], axis=0)
@@ -931,12 +944,12 @@ class BackpropTest(test.TestCase):
with ops.Graph().as_default(), self.test_session():
t = constant_op.constant(1, dtype=dtypes.float32, shape=(10, 4))
x = constant_op.constant(2, dtype=dtypes.float32, shape=(10, 4))
- with backprop.GradientTape() as gt:
+ with backprop.GradientTape() as tape:
tape.watch(x)
x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1)
y1 = x1**2
y = array_ops.concat([y1, t], axis=1)
- return self.evaluate(gt.gradient(y, x))
+ return self.evaluate(tape.gradient(y, x))
grad1 = get_grad()
grad2 = get_grad()
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 13fb0e88a6..778ff85342 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -37,7 +37,7 @@ GRAPH_MODE = 0
EAGER_MODE = 1
# Default execution mode.
-_default_mode = GRAPH_MODE
+default_execution_mode = GRAPH_MODE
# Cache from (old_device_name, partial_new_device_name) -> (new_device_name,
# new_device_spec).
@@ -56,14 +56,18 @@ SYNC = 0
ASYNC = 1
-class _TensorCache(object):
+class _EagerTensorCache(object):
"""Simple cache which evicts items based on length in a FIFO manner."""
- def __init__(self, max_items=256):
+ def __init__(self, max_items=256, max_tensor_size=10000):
self._data = collections.OrderedDict()
- self._max_items = max_items if max_items else 256
+ self._max_items = max_items
+ self._max_tensor_size = max_tensor_size
def put(self, key, value):
+ if value._num_elements() > self._max_tensor_size: # pylint: disable=protected-access
+ return
+
self._data[key] = value
if len(self._data) > self._max_items:
@@ -84,14 +88,14 @@ class _EagerContext(threading.local):
super(_EagerContext, self).__init__()
self.device_spec = pydev.DeviceSpec.from_string("")
self.device_name = self.device_spec.to_string()
- self.mode = _default_mode
- self.is_eager = _default_mode == EAGER_MODE
+ self.mode = default_execution_mode
+ self.is_eager = default_execution_mode == EAGER_MODE
self.scope_name = ""
self.recording_summaries = False
self.summary_writer_resource = None
self.scalar_cache = {}
- self.ones_rank_cache = _TensorCache()
- self.zeros_cache = _TensorCache()
+ self.ones_rank_cache = _EagerTensorCache()
+ self.zeros_cache = _EagerTensorCache()
self.execution_mode = None
@@ -111,8 +115,8 @@ class _ContextSwitchStack(threading.local):
# Initialize the stack with a pointer to enter the eager context; this
# ensures that the fact that eager execution was enabled is propagated
# across threads, since (1) `enable_eager_execution` modifies a
- # process-level flag (`_default_mode`) and (2) `__init__` is called each
- # time a threading.local object is used in a separate thread.
+ # process-level flag (`default_execution_mode`) and (2) `__init__` is
+ # called each time a threading.local object is used in a separate thread.
self.push(is_building_function=False, enter_context_fn=eager_mode)
def push(self, is_building_function, enter_context_fn):
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index cbd6f4cb75..fb5442b646 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -689,5 +689,16 @@ class SendRecvTest(test_util.TensorFlowTestCase):
2.0)
+class EagerTensorCacheTest(test_util.TensorFlowTestCase):
+
+ def testCacheSkipsTensorsTooLarge(self):
+ cache = context._EagerTensorCache(max_items=100, max_tensor_size=3)
+ cache.put('1', array_ops.zeros((2, 2)))
+ self.assertEqual(cache.get('1'), None)
+
+ cache.put('2', array_ops.zeros((2)))
+ self.assertNotEqual(cache.get('2'), None)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index c12bf89f8f..86fbd24d68 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -263,6 +263,14 @@ typedef struct EagerTensor {
TF_Status* status;
PyObject* weakreflist; /* List of weak references */
+
+ // Per-instance attribute dictionary, to support monkey patching
+ // (e.g. EagerTensor.assign when slicing variables). This dictionary is
+ // created by CPython the first time an attribute is assigned, pointed to by
+ // tp_dictoffset. Note that garbage collection is not enabled for
+ // EagerTensors, so assigning objects to EagerTensor attributes which require
+ // garbage collection is likely to cause issues.
+ PyObject* dict;
} EagerTensor;
namespace {
@@ -311,6 +319,7 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
Py_INCREF(Py_None);
self->tensor_shape = Py_None;
self->status = TF_NewStatus();
+ self->dict = nullptr;
self->weakreflist = nullptr;
PyObject* value;
PyObject* context = nullptr;
@@ -410,6 +419,10 @@ void EagerTensor_dealloc(EagerTensor* self) {
Py_DECREF(self->handle_data);
Py_DECREF(self->keras_mask);
Py_DECREF(self->tensor_shape);
+ // If an attribute dictionary has been created, release it. Note that this
+ // is only ever created by CPython's attribute setting methods; we don't
+ // create it ourselves.
+ Py_CLEAR(self->dict);
if (self->handle != nullptr) {
TFE_DeleteTensorHandle(self->handle);
self->handle = nullptr;
@@ -474,6 +487,30 @@ static PyObject* EagerTensor_rank(EagerTensor* self) {
#endif
}
+// Getter for `_num_elements`.
+static PyObject* EagerTensor_num_elements(EagerTensor* self) {
+ auto handle = self->handle;
+ int n = TFE_TensorHandleNumDims(handle, self->status);
+ if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
+ // Cleanup self->status before returning.
+ TF_SetStatus(self->status, TF_OK, "");
+ return nullptr;
+ }
+ tensorflow::int64 value = 1;
+ if (PyErr_Occurred()) return nullptr;
+ for (int i = 0; i < n; ++i) {
+ int64_t dim = TFE_TensorHandleDim(handle, i, self->status);
+ if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
+ // Cleanup self->status before returning.
+ TF_SetStatus(self->status, TF_OK, "");
+ PyErr_SetString(PyExc_RuntimeError, "Error while iterating dimensions");
+ return nullptr;
+ }
+ value *= dim;
+ }
+ return PyLong_FromLongLong(value);
+}
+
static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
Py_INCREF(self->handle_data);
return self->handle_data;
@@ -592,6 +629,8 @@ static PyMethodDef EagerTensor_methods[] = {
{"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS, PyDoc_STR("_rank")},
{"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
METH_VARARGS | METH_KEYWORDS, PyDoc_STR("_copy_to_device")},
+ {"_num_elements", (PyCFunction)EagerTensor_num_elements, METH_NOARGS,
+ PyDoc_STR("_num_elements")},
{nullptr, nullptr},
};
@@ -660,7 +699,7 @@ static PyTypeObject _EagerTensorType = {
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
- 0, /* tp_dictoffset */
+ offsetof(EagerTensor, dict), /* tp_dictoffset */
(initproc)EagerTensor_init, /* tp_init */
nullptr, /* tp_alloc */
nullptr, /* tp_new */
@@ -788,6 +827,7 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType");
return nullptr;
}
+ EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict);
#else
_EagerTensorType.tp_base = reinterpret_cast<PyTypeObject*>(base_class);
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index 823c4078b8..16f8c3c917 100755
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -138,7 +138,7 @@ void TFE_Py_TapeSetAdd(PyObject* tape);
PyObject* TFE_Py_TapeSetIsEmpty();
PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors);
-void TFE_Py_TapeSetWatch(PyObject* tensor);
+void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor);
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id);
// Stops any gradient recording on the current thread.
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 64cf36d079..0a33a04dcb 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1154,7 +1154,7 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
Py_RETURN_FALSE;
}
-void TFE_Py_TapeSetWatch(PyObject* tensor) {
+void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
if (*ThreadTapeIsStopped()) {
return;
}
@@ -1162,9 +1162,7 @@ void TFE_Py_TapeSetWatch(PyObject* tensor) {
if (PyErr_Occurred()) {
return;
}
- for (TFE_Py_Tape* tape : *GetTapeSet()) {
- tape->tape->Watch(tensor_id);
- }
+ reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
}
static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
@@ -1784,6 +1782,7 @@ bool OpDoesntRequireOutput(const string& op_name) {
"ReadVariableOp",
"VarHandleOp",
"Shape",
+ "StridedSlice",
});
return ops_that_dont_require_outputs->find(op_name) !=
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index caa217b70c..6eb62afec4 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -44,13 +44,9 @@ def push_tape(tape):
pywrap_tensorflow.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access
-def watch(tensor):
- """Marks this tensor to be watched by all tapes in the stack.
-
- Args:
- tensor: tensor to be watched.
- """
- pywrap_tensorflow.TFE_Py_TapeSetWatch(tensor)
+def watch(tape, tensor):
+ """Marks this tensor to be watched by the given tape."""
+ pywrap_tensorflow.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access
def watch_variable(variable):
diff --git a/tensorflow/python/estimator/canned/baseline_test.py b/tensorflow/python/estimator/canned/baseline_test.py
index e46a3a156d..1df7216ba6 100644
--- a/tensorflow/python/estimator/canned/baseline_test.py
+++ b/tensorflow/python/estimator/canned/baseline_test.py
@@ -42,13 +42,13 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import checkpoint_utils
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import input as input_lib
from tensorflow.python.training import optimizer
from tensorflow.python.training import queue_runner
@@ -490,7 +490,7 @@ class BaselineRegressorTrainingTest(test.TestCase):
self.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
@@ -498,7 +498,7 @@ class BaselineRegressorTrainingTest(test.TestCase):
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
mock_optimizer = test.mock.NonCallableMock(
@@ -693,13 +693,13 @@ class BaselineClassifierTrainingTest(test.TestCase):
# Verify loss. We can't check the value directly, so we add an assert op.
self.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
loss,
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
mock_optimizer = test.mock.NonCallableMock(
spec=optimizer.Optimizer,
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index ef7c217190..d104c961d3 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -38,7 +38,6 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.util.tf_export import estimator_export
@@ -876,7 +875,7 @@ def _bt_model_fn(
train_op.append(update_model)
with ops.control_dependencies([update_model]):
- increment_global = distribute_lib.increment_var(global_step)
+ increment_global = state_ops.assign_add(global_step, 1).op
train_op.append(increment_global)
return control_flow_ops.group(train_op, name='train_op')
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py
index 4945c3ba11..9799cf9e98 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py
@@ -31,10 +31,10 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import sync_replicas_optimizer
from tensorflow.python.training import training_util
from tensorflow.python.util.tf_export import estimator_export
@@ -161,8 +161,8 @@ def _dnn_linear_combined_model_fn(features,
with variable_scope.variable_scope(
dnn_parent_scope,
values=tuple(six.itervalues(features)),
- partitioner=dnn_partitioner):
-
+ partitioner=dnn_partitioner) as scope:
+ dnn_absolute_scope = scope.name
dnn_logit_fn = dnn._dnn_logit_fn_builder( # pylint: disable=protected-access
units=head.logits_dimension,
hidden_units=dnn_hidden_units,
@@ -186,6 +186,7 @@ def _dnn_linear_combined_model_fn(features,
linear_parent_scope,
values=tuple(six.itervalues(features)),
partitioner=input_layer_partitioner) as scope:
+ linear_absolute_scope = scope.name
logit_fn = linear._linear_logit_fn_builder( # pylint: disable=protected-access
units=head.logits_dimension,
feature_columns=linear_feature_columns,
@@ -211,18 +212,18 @@ def _dnn_linear_combined_model_fn(features,
loss,
var_list=ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES,
- scope=dnn_parent_scope)))
+ scope=dnn_absolute_scope)))
if linear_logits is not None:
train_ops.append(
linear_optimizer.minimize(
loss,
var_list=ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES,
- scope=linear_parent_scope)))
+ scope=linear_absolute_scope)))
train_op = control_flow_ops.group(*train_ops)
with ops.control_dependencies([train_op]):
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return head.create_estimator_spec(
features=features,
diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py
index de226ed0ef..11f1e93630 100644
--- a/tensorflow/python/estimator/canned/dnn_testing_utils.py
+++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py
@@ -44,13 +44,13 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.summary import summary as summary_lib
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import checkpoint_utils
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import monitored_session
from tensorflow.python.training import optimizer as optimizer_lib
@@ -222,7 +222,7 @@ def mock_optimizer(testcase, hidden_units, expected_loss=None):
testcase.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
@@ -230,7 +230,7 @@ def mock_optimizer(testcase, hidden_units, expected_loss=None):
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
optimizer_mock = test.mock.NonCallableMagicMock(
diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py
index c3934c7a80..65cdd50061 100644
--- a/tensorflow/python/estimator/canned/linear_testing_utils.py
+++ b/tensorflow/python/estimator/canned/linear_testing_utils.py
@@ -48,13 +48,13 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import checkpoint_utils
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import input as input_lib
from tensorflow.python.training import optimizer as optimizer_lib
@@ -756,7 +756,7 @@ class BaseLinearRegressorTrainingTest(object):
self.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
@@ -764,7 +764,7 @@ class BaseLinearRegressorTrainingTest(object):
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
mock_optimizer = test.mock.NonCallableMock(
@@ -979,13 +979,13 @@ class BaseLinearClassifierTrainingTest(object):
# Verify loss. We can't check the value directly, so we add an assert op.
self.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
loss,
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
mock_optimizer = test.mock.NonCallableMock(
spec=optimizer_lib.Optimizer,
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 97a02bd1e8..e44a69b374 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -35,17 +35,16 @@ from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.estimator.export import export as export_helpers
-from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_util
+from tensorflow.python.keras import metrics
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import metrics as metrics_lib
-from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
@@ -957,7 +956,12 @@ class Estimator(object):
mode=mode,
config=self.config)
- export_outputs = self._get_export_outputs_for_spec(estimator_spec)
+ export_outputs = model_fn_lib.export_outputs_for_mode(
+ mode=estimator_spec.mode,
+ serving_export_outputs=estimator_spec.export_outputs,
+ predictions=estimator_spec.predictions,
+ loss=estimator_spec.loss,
+ metrics=estimator_spec.eval_metric_ops)
# Build the SignatureDefs from receivers and all outputs
signature_def_map = export_helpers.build_all_signature_defs(
@@ -1014,45 +1018,6 @@ class Estimator(object):
else:
builder.add_meta_graph(**meta_graph_kwargs)
- def _get_export_outputs_for_spec(self, estimator_spec):
- """Given an `EstimatorSpec`, determine what our export outputs should be.
-
- `EstimatorSpecs` contains `export_outputs` that are used for serving, but
- for
- training and eval graphs, we must wrap the tensors of interest in
- appropriate `tf.estimator.export.ExportOutput` objects.
-
- Args:
- estimator_spec: `tf.estimator.EstimatorSpec` object that will be exported.
-
- Returns:
- a dict mapping `export_output_name` to `tf.estimator.export.ExportOutput`
- object.
-
- Raises:
- ValueError: if an appropriate `ExportOutput` cannot be found for the
- passed `EstimatorSpec.mode`
- """
- mode = estimator_spec.mode
- if mode == model_fn_lib.ModeKeys.PREDICT:
- outputs = estimator_spec.export_outputs
- else:
- if mode == model_fn_lib.ModeKeys.TRAIN:
- output_class = export_output.TrainOutput
- elif mode == model_fn_lib.ModeKeys.EVAL:
- output_class = export_output.EvalOutput
- else:
- raise ValueError(
- 'Export output type not found for mode: {}'.format(mode))
-
- export_out = output_class(
- loss=estimator_spec.loss,
- predictions=estimator_spec.predictions,
- metrics=estimator_spec.eval_metric_ops)
- outputs = {mode: export_out}
-
- return outputs
-
def _get_features_from_input_fn(self, input_fn, mode):
"""Extracts the `features` from return values of `input_fn`."""
result = self._call_input_fn(input_fn, mode)
@@ -1332,10 +1297,12 @@ class Estimator(object):
scaffold = _combine_distributed_scaffold(
grouped_estimator_spec.scaffold, self._train_distribution)
+ # TODO(yuefengz): add a test for unwrapping per_device_hooks.
def get_hooks_from_the_first_device(per_device_hooks):
- hooks_list = self._train_distribution.unwrap(per_device_hooks)
- assert hooks_list
- return hooks_list[0]
+ return [
+ self._distribution.unwrap(per_device_hook)[0]
+ for per_device_hook in per_device_hooks
+ ]
training_hooks = get_hooks_from_the_first_device(
grouped_estimator_spec.training_hooks)
@@ -1641,21 +1608,6 @@ def maybe_overwrite_model_dir_and_session_config(config, model_dir):
return config
-def create_per_tower_ready_op(scaffold):
- """Create a `tf.train.Scaffold.ready_op` inside a tower."""
- if scaffold.ready_op:
- return scaffold.ready_op
-
- def default_ready_op():
- return array_ops.concat([
- variables.report_uninitialized_variables(),
- resources.report_uninitialized_resources()
- ], 0)
-
- return monitored_session.Scaffold.get_or_default(
- 'ready_op', ops.GraphKeys.READY_OP, default_ready_op)
-
-
def create_per_tower_ready_for_local_init_op(scaffold):
"""Create a `tf.train.Scaffold.ready_for_local_init_op` inside a tower."""
if scaffold.ready_for_local_init_op:
@@ -1705,11 +1657,9 @@ def _combine_distributed_scaffold(grouped_scaffold, distribution):
return value[0]
ready_op = distribution.call_for_each_tower(
- create_per_tower_ready_op, grouped_scaffold)
+ lambda scaffold: scaffold.ready_op, grouped_scaffold)
if ready_op is not None:
ready_op = _unwrap_and_concat(ready_op)
- else:
- ready_op = None
ready_for_local_init_op = distribution.call_for_each_tower(
create_per_tower_ready_for_local_init_op, grouped_scaffold)
@@ -1837,19 +1787,21 @@ def _extract_metric_update_ops(eval_dict, distribution=None):
update_ops = []
value_ops = {}
# Sort metrics lexicographically so graph is identical every time.
- for name, metric_ops in sorted(six.iteritems(eval_dict)):
- value_ops[name] = metric_ops[0]
- if distribution:
- update_op = distribution.group(metric_ops[1])
+ for name, value in sorted(six.iteritems(eval_dict)):
+ if isinstance(value, metrics.Metric):
+ metric_result = value.result()
+ # We expect only one update op for every metric when there is no
+ # distribution strategy.
+ metric_update = value.updates if distribution else value.updates[0]
else:
- update_op = metric_ops[1]
- update_ops.append(update_op)
+ metric_result = value[0]
+ metric_update = value[1]
- if update_ops:
- update_op = control_flow_ops.group(*update_ops)
- else:
- update_op = None
+ value_ops[name] = metric_result
+ update_ops.append(
+ distribution.group(metric_update) if distribution else metric_update)
+ update_op = control_flow_ops.group(*update_ops) if update_ops else None
return update_op, value_ops
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index d316742a83..1ed5e30b0e 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -43,6 +43,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.layers import layers
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
@@ -971,19 +972,28 @@ class EstimatorTrainTest(test.TestCase):
def _model_fn_with_eval_metric_ops(features, labels, mode, params):
_, _ = features, labels
- metric_name = params.get('metric_name') or 'metric'
- metric_value = params.get('metric_value') or 2.
global_step = training.get_global_step()
loss = constant_op.constant(1.)
+ metric_name_1 = params.get('metric_name') or 'metric'
+ metric_value_1 = params.get('metric_value') or 2.
+ metric_name_2 = params.get('metric_name_2') or 'metric2'
+ metric_value_2 = params.get('metric_value_2') or 2.
+
metric_update_op = loss.op
metric_tensor = control_flow_ops.with_dependencies(
- [metric_update_op], constant_op.constant(metric_value))
+ [metric_update_op], constant_op.constant(metric_value_1))
+
+ mean = metrics_module.Mean()
+ mean.update_state(metric_value_2)
return model_fn_lib.EstimatorSpec(
mode,
loss=loss,
predictions={'predictions': constant_op.constant(1.)},
train_op=state_ops.assign_add(global_step, 1),
- eval_metric_ops={metric_name: (metric_tensor, metric_update_op)})
+ eval_metric_ops={
+ metric_name_1: (metric_tensor, metric_update_op),
+ metric_name_2: mean,
+ })
class _StepCounterHook(session_run_hook.SessionRunHook):
@@ -1167,16 +1177,22 @@ class EstimatorEvaluateTest(test.TestCase):
def test_no_checkpoint_uses_init(self):
def _model_fn(features, labels, mode, params):
del features, labels, params
+ mean = metrics_module.Mean()
+ mean.update_state(variables.Variable(2.) + 1)
return model_fn_lib.EstimatorSpec(
mode,
loss=constant_op.constant(1.),
- eval_metric_ops={'metric': metrics_lib.mean(
- variables.Variable(2.) + 1)})
+ eval_metric_ops={
+ 'mean1': mean,
+ 'mean2': metrics_lib.mean(variables.Variable(2.) + 1)
+ })
+
est = estimator.Estimator(model_fn=_model_fn)
- metrics = est.evaluate(dummy_input_fn, steps=1)
+ scores = est.evaluate(dummy_input_fn, steps=1)
# Metric value here is set to 1 + the value of the Variable that is newly
# initialized (since there is no checkpoint).
- self.assertEqual(3., metrics['metric'])
+ self.assertEqual(3., scores['mean1'])
+ self.assertEqual(3., scores['mean2'])
def test_no_checkpoint_uses_init_with_warm_starting(self):
def _make_model_fn(x):
@@ -1184,14 +1200,24 @@ class EstimatorEvaluateTest(test.TestCase):
_, _ = features, labels
x_var = variable_scope.get_variable('x', initializer=x)
global_step = training.get_global_step()
+ mean = metrics_module.Mean()
+ mean.update_state(x_var + 1)
return model_fn_lib.EstimatorSpec(
mode,
predictions={'y': constant_op.constant(1.0)},
loss=constant_op.constant(1.),
- eval_metric_ops={'metric': metrics_lib.mean(x_var + 1)},
+ eval_metric_ops={
+ 'mean1': mean,
+ 'mean2': metrics_lib.mean(x_var + 1)
+ },
train_op=state_ops.assign_add(global_step, 1),
- export_outputs={'test': export_output.ClassificationOutput(
- constant_op.constant([4.2]), constant_op.constant(['label']))})
+ export_outputs={
+ 'test':
+ export_output.ClassificationOutput(
+ constant_op.constant([4.2]),
+ constant_op.constant(['label']))
+ })
+
return _variable_creating_and_export_model_fn
first_est = estimator.Estimator(model_fn=_make_model_fn(42.))
@@ -1210,30 +1236,37 @@ class EstimatorEvaluateTest(test.TestCase):
# or an exported SavedModel.
est = estimator.Estimator(model_fn=_make_model_fn(52.),
warm_start_from=exported_path)
- metrics = est.evaluate(dummy_input_fn, steps=1)
+ eval_metrics = est.evaluate(dummy_input_fn, steps=1)
# Metric value here is set to 1 + the value of the Variable that is
# warm-started from the SavedModel of the first model (42.), as opposed to
# the initialization in the new model_fn (52.).
- self.assertEqual(43., metrics['metric'])
+ self.assertEqual(43., eval_metrics['mean1'])
+ self.assertEqual(43., eval_metrics['mean2'])
est = estimator.Estimator(model_fn=_make_model_fn(62.),
warm_start_from=first_est.model_dir)
- metrics = est.evaluate(dummy_input_fn, steps=1)
+ eval_metrics = est.evaluate(dummy_input_fn, steps=1)
# Metric value here is set to 1 + the value of the Variable that is
# warm-started from a checkpoint of the first model (42.), as opposed to
# the initialization in the new model_fn (52.).
- self.assertEqual(43., metrics['metric'])
+ self.assertEqual(43., eval_metrics['mean1'])
+ self.assertEqual(43., eval_metrics['mean2'])
def test_scores(self):
est = estimator.Estimator(
model_fn=_model_fn_with_eval_metric_ops,
params={
'metric_name': 'metric',
- 'metric_value': 2.})
+ 'metric_value': 2.,
+ 'metric_name_2': 'metric2',
+ 'metric_value_2': 3.,
+ })
est.train(dummy_input_fn, steps=5)
scores = est.evaluate(dummy_input_fn, steps=1)
self.assertIn('metric', scores)
self.assertAlmostEqual(2., scores['metric'])
+ self.assertIn('metric2', scores)
+ self.assertAlmostEqual(3., scores['metric2'])
def test_tuple_metrics(self):
def _model_fn(features, labels, mode):
@@ -1284,8 +1317,12 @@ class EstimatorEvaluateTest(test.TestCase):
def test_global_step_is_reported(self):
est = estimator.Estimator(
model_fn=_model_fn_with_eval_metric_ops,
- params={'metric_name': 'metric',
- 'metric_value': 2.})
+ params={
+ 'metric_name': 'metric',
+ 'metric_value': 2.,
+ 'metric_name_2': 'metric2',
+ 'metric_value_2': 3.,
+ })
est.train(dummy_input_fn, steps=5)
scores = est.evaluate(dummy_input_fn, steps=1)
self.assertIn('global_step', scores)
@@ -1328,7 +1365,10 @@ class EstimatorEvaluateTest(test.TestCase):
def test_evaluate_from_checkpoint(self):
params = {
'metric_name': 'metric',
- 'metric_value': 2.}
+ 'metric_value': 2.,
+ 'metric_name_2': 'metric2',
+ 'metric_value_2': 3.,
+ }
est1 = estimator.Estimator(
model_fn=_model_fn_with_eval_metric_ops,
params=params)
@@ -2027,8 +2067,15 @@ def _model_fn_with_x_y(features, labels, mode):
multiplied = math_ops.multiply(
features['x'], features['y'], name='{}multiplied'.format(prefix))
- metrics = {'mean': metrics_lib.mean(features['x'] - features['y'],
- name='{}mean'.format(prefix))}
+ mean = metrics_module.Mean(name='{}mean'.format(prefix))
+ mean.update_state(features['x'] - features['y'])
+ eval_metrics = {
+ 'mean1':
+ mean,
+ 'mean2':
+ metrics_lib.mean(
+ features['x'] - features['y'], name='{}mean'.format(prefix))
+ }
variables.Variable(1., name='later_var')
variables.Variable(3., name='name_collision')
return model_fn_lib.EstimatorSpec(
@@ -2036,7 +2083,7 @@ def _model_fn_with_x_y(features, labels, mode):
predictions=multiplied,
loss=constant_op.constant(1.),
train_op=state_ops.assign_add(training.get_global_step(), 1),
- eval_metric_ops=metrics)
+ eval_metric_ops=eval_metrics)
def _model_fn_with_saveables_for_export_tests(features, labels, mode):
@@ -2395,14 +2442,18 @@ class EstimatorExportTest(test.TestCase):
def _model_fn(features, labels, mode):
del features, labels # Unused
- metrics = {'metrics': (constant_op.constant([0]),
- control_flow_ops.no_op())}
+ metric_obj = metrics_module.Mean()
+ metric_obj.update_state(constant_op.constant([0]))
+ eval_metrics = {
+ 'metrics1': (constant_op.constant([0]), control_flow_ops.no_op()),
+ 'metrics2': metric_obj,
+ }
return model_fn_lib.EstimatorSpec(
mode,
predictions=constant_op.constant(10.),
loss=constant_op.constant(1.),
train_op=state_ops.assign_add(training.get_global_step(), 1),
- eval_metric_ops=metrics)
+ eval_metric_ops=eval_metrics)
tmpdir = tempfile.mkdtemp()
est = estimator.Estimator(model_fn=_model_fn)
@@ -2424,8 +2475,10 @@ class EstimatorExportTest(test.TestCase):
meta_graph = loader.load(sess, [tag_constants.EVAL], export_dir)
sig_outputs = meta_graph.signature_def[
model_fn_lib.ModeKeys.EVAL].outputs
- self.assertEqual(
- sig_outputs['metrics/update_op'].name, 'metric_op_wrapper:0')
+ self.assertTrue(sig_outputs['metrics1/update_op'].name.startswith(
+ 'metric_op_wrapper'))
+ self.assertTrue(sig_outputs['metrics2/update_op'].name.startswith(
+ 'metric_op_wrapper'))
def test_export_savedmodel_with_saveables_proto_roundtrip(self):
tmpdir = tempfile.mkdtemp()
@@ -3080,9 +3133,13 @@ class EstimatorIntegrationTest(test.TestCase):
loss = losses.mean_squared_error(labels, predictions)
train_op = training.GradientDescentOptimizer(learning_rate=0.5).minimize(
loss, training.get_global_step())
+ mean = metrics_module.Mean()
+ mean.update_state(loss)
eval_metric_ops = {
- 'absolute_error': metrics_lib.mean_absolute_error(
- labels, predictions)
+ 'absolute_error':
+ metrics_lib.mean_absolute_error(labels, predictions),
+ 'mean':
+ mean,
}
return model_fn_lib.EstimatorSpec(
@@ -3102,12 +3159,13 @@ class EstimatorIntegrationTest(test.TestCase):
x={'x': data}, y=data, batch_size=50, num_epochs=None, shuffle=True)
est.train(train_input_fn, steps=200)
- # EVALUTE
+ # EVALUATE
eval_input_fn = numpy_io.numpy_input_fn(
x={'x': data}, y=data, batch_size=50, num_epochs=1, shuffle=True)
scores = est.evaluate(eval_input_fn)
self.assertEqual(200, scores['global_step'])
self.assertGreater(0.1, scores['absolute_error'])
+ self.assertAlmostEqual(4.4e-14, scores['mean'], places=2)
# PREDICT
predict_input_fn = numpy_io.numpy_input_fn(
diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py
index 20382a58d8..c17fc08f21 100644
--- a/tensorflow/python/estimator/export/export_output.py
+++ b/tensorflow/python/estimator/export/export_output.py
@@ -26,6 +26,7 @@ import six
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.util.tf_export import estimator_export
@@ -259,7 +260,10 @@ class _SupervisedOutput(ExportOutput):
loss: dict of Tensors or single Tensor representing calculated loss.
predictions: dict of Tensors or single Tensor representing model
predictions.
- metrics: dict of (metric_value, update_op) tuples, or a single tuple.
+ metrics: Dict of metric results keyed by name.
+ The values of the dict can be one of the following:
+ (1) instance of `Metric` class.
+ (2) (metric_value, update_op) tuples, or a single tuple.
metric_value must be a Tensor, and update_op must be a Tensor or Op.
Raises:
@@ -311,7 +315,11 @@ class _SupervisedOutput(ExportOutput):
Here, we separate out the tuples and create a dict with names to tensors.
Args:
- metrics: dict of (metric_value, update_op) tuples, or a single tuple.
+ metrics: Dict of metric results keyed by name.
+ The values of the dict can be one of the following:
+ (1) instance of `Metric` class.
+ (2) (metric_value, update_op) tuples, or a single tuple.
+ metric_value must be a Tensor, and update_op must be a Tensor or Op.
Returns:
dict of output_names to tensors
@@ -324,7 +332,13 @@ class _SupervisedOutput(ExportOutput):
metrics = {self.METRICS_NAME: metrics}
outputs = {}
- for key, (metric_val, metric_op) in metrics.items():
+ for key, value in metrics.items():
+ if isinstance(value, metrics_module.Metric):
+ metric_val = value.result()
+ assert len(value.updates) == 1 # We expect only one update op.
+ metric_op = value.updates[0]
+ else:
+ metric_val, metric_op = value
key = self._check_output_key(key, self.METRICS_NAME)
key = self._prefix_key(key, self.METRICS_NAME)
@@ -397,7 +411,3 @@ class EvalOutput(_SupervisedOutput):
def _get_signature_def_fn(self):
return signature_def_utils.supervised_eval_signature_def
-
-
-
-
diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py
index d94c764fd7..96ce0e580d 100644
--- a/tensorflow/python/estimator/export/export_output_test.py
+++ b/tensorflow/python/estimator/export/export_output_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
@@ -240,16 +241,19 @@ class SupervisedOutputTest(test.TestCase):
"""Tests that no errors are raised when provided outputs are valid."""
loss = {"my_loss": constant_op.constant([0])}
predictions = {u"output1": constant_op.constant(["foo"])}
- metrics = {"metrics": (constant_op.constant([0]),
- constant_op.constant([10])),
- "metrics2": (constant_op.constant([0]),
- constant_op.constant([10]))}
+ metric_obj = metrics_module.Mean()
+ metric_obj.update_state(constant_op.constant([0]))
+ metrics = {
+ "metrics": metric_obj,
+ "metrics2": (constant_op.constant([0]), constant_op.constant([10]))
+ }
outputter = MockSupervisedOutput(loss, predictions, metrics)
self.assertEqual(outputter.loss["loss/my_loss"], loss["my_loss"])
self.assertEqual(
outputter.predictions["predictions/output1"], predictions["output1"])
- self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0])
+ self.assertEqual(outputter.metrics["metrics/update_op"].name,
+ "metric_op_wrapper:0")
self.assertEqual(
outputter.metrics["metrics2/update_op"], metrics["metrics2"][1])
@@ -259,7 +263,8 @@ class SupervisedOutputTest(test.TestCase):
self.assertEqual(outputter.loss, {"loss": loss["my_loss"]})
self.assertEqual(
outputter.predictions, {"predictions": predictions["output1"]})
- self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0])
+ self.assertEqual(outputter.metrics["metrics/update_op"].name,
+ "metric_op_wrapper_1:0")
def test_supervised_outputs_none(self):
outputter = MockSupervisedOutput(
@@ -282,34 +287,56 @@ class SupervisedOutputTest(test.TestCase):
"""Tests that no errors are raised when provided outputs are valid."""
loss = {("my", "loss"): constant_op.constant([0])}
predictions = {(u"output1", "2"): constant_op.constant(["foo"])}
- metrics = {("metrics", "twice"): (constant_op.constant([0]),
- constant_op.constant([10]))}
+ metric_obj = metrics_module.Mean()
+ metric_obj.update_state(constant_op.constant([0]))
+ metrics = {
+ ("metrics", "1"):
+ metric_obj,
+ ("metrics", "2"): (constant_op.constant([0]),
+ constant_op.constant([10]))
+ }
outputter = MockSupervisedOutput(loss, predictions, metrics)
self.assertEqual(set(outputter.loss.keys()), set(["loss/my/loss"]))
self.assertEqual(set(outputter.predictions.keys()),
set(["predictions/output1/2"]))
- self.assertEqual(set(outputter.metrics.keys()),
- set(["metrics/twice/value", "metrics/twice/update_op"]))
+ self.assertEqual(
+ set(outputter.metrics.keys()),
+ set([
+ "metrics/1/value", "metrics/1/update_op", "metrics/2/value",
+ "metrics/2/update_op"
+ ]))
def test_supervised_outputs_no_prepend(self):
"""Tests that no errors are raised when provided outputs are valid."""
loss = {"loss": constant_op.constant([0])}
predictions = {u"predictions": constant_op.constant(["foo"])}
- metrics = {u"metrics": (constant_op.constant([0]),
- constant_op.constant([10]))}
+ metric_obj = metrics_module.Mean()
+ metric_obj.update_state(constant_op.constant([0]))
+ metrics = {
+ "metrics_1": metric_obj,
+ "metrics_2": (constant_op.constant([0]), constant_op.constant([10]))
+ }
outputter = MockSupervisedOutput(loss, predictions, metrics)
self.assertEqual(set(outputter.loss.keys()), set(["loss"]))
self.assertEqual(set(outputter.predictions.keys()), set(["predictions"]))
- self.assertEqual(set(outputter.metrics.keys()),
- set(["metrics/value", "metrics/update_op"]))
+ self.assertEqual(
+ set(outputter.metrics.keys()),
+ set([
+ "metrics_1/value", "metrics_1/update_op", "metrics_2/update_op",
+ "metrics_2/value"
+ ]))
def test_train_signature_def(self):
loss = {"my_loss": constant_op.constant([0])}
predictions = {u"output1": constant_op.constant(["foo"])}
- metrics = {"metrics": (constant_op.constant([0]),
- constant_op.constant([10]))}
+ metric_obj = metrics_module.Mean()
+ metric_obj.update_state(constant_op.constant([0]))
+ metrics = {
+ "metrics_1": metric_obj,
+ "metrics_2": (constant_op.constant([0]), constant_op.constant([10]))
+ }
outputter = export_output_lib.TrainOutput(loss, predictions, metrics)
@@ -318,7 +345,8 @@ class SupervisedOutputTest(test.TestCase):
sig_def = outputter.as_signature_def(receiver)
self.assertTrue("loss/my_loss" in sig_def.outputs)
- self.assertTrue("metrics/value" in sig_def.outputs)
+ self.assertTrue("metrics_1/value" in sig_def.outputs)
+ self.assertTrue("metrics_2/value" in sig_def.outputs)
self.assertTrue("predictions/output1" in sig_def.outputs)
self.assertTrue("features" in sig_def.inputs)
@@ -337,18 +365,33 @@ class SupervisedOutputTest(test.TestCase):
self.assertTrue("predictions/output1" in sig_def.outputs)
self.assertTrue("features" in sig_def.inputs)
- def test_metric_op_is_operation(self):
+ def test_metric_op_is_tensor(self):
"""Tests that ops.Operation is wrapped by a tensor for metric_ops."""
loss = {"my_loss": constant_op.constant([0])}
predictions = {u"output1": constant_op.constant(["foo"])}
- metrics = {"metrics": (constant_op.constant([0]), control_flow_ops.no_op())}
+ metric_obj = metrics_module.Mean()
+ metric_obj.update_state(constant_op.constant([0]))
+ metrics = {
+ "metrics_1": metric_obj,
+ "metrics_2": (constant_op.constant([0]), control_flow_ops.no_op())
+ }
outputter = MockSupervisedOutput(loss, predictions, metrics)
- self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0])
- self.assertEqual(
- outputter.metrics["metrics/update_op"].name, "metric_op_wrapper:0")
+
+ self.assertTrue(outputter.metrics["metrics_1/update_op"].name.startswith(
+ "metric_op_wrapper"))
+ self.assertTrue(
+ isinstance(outputter.metrics["metrics_1/update_op"], ops.Tensor))
self.assertTrue(
- isinstance(outputter.metrics["metrics/update_op"], ops.Tensor))
+ isinstance(outputter.metrics["metrics_1/value"], ops.Tensor))
+
+ self.assertEqual(outputter.metrics["metrics_2/value"],
+ metrics["metrics_2"][0])
+ self.assertTrue(outputter.metrics["metrics_2/update_op"].name.startswith(
+ "metric_op_wrapper"))
+ self.assertTrue(
+ isinstance(outputter.metrics["metrics_2/update_op"], ops.Tensor))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index 6361c6acc1..6b2765be82 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -182,10 +182,58 @@ def _clone_and_build_model(mode,
K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN)
input_tensors, target_tensors = _convert_estimator_io_to_keras(
keras_model, features, labels)
- return models.clone_and_build_model(
+
+ compile_clone = (mode != model_fn_lib.ModeKeys.PREDICT)
+
+ global_step = None
+ if compile_clone:
+ # Set iterations to the global step created by tf.train.create_global_step()
+ # which is automatically run in the estimator framework.
+ global_step = training_util.get_or_create_global_step()
+ K.track_variable(global_step)
+
+ clone = models.clone_and_build_model(
keras_model, input_tensors, target_tensors, custom_objects,
- compile_clone=(mode != model_fn_lib.ModeKeys.PREDICT),
- in_place_reset=(not keras_model._is_graph_network))
+ compile_clone=compile_clone,
+ in_place_reset=(not keras_model._is_graph_network),
+ optimizer_iterations=global_step)
+
+ return clone
+
+
+def _convert_keras_metrics_to_estimator(model):
+ """Convert metrics from a Keras model to ops used by the Estimator framework.
+
+ Args:
+ model: A `tf.keras.Model` object.
+
+ Returns:
+ Dictionary mapping metric names to tuples of (value, update) ops. May return
+ `None` if the model does not contain any metrics.
+ """
+ if not getattr(model, 'metrics', None):
+ return None
+
+ # TODO(psv/fchollet): support stateful metrics
+ eval_metric_ops = {}
+ # When each metric maps to an output
+ if isinstance(model.metrics, dict):
+ for i, output_name in enumerate(model.metrics.keys()):
+ metric_name = model.metrics[output_name]
+ if callable(metric_name):
+ metric_name = metric_name.__name__
+ # When some outputs use the same metric
+ if list(model.metrics.values()).count(metric_name) > 1:
+ metric_name += '_' + output_name
+ eval_metric_ops[metric_name] = metrics_module.mean(
+ model.metrics_tensors[i - len(model.metrics)])
+ else:
+ for i, metric_name in enumerate(model.metrics):
+ if callable(metric_name):
+ metric_name = metric_name.__name__
+ eval_metric_ops[metric_name] = metrics_module.mean(
+ model.metrics_tensors[i])
+ return eval_metric_ops
def _create_keras_model_fn(keras_model, custom_objects=None):
@@ -237,26 +285,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
model._make_test_function() # pylint: disable=protected-access
loss = model.total_loss
- if model.metrics:
- # TODO(psv/fchollet): support stateful metrics
- eval_metric_ops = {}
- # When each metric maps to an output
- if isinstance(model.metrics, dict):
- for i, output_name in enumerate(model.metrics.keys()):
- metric_name = model.metrics[output_name]
- if callable(metric_name):
- metric_name = metric_name.__name__
- # When some outputs use the same metric
- if list(model.metrics.values()).count(metric_name) > 1:
- metric_name += '_' + output_name
- eval_metric_ops[metric_name] = metrics_module.mean(
- model.metrics_tensors[i - len(model.metrics)])
- else:
- for i, metric_name in enumerate(model.metrics):
- if callable(metric_name):
- metric_name = metric_name.__name__
- eval_metric_ops[metric_name] = metrics_module.mean(
- model.metrics_tensors[i])
+ eval_metric_ops = _convert_keras_metrics_to_estimator(model)
# Set train_op only during train.
if mode is model_fn_lib.ModeKeys.TRAIN:
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 007970bef7..439cc2e3a4 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -26,6 +26,7 @@ import six
from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras.metrics import Metric
from tensorflow.python.ops import array_ops
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
@@ -141,13 +142,15 @@ class EstimatorSpec(
prediction.
predictions: Predictions `Tensor` or dict of `Tensor`.
loss: Training loss `Tensor`. Must be either scalar, or with shape `[1]`.
- train_op: Op to run one training step.
- eval_metric_ops: Dict of metric results keyed by name. The values of the
- dict are the results of calling a metric function, namely a
- `(metric_tensor, update_op)` tuple. `metric_tensor` should be evaluated
- without any impact on state (typically is a pure computation results
- based on variables.). For example, it should not trigger the `update_op`
- or requires any input fetching.
+ train_op: Op for the training step.
+ eval_metric_ops: Dict of metric results keyed by name.
+ The values of the dict can be one of the following:
+ (1) instance of `Metric` class.
+ (2) Results of calling a metric function, namely a
+ `(metric_tensor, update_op)` tuple. `metric_tensor` should be
+ evaluated without any impact on state (typically is a pure computation
+ results based on variables.). For example, it should not trigger the
+ `update_op` or requires any input fetching.
export_outputs: Describes the output signatures to be exported to
`SavedModel` and used during serving.
A dict `{name: output}` where:
@@ -218,21 +221,27 @@ class EstimatorSpec(
if not isinstance(eval_metric_ops, dict):
raise TypeError(
'eval_metric_ops must be a dict, given: {}'.format(eval_metric_ops))
- for key, metric_value_and_update in six.iteritems(eval_metric_ops):
- if (not isinstance(metric_value_and_update, tuple) or
- len(metric_value_and_update) != 2):
- raise TypeError(
- 'Values of eval_metric_ops must be (metric_value, update_op) '
- 'tuples, given: {} for key: {}'.format(
- metric_value_and_update, key))
- metric_value, metric_update = metric_value_and_update
- for metric_value_member in nest.flatten(metric_value):
- # Allow (possibly nested) tuples for metric values, but require that
- # each of them be Tensors or Operations.
- _check_is_tensor_or_operation(metric_value_member,
+ for key, value in six.iteritems(eval_metric_ops):
+ # TODO(psv): When we deprecate the old metrics, throw an error here if
+ # the value is not an instance of `Metric` class.
+ if isinstance(value, Metric):
+ if not value.updates: # Check if metrics updates are available.
+ raise ValueError(
+ 'Please call update_state(...) on the "{metric_name}" metric'
+ .format(metric_name=value.name))
+ else:
+ if not isinstance(value, tuple) or len(value) != 2:
+ raise TypeError(
+ 'Values of eval_metric_ops must be (metric_value, update_op) '
+ 'tuples, given: {} for key: {}'.format(value, key))
+ metric_value, metric_update = value
+ for metric_value_member in nest.flatten(metric_value):
+ # Allow (possibly nested) tuples for metric values, but require that
+ # each of them be Tensors or Operations.
+ _check_is_tensor_or_operation(metric_value_member,
+ 'eval_metric_ops[{}]'.format(key))
+ _check_is_tensor_or_operation(metric_update,
'eval_metric_ops[{}]'.format(key))
- _check_is_tensor_or_operation(metric_update,
- 'eval_metric_ops[{}]'.format(key))
# Validate the passed export outputs, or generate defaults.
if mode == ModeKeys.PREDICT:
@@ -267,8 +276,12 @@ class EstimatorSpec(
if train_op is not None and train_op.graph is not default_graph:
raise ValueError(error_message_template.format('train_op', train_op.name))
for key, value in list(six.iteritems(eval_metric_ops)):
- values = nest.flatten(value)
- for val in values:
+ if isinstance(value, Metric):
+ values_to_check = value.updates[:]
+ values_to_check.append(value.result())
+ else:
+ values_to_check = nest.flatten(value)
+ for val in values_to_check:
if val.graph is not default_graph:
raise ValueError(error_message_template.format(
'eval_metric_ops',
@@ -287,6 +300,19 @@ class EstimatorSpec(
'All hooks must be SessionRunHook instances, given: {}'.format(
hook))
+ # Add metric variables to the `LOCAL_VARIABLES` collection. Metric variables
+ # are by default not added to any collections. We are doing this here, so
+ # that metric variables get initialized.
+ local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
+ vars_to_add = set()
+ for key, value in six.iteritems(eval_metric_ops):
+ if isinstance(value, Metric):
+ vars_to_add.update(value.variables)
+ # Remove variables that are in the local variables collection already.
+ vars_to_add = vars_to_add.difference(local_vars)
+ for v in vars_to_add:
+ ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, v)
+
scaffold = scaffold or monitored_session.Scaffold()
# Validate scaffold.
if not isinstance(scaffold, monitored_session.Scaffold):
@@ -449,3 +475,44 @@ def _check_is_tensor(x, tensor_name):
if not isinstance(x, ops.Tensor):
raise TypeError('{} must be Tensor, given: {}'.format(tensor_name, x))
return x
+
+
+def export_outputs_for_mode(
+ mode, serving_export_outputs=None, predictions=None, loss=None,
+ metrics=None):
+ """Util function for constructing a `ExportOutput` dict given a mode.
+
+ The returned dict can be directly passed to `build_all_signature_defs` helper
+ function as the `export_outputs` argument, used for generating a SignatureDef
+ map.
+
+ Args:
+ mode: A `ModeKeys` specifying the mode.
+ serving_export_outputs: Describes the output signatures to be exported to
+ `SavedModel` and used during serving. Should be a dict or None.
+ predictions: A dict of Tensors or single Tensor representing model
+ predictions. This argument is only used if serving_export_outputs is not
+ set.
+ loss: A dict of Tensors or single Tensor representing calculated loss.
+ metrics: A dict of (metric_value, update_op) tuples, or a single tuple.
+ metric_value must be a Tensor, and update_op must be a Tensor or Op
+
+ Returns:
+ Dictionary mapping the a key to an `tf.estimator.export.ExportOutput` object
+ The key is the expected SignatureDef key for the mode.
+
+ Raises:
+ ValueError: if an appropriate ExportOutput cannot be found for the mode.
+ """
+ # TODO(b/113185250): move all model export helper functions into an util file.
+ if mode == ModeKeys.PREDICT:
+ return _get_export_outputs(serving_export_outputs, predictions)
+ elif mode == ModeKeys.TRAIN:
+ return {mode: export_output_lib.TrainOutput(
+ loss=loss, predictions=predictions, metrics=metrics)}
+ elif mode == ModeKeys.EVAL:
+ return {mode: export_output_lib.EvalOutput(
+ loss=loss, predictions=predictions, metrics=metrics)}
+ else:
+ raise ValueError(
+ 'Export output type not found for mode: {}'.format(mode))
diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py
index b6f1b16a22..8a3a9f3f51 100644
--- a/tensorflow/python/estimator/model_fn_test.py
+++ b/tensorflow/python/estimator/model_fn_test.py
@@ -24,6 +24,7 @@ from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.keras import metrics
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
@@ -60,12 +61,17 @@ class EstimatorSpecTrainTest(test.TestCase):
loss = constant_op.constant(1.)
predictions = {'loss': loss}
classes = constant_op.constant('hello')
+ metric_obj = metrics.Mean()
+ metric_obj.update_state(loss)
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
predictions=predictions,
loss=loss,
train_op=control_flow_ops.no_op(),
- eval_metric_ops={'loss': (control_flow_ops.no_op(), loss)},
+ eval_metric_ops={
+ 'loss': (control_flow_ops.no_op(), loss),
+ 'mean': metric_obj,
+ },
export_outputs={
'head_name': export_output.ClassificationOutput(classes=classes)
},
@@ -212,12 +218,17 @@ class EstimatorSpecEvalTest(test.TestCase):
loss = constant_op.constant(1.)
predictions = {'loss': loss}
classes = constant_op.constant('hello')
+ metric_obj = metrics.Mean()
+ metric_obj.update_state(loss)
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
predictions=predictions,
loss=loss,
train_op=control_flow_ops.no_op(),
- eval_metric_ops={'loss': (control_flow_ops.no_op(), loss)},
+ eval_metric_ops={
+ 'loss': (control_flow_ops.no_op(), loss),
+ 'mean': metric_obj,
+ },
export_outputs={
'head_name': export_output.ClassificationOutput(classes=classes)
},
@@ -423,7 +434,7 @@ class EstimatorSpecEvalTest(test.TestCase):
eval_metric_ops={'loss': ((('NonTensor',),),
control_flow_ops.no_op())})
- def testEvalMetricOpsFromDifferentGraph(self):
+ def testEvalMetricOpsFromDifferentGraphWithMetricTuple(self):
with ops.Graph().as_default():
eval_metric_ops = {
'loss': (control_flow_ops.no_op(), constant_op.constant(1.))}
@@ -437,6 +448,33 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=loss,
eval_metric_ops=eval_metric_ops)
+ def testEvalMetricOpsFromDifferentGraphWithMetricObject(self):
+ with ops.Graph().as_default():
+ metric_obj = metrics.Mean()
+ metric_obj.update_state(constant_op.constant(1.))
+ eval_metric_ops = {'metric': metric_obj}
+ with ops.Graph().as_default(), self.cached_session():
+ loss = constant_op.constant(1.)
+ with self.assertRaisesRegexp(
+ ValueError, 'must be from the default graph'):
+ model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.EVAL,
+ predictions={'loss': loss},
+ loss=loss,
+ eval_metric_ops=eval_metric_ops)
+
+ def testEvalMetricOpsWithoutUpdates(self):
+ with ops.Graph().as_default():
+ eval_metric_ops = {'mean': metrics.Mean()}
+ with ops.Graph().as_default(), self.cached_session():
+ loss = constant_op.constant(1.)
+ with self.assertRaisesRegexp(ValueError, 'Please call update_state(...)'):
+ model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.EVAL,
+ predictions={'loss': loss},
+ loss=loss,
+ eval_metric_ops=eval_metric_ops)
+
class EstimatorSpecInferTest(test.TestCase):
"""Tests EstimatorSpec in infer mode."""
@@ -454,12 +492,17 @@ class EstimatorSpecInferTest(test.TestCase):
loss = constant_op.constant(1.)
predictions = {'loss': loss}
classes = constant_op.constant('hello')
+ metric_obj = metrics.Mean()
+ metric_obj.update_state(loss)
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
loss=loss,
train_op=control_flow_ops.no_op(),
- eval_metric_ops={'loss': (control_flow_ops.no_op(), loss)},
+ eval_metric_ops={
+ 'loss': (control_flow_ops.no_op(), loss),
+ 'mean': metric_obj,
+ },
export_outputs={
'head_name': export_output.ClassificationOutput(classes=classes)
},
diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py
index 6e844e14b9..a69018d00d 100644
--- a/tensorflow/python/framework/error_interpolation.py
+++ b/tensorflow/python/framework/error_interpolation.py
@@ -26,21 +26,17 @@ import collections
import itertools
import os
import re
-import string
import six
from tensorflow.python.util import tf_stack
-
_NAME_REGEX = r"[A-Za-z0-9.][A-Za-z0-9_.\-/]*?"
-_FORMAT_REGEX = r"[A-Za-z0-9_.\-/${}:]+"
-_TAG_REGEX = r"\^\^({name}):({name}):({fmt})\^\^".format(
- name=_NAME_REGEX, fmt=_FORMAT_REGEX)
+_TAG_REGEX = r"\^\^({name}):({name})\^\^".format(name=_NAME_REGEX)
_INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX)
_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX)
-_ParseTag = collections.namedtuple("_ParseTag", ["type", "name", "format"])
+_ParseTag = collections.namedtuple("_ParseTag", ["type", "name"])
_BAD_FILE_SUBSTRINGS = [
os.path.join("tensorflow", "python"),
@@ -52,16 +48,9 @@ def _parse_message(message):
"""Parses the message.
Splits the message into separators and tags. Tags are named tuples
- representing the string ^^type:name:format^^ and they are separated by
- separators. For example, in
- "123^^node:Foo:${file}^^456^^node:Bar:${line}^^789", there are two tags and
- three separators. The separators are the numeric characters.
-
- Supported tags after node:<node_name>
- file: Replaced with the filename in which the node was defined.
- line: Replaced by the line number at which the node was defined.
- colocations: Replaced by a multi-line message describing the file and
- line numbers at which this node was colocated with other nodes.
+ representing the string ^^type:name^^ and they are separated by
+ separators. For example, in "123^^node:Foo^^456^^node:Bar^^789", there are
+ two tags and three separators. The separators are the numeric characters.
Args:
message: String to parse
@@ -69,8 +58,8 @@ def _parse_message(message):
Returns:
(list of separator strings, list of _ParseTags).
- For example, if message is "123^^node:Foo:${file}^^456" then this function
- returns (["123", "456"], [_ParseTag("node", "Foo", "${file}")])
+ For example, if message is "123^^node:Foo^^456" then this function
+ returns (["123", "456"], [_ParseTag("node", "Foo")])
"""
seps = []
tags = []
@@ -79,7 +68,7 @@ def _parse_message(message):
match = re.match(_INTERPOLATION_PATTERN, message[pos:])
if match:
seps.append(match.group(1))
- tags.append(_ParseTag(match.group(3), match.group(4), match.group(5)))
+ tags.append(_ParseTag(match.group(3), match.group(4)))
pos += match.end()
else:
break
@@ -111,12 +100,12 @@ def _compute_device_summary_from_list(name, device_assignment_list, prefix=""):
return prefix + message
str_list = []
- str_list.append("%sDevice assignments active during op '%s' creation:"
- % (prefix, name))
+ str_list.append(
+ "%sDevice assignments active during op '%s' creation:" % (prefix, name))
for traceable_obj in device_assignment_list:
- location_summary = "<{file}:{line}>".format(file=traceable_obj.filename,
- line=traceable_obj.lineno)
+ location_summary = "<{file}:{line}>".format(
+ file=traceable_obj.filename, line=traceable_obj.lineno)
subs = {
"prefix": prefix,
"indent": " ",
@@ -160,12 +149,12 @@ def _compute_colocation_summary_from_dict(name, colocation_dict, prefix=""):
return prefix + message
str_list = []
- str_list.append("%sNode-device colocations active during op '%s' creation:"
- % (prefix, name))
+ str_list.append("%sNode-device colocations active during op '%s' creation:" %
+ (prefix, name))
for coloc_name, location in colocation_dict.items():
- location_summary = "<{file}:{line}>".format(file=location.filename,
- line=location.lineno)
+ location_summary = "<{file}:{line}>".format(
+ file=location.filename, line=location.lineno)
subs = {
"prefix": prefix,
"indent": " ",
@@ -180,8 +169,10 @@ def _compute_colocation_summary_from_dict(name, colocation_dict, prefix=""):
def _compute_colocation_summary_from_op(op, prefix=""):
"""Fetch colocation file, line, and nesting and return a summary string."""
- return _compute_colocation_summary_from_dict(
- op.name, op._colocation_dict, prefix) # pylint: disable=protected-access
+ # pylint: disable=protected-access
+ return _compute_colocation_summary_from_dict(op.name, op._colocation_dict,
+ prefix)
+ # pylint: enable=protected-access
def _find_index_of_defining_frame_for_op(op):
@@ -276,7 +267,7 @@ def compute_field_dict(op):
def interpolate(error_message, graph):
"""Interpolates an error message.
- The error message can contain tags of the form ^^type:name:format^^ which will
+ The error message can contain tags of the form ^^type:name^^ which will
be replaced.
Args:
@@ -285,29 +276,29 @@ def interpolate(error_message, graph):
message.
Returns:
- The string with tags of the form ^^type:name:format^^ interpolated.
+ The string with tags of the form ^^type:name^^ interpolated.
"""
seps, tags = _parse_message(error_message)
+ subs = []
+ end_msg = ""
- node_name_to_substitution_dict = {}
- for name in [t.name for t in tags]:
- if name in node_name_to_substitution_dict:
- continue
+ for t in tags:
try:
- op = graph.get_operation_by_name(name)
+ op = graph.get_operation_by_name(t.name)
except KeyError:
op = None
+ msg = "^^%s:%s^^" % (t.type, t.name)
if op is not None:
field_dict = compute_field_dict(op)
- else:
- msg = "<NA>"
- field_dict = collections.defaultdict(lambda s=msg: s)
- node_name_to_substitution_dict[name] = field_dict
-
- subs = [
- string.Template(tag.format).safe_substitute(
- node_name_to_substitution_dict[tag.name]) for tag in tags
- ]
+ if t.type == "node":
+ msg = "node %s%s " % (t.name, field_dict["defined_at"])
+ elif t.type == "colocation_node":
+ msg = "node %s%s having device %s " % (t.name, field_dict["defined_at"],
+ field_dict["devices"])
+ end_msg += "\n\n" + field_dict["devs_and_colocs"]
+ subs.append(msg)
+ subs.append(end_msg)
+
return "".join(
itertools.chain(*six.moves.zip_longest(seps, subs, fillvalue="")))
diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py
index 0427156b2b..a7c7bbf28b 100644
--- a/tensorflow/python/framework/error_interpolation_test.py
+++ b/tensorflow/python/framework/error_interpolation_test.py
@@ -50,9 +50,9 @@ def _modify_op_stack_with_filenames(op, num_user_frames, user_filename,
stack = []
for idx in range(0, num_outer_frames):
stack.append(op._traceback[idx])
- for idx in range(len(stack), len(stack)+num_user_frames):
+ for idx in range(len(stack), len(stack) + num_user_frames):
stack.append(_make_frame_with_filename(op, idx, user_filename % idx))
- for idx in range(len(stack), len(stack)+num_inner_tf_frames):
+ for idx in range(len(stack), len(stack) + num_inner_tf_frames):
stack.append(_make_frame_with_filename(op, idx, tf_filename % idx))
op._traceback = stack
@@ -62,13 +62,11 @@ class ComputeDeviceSummaryFromOpTest(test.TestCase):
def testCorrectFormatWithActiveDeviceAssignments(self):
assignments = []
assignments.append(
- traceable_stack.TraceableObject("/cpu:0",
- filename="hope.py",
- lineno=24))
+ traceable_stack.TraceableObject(
+ "/cpu:0", filename="hope.py", lineno=24))
assignments.append(
- traceable_stack.TraceableObject("/gpu:2",
- filename="please.py",
- lineno=42))
+ traceable_stack.TraceableObject(
+ "/gpu:2", filename="please.py", lineno=42))
summary = error_interpolation._compute_device_summary_from_list(
"nodename", assignments, prefix=" ")
@@ -90,12 +88,10 @@ class ComputeDeviceSummaryFromOpTest(test.TestCase):
class ComputeColocationSummaryFromOpTest(test.TestCase):
def testCorrectFormatWithActiveColocations(self):
- t_obj_1 = traceable_stack.TraceableObject(None,
- filename="test_1.py",
- lineno=27)
- t_obj_2 = traceable_stack.TraceableObject(None,
- filename="test_2.py",
- lineno=38)
+ t_obj_1 = traceable_stack.TraceableObject(
+ None, filename="test_1.py", lineno=27)
+ t_obj_2 = traceable_stack.TraceableObject(
+ None, filename="test_2.py", lineno=38)
colocation_dict = {
"test_node_1": t_obj_1,
"test_node_2": t_obj_2,
@@ -140,10 +136,11 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
def testFindIndexOfDefiningFrameForOp(self):
local_op = constant_op.constant(42).op
user_filename = "hope.py"
- _modify_op_stack_with_filenames(local_op,
- num_user_frames=3,
- user_filename=user_filename,
- num_inner_tf_frames=5)
+ _modify_op_stack_with_filenames(
+ local_op,
+ num_user_frames=3,
+ user_filename=user_filename,
+ num_inner_tf_frames=5)
idx = error_interpolation._find_index_of_defining_frame_for_op(local_op)
# Expected frame is 6th from the end because there are 5 inner frames witih
# TF filenames.
@@ -155,44 +152,39 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
# Truncate stack to known length.
local_op._traceback = local_op._traceback[:7]
# Ensure all frames look like TF frames.
- _modify_op_stack_with_filenames(local_op,
- num_user_frames=0,
- user_filename="user_file.py",
- num_inner_tf_frames=7)
+ _modify_op_stack_with_filenames(
+ local_op,
+ num_user_frames=0,
+ user_filename="user_file.py",
+ num_inner_tf_frames=7)
idx = error_interpolation._find_index_of_defining_frame_for_op(local_op)
self.assertEqual(0, idx)
def testNothingToDo(self):
normal_string = "This is just a normal string"
- interpolated_string = error_interpolation.interpolate(normal_string,
- self.graph)
+ interpolated_string = error_interpolation.interpolate(
+ normal_string, self.graph)
self.assertEqual(interpolated_string, normal_string)
- def testOneTag(self):
- one_tag_string = "^^node:Two:${file}^^"
- interpolated_string = error_interpolation.interpolate(one_tag_string,
- self.graph)
- self.assertTrue(interpolated_string.endswith("constant_op.py"),
- "interpolated_string '%s' did not end with constant_op.py"
- % interpolated_string)
-
def testOneTagWithAFakeNameResultsInPlaceholders(self):
- one_tag_string = "^^node:MinusOne:${file}^^"
- interpolated_string = error_interpolation.interpolate(one_tag_string,
- self.graph)
- self.assertEqual("<NA>", interpolated_string)
+ one_tag_string = "^^node:MinusOne^^"
+ interpolated_string = error_interpolation.interpolate(
+ one_tag_string, self.graph)
+ self.assertEqual(one_tag_string, interpolated_string)
def testTwoTagsNoSeps(self):
- two_tags_no_seps = "^^node:One:${file}^^^^node:Three:${line}^^"
- interpolated_string = error_interpolation.interpolate(two_tags_no_seps,
- self.graph)
- self.assertRegexpMatches(interpolated_string, "constant_op.py[0-9]+")
+ two_tags_no_seps = "^^node:One^^^^node:Three^^"
+ interpolated_string = error_interpolation.interpolate(
+ two_tags_no_seps, self.graph)
+ self.assertRegexpMatches(interpolated_string,
+ "constant_op.py:[0-9]+.*constant_op.py:[0-9]+")
def testTwoTagsWithSeps(self):
- two_tags_with_seps = ";;;^^node:Two:${file}^^,,,^^node:Three:${line}^^;;;"
- interpolated_string = error_interpolation.interpolate(two_tags_with_seps,
- self.graph)
- expected_regex = "^;;;.*constant_op.py,,,[0-9]*;;;$"
+ two_tags_with_seps = ";;;^^node:Two^^,,,^^node:Three^^;;;"
+ interpolated_string = error_interpolation.interpolate(
+ two_tags_with_seps, self.graph)
+ expected_regex = (
+ r"^;;;.*constant_op.py:[0-9]+\) ,,,.*constant_op.py:[0-9]*\) ;;;$")
self.assertRegexpMatches(interpolated_string, expected_regex)
@@ -214,30 +206,26 @@ class InterpolateDeviceSummaryTest(test.TestCase):
self.graph = self.three.graph
def testNodeZeroHasNoDeviceSummaryInfo(self):
- message = "^^node:zero:${devices}^^"
+ message = "^^colocation_node:zero^^"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("No device assignments were active", result)
def testNodeOneHasExactlyOneInterpolatedDevice(self):
- message = "^^node:one:${devices}^^"
+ message = "^^colocation_node:one^^"
result = error_interpolation.interpolate(message, self.graph)
- num_devices = result.count("tf.device")
- self.assertEqual(1, num_devices)
- self.assertIn("tf.device(/cpu)", result)
+ self.assertEqual(2, result.count("tf.device(/cpu)"))
def testNodeTwoHasTwoInterpolatedDevice(self):
- message = "^^node:two:${devices}^^"
+ message = "^^colocation_node:two^^"
result = error_interpolation.interpolate(message, self.graph)
- num_devices = result.count("tf.device")
- self.assertEqual(2, num_devices)
- self.assertIn("tf.device(/cpu)", result)
- self.assertIn("tf.device(/cpu:0)", result)
+ self.assertEqual(2, result.count("tf.device(/cpu)"))
+ self.assertEqual(2, result.count("tf.device(/cpu:0)"))
def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self):
- message = "^^node:three:${devices}^^"
+ message = "^^colocation_node:three^^"
result = error_interpolation.interpolate(message, self.graph)
num_devices = result.count("tf.device")
- self.assertEqual(1, num_devices)
+ self.assertEqual(2, num_devices)
name_re = r"_fancy_device_function<.*error_interpolation_test.py, [0-9]+>"
expected_re = r"with tf.device\(.*%s\)" % name_re
self.assertRegexpMatches(result, expected_re)
@@ -268,27 +256,26 @@ class InterpolateColocationSummaryTest(test.TestCase):
self.graph = node_three.graph
def testNodeThreeHasColocationInterpolation(self):
- message = "^^node:Three_with_one:${colocations}^^"
+ message = "^^colocation_node:Three_with_one^^"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("colocate_with(One)", result)
def testNodeFourHasColocationInterpolationForNodeThreeOnly(self):
- message = "^^node:Four_with_three:${colocations}^^"
+ message = "^^colocation_node:Four_with_three^^"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("colocate_with(Three_with_one)", result)
self.assertNotIn(
"One", result,
- "Node One should not appear in Four_with_three's summary:\n%s"
- % result)
+ "Node One should not appear in Four_with_three's summary:\n%s" % result)
def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self):
- message = "^^node:Five_with_one_with_two:${colocations}^^"
+ message = "^^colocation_node:Five_with_one_with_two^^"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("colocate_with(One)", result)
self.assertIn("colocate_with(Two)", result)
def testColocationInterpolationForNodeLackingColocation(self):
- message = "^^node:One:${colocations}^^"
+ message = "^^colocation_node:One^^"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("No node-device colocations", result)
self.assertNotIn("Two", result)
diff --git a/tensorflow/python/framework/errors_impl.py b/tensorflow/python/framework/errors_impl.py
index 9f973de400..5af71f2cfb 100644
--- a/tensorflow/python/framework/errors_impl.py
+++ b/tensorflow/python/framework/errors_impl.py
@@ -25,6 +25,7 @@ from tensorflow.core.lib.core import error_codes_pb2
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.framework import c_api_util
from tensorflow.python.util import compat
+from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@@ -47,11 +48,17 @@ class OpError(Exception):
error_code: The `error_codes_pb2.Code` describing the error.
"""
super(OpError, self).__init__()
- self._message = message
self._node_def = node_def
self._op = op
+ self._message = message
self._error_code = error_code
+ def __reduce__(self):
+ # Allow the subclasses to accept less arguments in their __init__.
+ init_argspec = tf_inspect.getargspec(self.__class__.__init__)
+ args = tuple(getattr(self, arg) for arg in init_argspec.args[1:])
+ return self.__class__, args
+
@property
def message(self):
"""The error message that describes the error."""
diff --git a/tensorflow/python/framework/errors_test.py b/tensorflow/python/framework/errors_test.py
index 62f8ab030c..574b126cae 100644
--- a/tensorflow/python/framework/errors_test.py
+++ b/tensorflow/python/framework/errors_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import gc
+import pickle
import warnings
from tensorflow.core.lib.core import error_codes_pb2
@@ -107,6 +108,34 @@ class ErrorsTest(test.TestCase):
gc.collect()
self.assertEqual(0, self._CountReferences(c_api_util.ScopedTFStatus))
+ def testPickleable(self):
+ for error_code in [
+ errors.CANCELLED,
+ errors.UNKNOWN,
+ errors.INVALID_ARGUMENT,
+ errors.DEADLINE_EXCEEDED,
+ errors.NOT_FOUND,
+ errors.ALREADY_EXISTS,
+ errors.PERMISSION_DENIED,
+ errors.UNAUTHENTICATED,
+ errors.RESOURCE_EXHAUSTED,
+ errors.FAILED_PRECONDITION,
+ errors.ABORTED,
+ errors.OUT_OF_RANGE,
+ errors.UNIMPLEMENTED,
+ errors.INTERNAL,
+ errors.UNAVAILABLE,
+ errors.DATA_LOSS,
+ ]:
+ # pylint: disable=protected-access
+ exc = errors_impl._make_specific_exception(None, None, None, error_code)
+ # pylint: enable=protected-access
+ unpickled = pickle.loads(pickle.dumps(exc))
+ self.assertEqual(exc.node_def, unpickled.node_def)
+ self.assertEqual(exc.op, unpickled.op)
+ self.assertEqual(exc.message, unpickled.message)
+ self.assertEqual(exc.error_code, unpickled.error_code)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 8d72eb39c0..4cfd639bf9 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -802,6 +802,19 @@ class _EagerTensorBase(Tensor):
"""
raise NotImplementedError()
+ def _num_elements(self):
+ """Number of elements of this Tensor.
+
+ Unlike regular Tensors, the number of elements is always known for
+ EagerTensors.
+
+ This is more performant than tensor.shape.num_elements
+
+ Returns:
+ Long - num elements in the tensor
+ """
+ raise NotImplementedError()
+
def _copy_to_device(self, context, device): # pylint: disable=redefined-outer-name
raise NotImplementedError()
@@ -5399,7 +5412,7 @@ def enable_eager_execution(config=None,
TensorFlow graph, or if options provided conflict with a previous call
to this function.
"""
- if context._default_mode != context.EAGER_MODE: # pylint: disable=protected-access
+ if context.default_execution_mode != context.EAGER_MODE:
return enable_eager_execution_internal(
config=config,
device_policy=device_policy,
@@ -5442,15 +5455,15 @@ def enable_eager_execution_internal(config=None,
raise ValueError(
"execution_mode must be one of None, tf.contrib.eager.SYNC, "
"tf.contrib.eager.ASYNC")
- # pylint: disable=protected-access
- if context._default_mode == context.GRAPH_MODE:
+ if context.default_execution_mode == context.GRAPH_MODE:
graph_mode_has_been_used = (
_default_session_stack.stack
or len(get_default_graph().get_operations()) > 0) # pylint: disable=g-explicit-length-test
if graph_mode_has_been_used:
raise ValueError(
"tf.enable_eager_execution must be called at program startup.")
- context._default_mode = context.EAGER_MODE
+ context.default_execution_mode = context.EAGER_MODE
+ # pylint: disable=protected-access
if context._context is None:
context._context = context.Context(
config=config,
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 7cddd861c8..b5388ad0b2 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -866,6 +866,18 @@ def device(use_gpu):
yield
+class ErrorLoggingSession(session.Session):
+ """Wrapper around a Session that logs errors in run().
+ """
+
+ def run(self, *args, **kwargs):
+ try:
+ return super(ErrorLoggingSession, self).run(*args, **kwargs)
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(str(e))
+ raise
+
+
@tf_export("test.TestCase")
class TensorFlowTestCase(googletest.TestCase):
"""Base class for tests that need to test TensorFlow.
@@ -1853,7 +1865,7 @@ class TensorFlowTestCase(googletest.TestCase):
rewriter_config_pb2.RewriterConfig.OFF)
return config
- return session.Session(graph=graph, config=prepare_config(config))
+ return ErrorLoggingSession(graph=graph, config=prepare_config(config))
@contextlib.contextmanager
def _get_cached_session(self,
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 5523d70a8d..7246341519 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -388,7 +388,7 @@ py_test(
py_test(
name = "embeddings_test",
- size = "small",
+ size = "medium",
srcs = ["layers/embeddings_test.py"],
srcs_version = "PY2AND3",
deps = [
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 5feedc43a5..85f1d6299f 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -160,7 +160,7 @@ def fit_loop(
outs = [outs]
outs = _aggregate_metrics_across_towers(
- len(current_strategy._devices), out_labels, outs)
+ current_strategy.num_towers, out_labels, outs)
for l, o in zip(out_labels, outs):
batch_logs[l] = o
callbacks.on_batch_end(step_index, batch_logs)
@@ -263,7 +263,7 @@ def test_loop(model, inputs, targets, verbose=0, steps=None):
for step in range(steps):
batch_outs = distributed_test_function(ins)
batch_outs = _aggregate_metrics_across_towers(
- len(current_strategy._devices), model.metrics_names, batch_outs)
+ current_strategy.num_towers, model.metrics_names, batch_outs)
if isinstance(batch_outs, list):
if step == 0:
for _ in enumerate(batch_outs):
diff --git a/tensorflow/python/keras/initializers.py b/tensorflow/python/keras/initializers.py
index b9d856efa8..cac78c44ca 100644
--- a/tensorflow/python/keras/initializers.py
+++ b/tensorflow/python/keras/initializers.py
@@ -20,14 +20,15 @@ from __future__ import print_function
import six
+from tensorflow.python.framework import dtypes
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
# These imports are brought in so that keras.initializers.deserialize
# has them available in module_objects.
from tensorflow.python.ops.init_ops import Constant
-from tensorflow.python.ops.init_ops import glorot_normal_initializer
-from tensorflow.python.ops.init_ops import glorot_uniform_initializer
+from tensorflow.python.ops.init_ops import GlorotNormal
+from tensorflow.python.ops.init_ops import GlorotUniform
from tensorflow.python.ops.init_ops import he_normal # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import he_uniform # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import Identity
@@ -36,15 +37,84 @@ from tensorflow.python.ops.init_ops import lecun_normal # pylint: disable=unuse
from tensorflow.python.ops.init_ops import lecun_uniform # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import Ones
from tensorflow.python.ops.init_ops import Orthogonal
-from tensorflow.python.ops.init_ops import RandomNormal
-from tensorflow.python.ops.init_ops import RandomUniform
-from tensorflow.python.ops.init_ops import TruncatedNormal
+from tensorflow.python.ops.init_ops import RandomNormal as TFRandomNormal
+from tensorflow.python.ops.init_ops import RandomUniform as TFRandomUniform
+from tensorflow.python.ops.init_ops import TruncatedNormal as TFTruncatedNormal
from tensorflow.python.ops.init_ops import VarianceScaling # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import Zeros
from tensorflow.python.util.tf_export import tf_export
+@tf_export('keras.initializers.TruncatedNormal',
+ 'keras.initializers.truncated_normal')
+class TruncatedNormal(TFTruncatedNormal):
+ """Initializer that generates a truncated normal distribution.
+
+ These values are similar to values from a `random_normal_initializer`
+ except that values more than two standard deviations from the mean
+ are discarded and re-drawn. This is the recommended initializer for
+ neural network weights and filters.
+
+ Args:
+ mean: a python scalar or a scalar tensor. Mean of the random values to
+ generate. Defaults to 0.
+ stddev: a python scalar or a scalar tensor. Standard deviation of the random
+ values to generate. Defaults to 0.05.
+ seed: A Python integer. Used to create random seeds. See
+ `tf.set_random_seed` for behavior.
+ dtype: The data type. Only floating point types are supported.
+ """
+
+ def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32):
+ super(TruncatedNormal, self).__init__(
+ mean=mean, stddev=stddev, seed=seed, dtype=dtype)
+
+
+@tf_export('keras.initializers.RandomUniform', 'keras.initializers.uniform',
+ 'keras.initializers.random_uniform')
+class RandomUniform(TFRandomUniform):
+ """Initializer that generates tensors with a uniform distribution.
+
+ Args:
+ minval: A python scalar or a scalar tensor. Lower bound of the range of
+ random values to generate. Defaults to -0.05.
+ maxval: A python scalar or a scalar tensor. Upper bound of the range of
+ random values to generate. Defaults to 0.05.
+ seed: A Python integer. Used to create random seeds. See
+ `tf.set_random_seed` for behavior.
+ dtype: The data type.
+ """
+
+ def __init__(self, minval=-0.05, maxval=0.05, seed=None,
+ dtype=dtypes.float32):
+ super(RandomUniform, self).__init__(
+ minval=minval, maxval=maxval, seed=seed, dtype=dtype)
+
+
+@tf_export('keras.initializers.RandomNormal', 'keras.initializers.normal',
+ 'keras.initializers.random_normal')
+class RandomNormal(TFRandomNormal):
+ """Initializer that generates tensors with a normal distribution.
+
+ Args:
+ mean: a python scalar or a scalar tensor. Mean of the random values to
+ generate. Defaults to 0.
+ stddev: a python scalar or a scalar tensor. Standard deviation of the random
+ values to generate. Defaults to 0.05.
+ seed: A Python integer. Used to create random seeds. See
+ `tf.set_random_seed` for behavior.
+ dtype: The data type. Only floating point types are supported.
+
+ Returns:
+ RandomNormal instance.
+ """
+
+ def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32):
+ super(RandomNormal, self).__init__(
+ mean=mean, stddev=stddev, seed=seed, dtype=dtype)
+
+
# Compatibility aliases
# pylint: disable=invalid-name
@@ -56,10 +126,9 @@ normal = random_normal = RandomNormal
truncated_normal = TruncatedNormal
identity = Identity
orthogonal = Orthogonal
-glorot_normal = glorot_normal_initializer
-glorot_uniform = glorot_uniform_initializer
+glorot_normal = GlorotNormal
+glorot_uniform = GlorotUniform
-# pylint: enable=invalid-name
# Utility functions
@@ -92,3 +161,6 @@ def get(identifier):
else:
raise ValueError('Could not interpret initializer identifier: ' +
str(identifier))
+
+
+# pylint: enable=invalid-name
diff --git a/tensorflow/python/keras/initializers_test.py b/tensorflow/python/keras/initializers_test.py
index 8ddc9a17bf..2b758a98f3 100644
--- a/tensorflow/python/keras/initializers_test.py
+++ b/tensorflow/python/keras/initializers_test.py
@@ -146,6 +146,21 @@ class KerasInitializersTest(test.TestCase):
self._runner(keras.initializers.ones(), tensor_shape,
target_mean=1., target_max=1.)
+ def test_default_random_uniform(self):
+ ru = keras.initializers.get('uniform')
+ self.assertEqual(ru.minval, -0.05)
+ self.assertEqual(ru.maxval, 0.05)
+
+ def test_default_random_normal(self):
+ rn = keras.initializers.get('normal')
+ self.assertEqual(rn.mean, 0.0)
+ self.assertEqual(rn.stddev, 0.05)
+
+ def test_default_truncated_normal(self):
+ tn = keras.initializers.get('truncated_normal')
+ self.assertEqual(tn.mean, 0.0)
+ self.assertEqual(tn.stddev, 0.05)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index 04b3aecff8..ba7498e7e6 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -2100,6 +2100,9 @@ class LSTMCell(Layer):
class LSTM(RNN):
"""Long Short-Term Memory layer - Hochreiter 1997.
+ Note that this cell is not optimized for performance on GPU. Please use
+ `tf.keras.layers.CuDNNLSTM` for better performance on GPU.
+
Arguments:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
@@ -2195,6 +2198,10 @@ class LSTM(RNN):
logging.warning('`implementation=0` has been deprecated, '
'and now defaults to `implementation=1`.'
'Please update your layer call.')
+ if context.executing_eagerly() and context.num_gpus() > 0:
+ logging.warn('%s: Note that this layer is not optimized for performance. '
+ 'Please use tf.keras.layers.CuDNNLSTM for better '
+ 'performance on GPU.', self)
cell = LSTMCell(
units,
activation=activation,
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 14cf1ce2af..81c760b1f6 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -585,11 +585,15 @@ def categorical_accuracy(y_true, y_pred):
def sparse_categorical_accuracy(y_true, y_pred):
- return math_ops.cast(
- math_ops.equal(
- math_ops.reduce_max(y_true, axis=-1),
- math_ops.cast(math_ops.argmax(y_pred, axis=-1), K.floatx())),
- K.floatx())
+ y_true = math_ops.reduce_max(y_true, axis=-1)
+ y_pred = math_ops.argmax(y_pred, axis=-1)
+
+ # If the expected labels are float, we need to cast the int returned by
+ # argmax to compare.
+ if K.dtype(y_true) == K.floatx():
+ y_pred = math_ops.cast(y_pred, K.floatx())
+
+ return math_ops.cast(math_ops.equal(y_true, y_pred), K.floatx())
@tf_export('keras.metrics.top_k_categorical_accuracy')
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index 0bc95a3952..779c08c42d 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -50,9 +50,32 @@ class KerasMetricsTest(test.TestCase):
def test_sparse_categorical_accuracy(self):
with self.cached_session():
metric = metrics.sparse_categorical_accuracy
- y_a = K.variable(np.random.randint(0, 7, (6,)))
- y_b = K.variable(np.random.random((6, 7)))
- self.assertEqual(K.eval(metric(y_a, y_b)).shape, (6,))
+ y_true = K.variable(np.random.randint(0, 7, (6,)))
+ y_pred = K.variable(np.random.random((6, 7)))
+ self.assertEqual(K.eval(metric(y_true, y_pred)).shape, (6,))
+
+ def test_sparse_categorical_accuracy_float(self):
+ with self.cached_session():
+ metric = metrics.sparse_categorical_accuracy
+ y_true = K.variable(np.random.random((6,)))
+ y_pred = K.variable(np.random.random((6, 7)))
+ self.assertEqual(K.eval(metric(y_true, y_pred)).shape, (6,))
+
+ def test_sparse_categorical_accuracy_eager(self):
+ """Tests that ints passed in via Eager return results. See b/113504761."""
+ with context.eager_mode():
+ metric = metrics.sparse_categorical_accuracy
+ y_true = np.arange(6).reshape([6, 1])
+ y_pred = np.arange(36).reshape([6, 6])
+ self.assertAllEqual(metric(y_true, y_pred), [0., 0., 0., 0., 0., 1.])
+
+ def test_sparse_categorical_accuracy_float_eager(self):
+ """Tests that floats passed in via Eager return results. See b/113504761."""
+ with context.eager_mode():
+ metric = metrics.sparse_categorical_accuracy
+ y_true = np.arange(6, dtype=np.float32).reshape([6, 1])
+ y_pred = np.arange(36).reshape([6, 6])
+ self.assertAllEqual(metric(y_true, y_pred), [0., 0., 0., 0., 0., 1.])
def test_sparse_top_k_categorical_accuracy(self):
with self.cached_session():
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index 39b6042597..c3b7301eba 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -30,7 +30,6 @@ from tensorflow.python.keras.engine.input_layer import InputLayer
from tensorflow.python.keras.engine.network import Network
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
-from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.util.tf_export import tf_export
@@ -394,10 +393,11 @@ def in_place_subclassed_model_state_restoration(model):
def clone_and_build_model(
model, input_tensors=None, target_tensors=None, custom_objects=None,
- compile_clone=True, in_place_reset=False):
+ compile_clone=True, in_place_reset=False, optimizer_iterations=None):
"""Clone a `Model` and build/compile it with the same settings used before.
- This function should be run in the same graph as the model.
+ This function can be be run in the same graph or in a separate graph from the
+ model. When using a separate graph, `in_place_reset` must be `False`.
Args:
model: `tf.keras.Model` object. Can be Functional, Sequential, or
@@ -414,6 +414,10 @@ def clone_and_build_model(
this argument must be set to `True` (default `False`). To restore the
original model, use the function
`in_place_subclassed_model_state_restoration(model)`.
+ optimizer_iterations: An iterations variable to pass to the optimizer if
+ the model uses a TFOptimizer, and if the clone is compiled. This is used
+ when a Keras model is cloned into an Estimator model function, because
+ Estimators create their own global step variable.
Returns:
Clone of the model.
@@ -448,14 +452,12 @@ def clone_and_build_model(
clone.build()
elif model.optimizer:
if isinstance(model.optimizer, optimizers.TFOptimizer):
- optimizer = model.optimizer
+ optimizer = optimizers.TFOptimizer(
+ model.optimizer.optimizer, optimizer_iterations)
K.track_tf_optimizer(optimizer)
else:
optimizer_config = model.optimizer.get_config()
optimizer = model.optimizer.__class__.from_config(optimizer_config)
- global_step = training_util.get_or_create_global_step()
- K.track_variable(global_step)
- optimizer.iterations = global_step
clone.compile(
optimizer,
diff --git a/tensorflow/python/keras/optimizers.py b/tensorflow/python/keras/optimizers.py
index 2ce79285db..ab13e5c632 100644
--- a/tensorflow/python/keras/optimizers.py
+++ b/tensorflow/python/keras/optimizers.py
@@ -692,11 +692,15 @@ class TFOptimizer(Optimizer, checkpointable.CheckpointableBase):
"""Wrapper class for native TensorFlow optimizers.
"""
- def __init__(self, optimizer): # pylint: disable=super-init-not-called
+ def __init__(self, optimizer, iterations=None): # pylint: disable=super-init-not-called
self.optimizer = optimizer
self._track_checkpointable(optimizer, name='optimizer')
- with K.name_scope(self.__class__.__name__):
- self.iterations = K.variable(0, dtype='int64', name='iterations')
+ if iterations is None:
+ with K.name_scope(self.__class__.__name__):
+ self.iterations = K.variable(0, dtype='int64', name='iterations')
+ else:
+ self.iterations = iterations
+ self._track_checkpointable(self.iterations, name='global_step')
def apply_gradients(self, grads):
self.optimizer.apply_gradients(grads, global_step=self.iterations)
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 7671da11ab..3026c7755a 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1121,6 +1121,7 @@ tf_py_test(
"//tensorflow/python:variable_scope",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:state_ops",
+ "//tensorflow/python:util",
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:function",
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index b0e24e969c..a164682227 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -560,13 +560,21 @@ class StridedSliceTest(test_util.TensorFlowTestCase):
self.assertAllEqual([3.], self.evaluate(s))
@test_util.assert_no_new_pyobjects_executing_eagerly
- def testEagerMemory(self):
+ @test_util.assert_no_garbage_created
+ def testTensorSliceEagerMemory(self):
with context.eager_mode():
inputs = constant_op.constant(
[[[1], [2], [3], [4]]], dtype=dtypes.float32)
# Tests that slicing an EagerTensor doesn't leak memory
inputs[0] # pylint: disable=pointless-statement
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ @test_util.assert_no_garbage_created
+ def testVariableSliceEagerMemory(self):
+ with context.eager_mode():
+ v = variables.Variable([1., 2.])
+ v[0] # pylint: disable=pointless-statement
+
def testDegenerateSlices(self):
with self.test_session(use_gpu=True):
checker = StridedSliceChecker(self, StridedSliceChecker.REF_TENSOR)
diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py
index 3193222262..9b6aee64aa 100644
--- a/tensorflow/python/kernel_tests/list_ops_test.py
+++ b/tensorflow/python/kernel_tests/list_ops_test.py
@@ -71,6 +71,36 @@ class ListOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual(self.evaluate(t), [1.0, 2.0])
@test_util.run_in_graph_and_eager_modes
+ def testGatherGrad(self):
+ with backprop.GradientTape() as tape:
+ l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
+ element_shape=scalar_shape())
+ c0 = constant_op.constant(1.0)
+ tape.watch(c0)
+ l = list_ops.tensor_list_push_back(l, c0)
+ l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
+ t = list_ops.tensor_list_gather(l, [1, 0], element_dtype=dtypes.float32)
+ self.assertAllEqual(self.evaluate(t), [2.0, 1.0])
+ s = (t[0] + t[1]) * (t[0] + t[1])
+ dt = tape.gradient(s, c0)
+ self.assertAllEqual(self.evaluate(dt), 6.0)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testScatterGrad(self):
+ with backprop.GradientTape() as tape:
+ c0 = constant_op.constant([1.0, 2.0])
+ tape.watch(c0)
+ l = list_ops.tensor_list_scatter(
+ c0, [1, 0], ops.convert_to_tensor([], dtype=dtypes.int32))
+ t0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
+ t1 = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
+ self.assertAllEqual(self.evaluate(t0), 2.0)
+ self.assertAllEqual(self.evaluate(t1), 1.0)
+ loss = t0 * t0 + t1 * t1
+ dt = tape.gradient(loss, c0)
+ self.assertAllEqual(self.evaluate(dt), [2., 4.])
+
+ @test_util.run_in_graph_and_eager_modes
def testStackGPU(self):
if not context.num_gpus():
return
diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py
index 59b3ee2013..7dff4501cc 100644
--- a/tensorflow/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/python/kernel_tests/parsing_ops_test.py
@@ -60,8 +60,9 @@ def flatten(list_of_lists):
def flatten_values_tensors_or_sparse(tensors_list):
"""Flatten each SparseTensor object into 3 Tensors for session.run()."""
return list(
- flatten([[v.indices, v.values, v.dense_shape] if isinstance(
- v, sparse_tensor.SparseTensor) else [v] for v in tensors_list]))
+ flatten([[v.indices, v.values, v.dense_shape]
+ if isinstance(v, sparse_tensor.SparseTensor) else [v]
+ for v in tensors_list]))
def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
@@ -106,8 +107,9 @@ class ParseExampleTest(test.TestCase):
# Check shapes; if serialized is a Tensor we need its size to
# properly check.
serialized = kwargs["serialized"]
- batch_size = (serialized.eval().size if isinstance(serialized, ops.Tensor)
- else np.asarray(serialized).size)
+ batch_size = (
+ serialized.eval().size if isinstance(serialized, ops.Tensor) else
+ np.asarray(serialized).size)
for k, f in kwargs["features"].items():
if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
self.assertEqual(
@@ -129,12 +131,9 @@ class ParseExampleTest(test.TestCase):
c_default = np.random.rand(2).astype(np.float32)
expected_st_a = ( # indices, values, shape
- np.empty(
- (0, 2), dtype=np.int64), # indices
- np.empty(
- (0,), dtype=np.int64), # sp_a is DT_INT64
- np.array(
- [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
+ np.empty((0, 2), dtype=np.int64), # indices
+ np.empty((0,), dtype=np.int64), # sp_a is DT_INT64
+ np.array([2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
expected_output = {
sparse_name: expected_st_a,
@@ -143,28 +142,23 @@ class ParseExampleTest(test.TestCase):
c_name: np.array(2 * [c_default]),
}
- self._test(
- {
- "example_names":
- np.empty(
- (0,), dtype=bytes),
- "serialized":
- ops.convert_to_tensor(["", ""]),
- "features": {
- sparse_name:
- parsing_ops.VarLenFeature(dtypes.int64),
- a_name:
- parsing_ops.FixedLenFeature(
- (1, 3), dtypes.int64, default_value=a_default),
- b_name:
- parsing_ops.FixedLenFeature(
- (3, 3), dtypes.string, default_value=b_default),
- c_name:
- parsing_ops.FixedLenFeature(
- (2,), dtypes.float32, default_value=c_default),
- }
- },
- expected_output)
+ self._test({
+ "example_names": np.empty((0,), dtype=bytes),
+ "serialized": ops.convert_to_tensor(["", ""]),
+ "features": {
+ sparse_name:
+ parsing_ops.VarLenFeature(dtypes.int64),
+ a_name:
+ parsing_ops.FixedLenFeature(
+ (1, 3), dtypes.int64, default_value=a_default),
+ b_name:
+ parsing_ops.FixedLenFeature(
+ (3, 3), dtypes.string, default_value=b_default),
+ c_name:
+ parsing_ops.FixedLenFeature(
+ (2,), dtypes.float32, default_value=c_default),
+ }
+ }, expected_output)
def testEmptySerializedWithoutDefaultsShouldFail(self):
input_features = {
@@ -180,8 +174,7 @@ class ParseExampleTest(test.TestCase):
default_value=np.random.rand(3, 3).astype(bytes)),
# Feature "c" is missing a default, this gap will cause failure.
"c":
- parsing_ops.FixedLenFeature(
- (2,), dtype=dtypes.float32),
+ parsing_ops.FixedLenFeature((2,), dtype=dtypes.float32),
}
# Edge case where the key is there but the feature value is empty
@@ -211,7 +204,8 @@ class ParseExampleTest(test.TestCase):
original = [
example(features=features({
"a": float_feature([1, 1, 3]),
- })), example(features=features({
+ })),
+ example(features=features({
"a": float_feature([-1, -1]),
}))
]
@@ -231,7 +225,11 @@ class ParseExampleTest(test.TestCase):
"Name: failing, Key: a, Index: 1. Number of float val"))
def testDenseDefaultNoShapeShouldFail(self):
- original = [example(features=features({"a": float_feature([1, 1, 3]),})),]
+ original = [
+ example(features=features({
+ "a": float_feature([1, 1, 3]),
+ })),
+ ]
serialized = [m.SerializeToString() for m in original]
@@ -250,31 +248,31 @@ class ParseExampleTest(test.TestCase):
example(features=features({
"st_c": float_feature([3, 4])
})),
- example(features=features({
- "st_c": float_feature([]), # empty float list
- })),
- example(features=features({
- "st_d": feature(), # feature with nothing in it
- })),
- example(features=features({
- "st_c": float_feature([1, 2, -1]),
- "st_d": bytes_feature([b"hi"])
- }))
+ example(
+ features=features({
+ "st_c": float_feature([]), # empty float list
+ })),
+ example(
+ features=features({
+ "st_d": feature(), # feature with nothing in it
+ })),
+ example(
+ features=features({
+ "st_c": float_feature([1, 2, -1]),
+ "st_d": bytes_feature([b"hi"])
+ }))
]
serialized = [m.SerializeToString() for m in original]
expected_st_c = ( # indices, values, shape
- np.array(
- [[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64), np.array(
- [3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32), np.array(
- [4, 3], dtype=np.int64)) # batch == 2, max_elems = 3
+ np.array([[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64),
+ np.array([3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32),
+ np.array([4, 3], dtype=np.int64)) # batch == 2, max_elems = 3
expected_st_d = ( # indices, values, shape
- np.array(
- [[3, 0]], dtype=np.int64), np.array(
- ["hi"], dtype=bytes), np.array(
- [4, 1], dtype=np.int64)) # batch == 2, max_elems = 1
+ np.array([[3, 0]], dtype=np.int64), np.array(["hi"], dtype=bytes),
+ np.array([4, 1], dtype=np.int64)) # batch == 2, max_elems = 1
expected_output = {
"st_c": expected_st_c,
@@ -291,70 +289,74 @@ class ParseExampleTest(test.TestCase):
def testSerializedContainingSparseFeature(self):
original = [
- example(features=features({
- "val": float_feature([3, 4]),
- "idx": int64_feature([5, 10])
- })),
- example(features=features({
- "val": float_feature([]), # empty float list
- "idx": int64_feature([])
- })),
- example(features=features({
- "val": feature(), # feature with nothing in it
- # missing idx feature
- })),
- example(features=features({
- "val": float_feature([1, 2, -1]),
- "idx":
- int64_feature([0, 9, 3]) # unsorted
- }))
+ example(
+ features=features({
+ "val": float_feature([3, 4]),
+ "idx": int64_feature([5, 10])
+ })),
+ example(
+ features=features({
+ "val": float_feature([]), # empty float list
+ "idx": int64_feature([])
+ })),
+ example(
+ features=features({
+ "val": feature(), # feature with nothing in it
+ # missing idx feature
+ })),
+ example(
+ features=features({
+ "val": float_feature([1, 2, -1]),
+ "idx":
+ int64_feature([0, 9, 3]) # unsorted
+ }))
]
serialized = [m.SerializeToString() for m in original]
expected_sp = ( # indices, values, shape
- np.array(
- [[0, 5], [0, 10], [3, 0], [3, 3], [3, 9]], dtype=np.int64),
- np.array(
- [3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32), np.array(
- [4, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+ np.array([[0, 5], [0, 10], [3, 0], [3, 3], [3, 9]], dtype=np.int64),
+ np.array([3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32),
+ np.array([4, 13], dtype=np.int64)) # batch == 4, max_elems = 13
- expected_output = {"sp": expected_sp,}
+ expected_output = {
+ "sp": expected_sp,
+ }
self._test({
"serialized": ops.convert_to_tensor(serialized),
"features": {
- "sp": parsing_ops.SparseFeature(
- ["idx"], "val", dtypes.float32, [13])
+ "sp":
+ parsing_ops.SparseFeature(["idx"], "val", dtypes.float32, [13])
}
}, expected_output)
def testSerializedContainingSparseFeatureReuse(self):
original = [
- example(features=features({
- "val1": float_feature([3, 4]),
- "val2": float_feature([5, 6]),
- "idx": int64_feature([5, 10])
- })),
- example(features=features({
- "val1": float_feature([]), # empty float list
- "idx": int64_feature([])
- })),
+ example(
+ features=features({
+ "val1": float_feature([3, 4]),
+ "val2": float_feature([5, 6]),
+ "idx": int64_feature([5, 10])
+ })),
+ example(
+ features=features({
+ "val1": float_feature([]), # empty float list
+ "idx": int64_feature([])
+ })),
]
serialized = [m.SerializeToString() for m in original]
expected_sp1 = ( # indices, values, shape
- np.array(
- [[0, 5], [0, 10]], dtype=np.int64), np.array(
- [3.0, 4.0], dtype=np.float32), np.array(
- [2, 13], dtype=np.int64)) # batch == 2, max_elems = 13
+ np.array([[0, 5], [0, 10]], dtype=np.int64),
+ np.array([3.0, 4.0], dtype=np.float32), np.array(
+ [2, 13], dtype=np.int64)) # batch == 2, max_elems = 13
expected_sp2 = ( # indices, values, shape
- np.array(
- [[0, 5], [0, 10]], dtype=np.int64), np.array(
- [5.0, 6.0], dtype=np.float32), np.array(
- [2, 7], dtype=np.int64)) # batch == 2, max_elems = 13
+ np.array([[0, 5], [0, 10]], dtype=np.int64),
+ np.array([5.0, 6.0], dtype=np.float32), np.array(
+ [2, 7], dtype=np.int64)) # batch == 2, max_elems = 13
expected_output = {
"sp1": expected_sp1,
@@ -374,25 +376,29 @@ class ParseExampleTest(test.TestCase):
def testSerializedContaining3DSparseFeature(self):
original = [
- example(features=features({
- "val": float_feature([3, 4]),
- "idx0": int64_feature([5, 10]),
- "idx1": int64_feature([0, 2]),
- })),
- example(features=features({
- "val": float_feature([]), # empty float list
- "idx0": int64_feature([]),
- "idx1": int64_feature([]),
- })),
- example(features=features({
- "val": feature(), # feature with nothing in it
- # missing idx feature
- })),
- example(features=features({
- "val": float_feature([1, 2, -1]),
- "idx0": int64_feature([0, 9, 3]), # unsorted
- "idx1": int64_feature([1, 0, 2]),
- }))
+ example(
+ features=features({
+ "val": float_feature([3, 4]),
+ "idx0": int64_feature([5, 10]),
+ "idx1": int64_feature([0, 2]),
+ })),
+ example(
+ features=features({
+ "val": float_feature([]), # empty float list
+ "idx0": int64_feature([]),
+ "idx1": int64_feature([]),
+ })),
+ example(
+ features=features({
+ "val": feature(), # feature with nothing in it
+ # missing idx feature
+ })),
+ example(
+ features=features({
+ "val": float_feature([1, 2, -1]),
+ "idx0": int64_feature([0, 9, 3]), # unsorted
+ "idx1": int64_feature([1, 0, 2]),
+ }))
]
serialized = [m.SerializeToString() for m in original]
@@ -407,13 +413,16 @@ class ParseExampleTest(test.TestCase):
# shape batch == 4, max_elems = 13
np.array([4, 13, 3], dtype=np.int64))
- expected_output = {"sp": expected_sp,}
+ expected_output = {
+ "sp": expected_sp,
+ }
self._test({
"serialized": ops.convert_to_tensor(serialized),
"features": {
- "sp": parsing_ops.SparseFeature(
- ["idx0", "idx1"], "val", dtypes.float32, [13, 3])
+ "sp":
+ parsing_ops.SparseFeature(["idx0", "idx1"], "val",
+ dtypes.float32, [13, 3])
}
}, expected_output)
@@ -421,41 +430,37 @@ class ParseExampleTest(test.TestCase):
aname = "a"
bname = "b*has+a:tricky_name"
original = [
- example(features=features({
- aname: float_feature([1, 1]),
- bname: bytes_feature([b"b0_str"]),
- })), example(features=features({
- aname: float_feature([-1, -1]),
- bname: bytes_feature([b""]),
- }))
+ example(
+ features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str"]),
+ })),
+ example(
+ features=features({
+ aname: float_feature([-1, -1]),
+ bname: bytes_feature([b""]),
+ }))
]
serialized = [m.SerializeToString() for m in original]
expected_output = {
aname:
- np.array(
- [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
+ np.array([[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
bname:
- np.array(
- ["b0_str", ""], dtype=bytes).reshape(2, 1, 1, 1, 1),
+ np.array(["b0_str", ""], dtype=bytes).reshape(2, 1, 1, 1, 1),
}
# No defaults, values required
- self._test(
- {
- "serialized":
- ops.convert_to_tensor(serialized),
- "features": {
- aname:
- parsing_ops.FixedLenFeature(
- (1, 2, 1), dtype=dtypes.float32),
- bname:
- parsing_ops.FixedLenFeature(
- (1, 1, 1, 1), dtype=dtypes.string),
- }
- },
- expected_output)
+ self._test({
+ "serialized": ops.convert_to_tensor(serialized),
+ "features": {
+ aname:
+ parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
+ bname:
+ parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
+ }
+ }, expected_output)
# This test is identical as the previous one except
# for the creation of 'serialized'.
@@ -466,18 +471,22 @@ class ParseExampleTest(test.TestCase):
original = [
(example(features=features({
aname: float_feature([10, 10]),
- })), example(features=features({
- aname: float_feature([1, 1]),
- bname: bytes_feature([b"b0_str"]),
- }))),
+ })),
+ example(
+ features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str"]),
+ }))),
(
example(features=features({
bname: bytes_feature([b"b100"]),
})),
- example(features=features({
- aname: float_feature([-1, -1]),
- bname: bytes_feature([b"b1"]),
- })),),
+ example(
+ features=features({
+ aname: float_feature([-1, -1]),
+ bname: bytes_feature([b"b1"]),
+ })),
+ ),
]
serialized = [
@@ -486,55 +495,45 @@ class ParseExampleTest(test.TestCase):
expected_output = {
aname:
- np.array(
- [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
+ np.array([[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
bname:
- np.array(
- ["b0_str", "b1"], dtype=bytes).reshape(2, 1, 1, 1, 1),
+ np.array(["b0_str", "b1"], dtype=bytes).reshape(2, 1, 1, 1, 1),
}
# No defaults, values required
- self._test(
- {
- "serialized":
- ops.convert_to_tensor(serialized),
- "features": {
- aname:
- parsing_ops.FixedLenFeature(
- (1, 2, 1), dtype=dtypes.float32),
- bname:
- parsing_ops.FixedLenFeature(
- (1, 1, 1, 1), dtype=dtypes.string),
- }
- },
- expected_output)
+ self._test({
+ "serialized": ops.convert_to_tensor(serialized),
+ "features": {
+ aname:
+ parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
+ bname:
+ parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
+ }
+ }, expected_output)
def testSerializedContainingDenseScalar(self):
original = [
example(features=features({
"a": float_feature([1]),
- })), example(features=features({}))
+ })),
+ example(features=features({}))
]
serialized = [m.SerializeToString() for m in original]
expected_output = {
"a":
- np.array(
- [[1], [-1]], dtype=np.float32) # 2x1 (column vector)
+ np.array([[1], [-1]], dtype=np.float32) # 2x1 (column vector)
}
- self._test(
- {
- "serialized":
- ops.convert_to_tensor(serialized),
- "features": {
- "a":
- parsing_ops.FixedLenFeature(
- (1,), dtype=dtypes.float32, default_value=-1),
- }
- },
- expected_output)
+ self._test({
+ "serialized": ops.convert_to_tensor(serialized),
+ "features": {
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1,), dtype=dtypes.float32, default_value=-1),
+ }
+ }, expected_output)
def testSerializedContainingDenseWithDefaults(self):
original = [
@@ -553,58 +552,48 @@ class ParseExampleTest(test.TestCase):
expected_output = {
"a":
- np.array(
- [[1, 1], [3, -3], [3, -3]], dtype=np.float32).reshape(3, 1, 2,
- 1),
+ np.array([[1, 1], [3, -3], [3, -3]], dtype=np.float32).reshape(
+ 3, 1, 2, 1),
"b":
- np.array(
- ["tmp_str", "b1", "tmp_str"], dtype=bytes).reshape(3, 1, 1, 1,
- 1),
+ np.array(["tmp_str", "b1", "tmp_str"], dtype=bytes).reshape(
+ 3, 1, 1, 1, 1),
}
- self._test(
- {
- "serialized":
- ops.convert_to_tensor(serialized),
- "features": {
- "a":
- parsing_ops.FixedLenFeature(
- (1, 2, 1),
- dtype=dtypes.float32,
- default_value=[3.0, -3.0]),
- "b":
- parsing_ops.FixedLenFeature(
- (1, 1, 1, 1),
- dtype=dtypes.string,
- default_value="tmp_str"),
- }
- },
- expected_output)
+ self._test({
+ "serialized": ops.convert_to_tensor(serialized),
+ "features": {
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1, 2, 1), dtype=dtypes.float32, default_value=[3.0, -3.0]),
+ "b":
+ parsing_ops.FixedLenFeature(
+ (1, 1, 1, 1), dtype=dtypes.string, default_value="tmp_str"),
+ }
+ }, expected_output)
def testSerializedContainingSparseAndSparseFeatureAndDenseWithNoDefault(self):
expected_st_a = ( # indices, values, shape
- np.empty(
- (0, 2), dtype=np.int64), # indices
- np.empty(
- (0,), dtype=np.int64), # sp_a is DT_INT64
- np.array(
- [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
+ np.empty((0, 2), dtype=np.int64), # indices
+ np.empty((0,), dtype=np.int64), # sp_a is DT_INT64
+ np.array([2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
expected_sp = ( # indices, values, shape
- np.array(
- [[0, 0], [0, 3], [1, 7]], dtype=np.int64), np.array(
- ["a", "b", "c"], dtype="|S"), np.array(
- [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+ np.array([[0, 0], [0, 3], [1, 7]], dtype=np.int64),
+ np.array(["a", "b", "c"], dtype="|S"), np.array(
+ [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
original = [
- example(features=features({
- "c": float_feature([3, 4]),
- "val": bytes_feature([b"a", b"b"]),
- "idx": int64_feature([0, 3])
- })), example(features=features({
- "c": float_feature([1, 2]),
- "val": bytes_feature([b"c"]),
- "idx": int64_feature([7])
- }))
+ example(
+ features=features({
+ "c": float_feature([3, 4]),
+ "val": bytes_feature([b"a", b"b"]),
+ "idx": int64_feature([0, 3])
+ })),
+ example(
+ features=features({
+ "c": float_feature([1, 2]),
+ "val": bytes_feature([b"c"]),
+ "idx": int64_feature([7])
+ }))
]
names = ["in1", "in2"]
@@ -617,16 +606,13 @@ class ParseExampleTest(test.TestCase):
"sp": expected_sp,
"a": np.array(2 * [[a_default]]),
"b": np.array(2 * [b_default]),
- "c": np.array(
- [[3, 4], [1, 2]], dtype=np.float32),
+ "c": np.array([[3, 4], [1, 2]], dtype=np.float32),
}
self._test(
{
- "example_names":
- names,
- "serialized":
- ops.convert_to_tensor(serialized),
+ "example_names": names,
+ "serialized": ops.convert_to_tensor(serialized),
"features": {
"st_a":
parsing_ops.VarLenFeature(dtypes.int64),
@@ -647,25 +633,26 @@ class ParseExampleTest(test.TestCase):
def testSerializedContainingSparseAndSparseFeatureWithReuse(self):
expected_idx = ( # indices, values, shape
- np.array(
- [[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64),
- np.array([0, 3, 7, 1]), np.array(
- [2, 2], dtype=np.int64)) # batch == 4, max_elems = 2
+ np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64),
+ np.array([0, 3, 7, 1]),
+ np.array([2, 2], dtype=np.int64)) # batch == 4, max_elems = 2
expected_sp = ( # indices, values, shape
- np.array(
- [[0, 0], [0, 3], [1, 1], [1, 7]], dtype=np.int64), np.array(
- ["a", "b", "d", "c"], dtype="|S"), np.array(
- [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+ np.array([[0, 0], [0, 3], [1, 1], [1, 7]], dtype=np.int64),
+ np.array(["a", "b", "d", "c"], dtype="|S"),
+ np.array([2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
original = [
- example(features=features({
- "val": bytes_feature([b"a", b"b"]),
- "idx": int64_feature([0, 3])
- })), example(features=features({
- "val": bytes_feature([b"c", b"d"]),
- "idx": int64_feature([7, 1])
- }))
+ example(
+ features=features({
+ "val": bytes_feature([b"a", b"b"]),
+ "idx": int64_feature([0, 3])
+ })),
+ example(
+ features=features({
+ "val": bytes_feature([b"c", b"d"]),
+ "idx": int64_feature([7, 1])
+ }))
]
names = ["in1", "in2"]
@@ -680,9 +667,10 @@ class ParseExampleTest(test.TestCase):
"example_names": names,
"serialized": ops.convert_to_tensor(serialized),
"features": {
- "idx": parsing_ops.VarLenFeature(dtypes.int64),
- "sp": parsing_ops.SparseFeature(
- ["idx"], "val", dtypes.string, [13]),
+ "idx":
+ parsing_ops.VarLenFeature(dtypes.int64),
+ "sp":
+ parsing_ops.SparseFeature(["idx"], "val", dtypes.string, [13]),
}
}, expected_output)
@@ -720,10 +708,11 @@ class ParseExampleTest(test.TestCase):
}
original = [
- example(features=features(
- {"a": int64_feature([truth_int[i]]),
- "b": bytes_feature(truth_str[i])}))
- for i in range(batch_size)
+ example(
+ features=features({
+ "a": int64_feature([truth_int[i]]),
+ "b": bytes_feature(truth_str[i])
+ })) for i in range(batch_size)
]
serialized = [m.SerializeToString() for m in original]
@@ -731,12 +720,18 @@ class ParseExampleTest(test.TestCase):
self._test({
"serialized": ops.convert_to_tensor(serialized, dtype=dtypes.string),
"features": {
- "a": parsing_ops.FixedLenSequenceFeature(
- shape=(), dtype=dtypes.int64, allow_missing=True,
- default_value=-1),
- "b": parsing_ops.FixedLenSequenceFeature(
- shape=[], dtype=dtypes.string, allow_missing=True,
- default_value="default"),
+ "a":
+ parsing_ops.FixedLenSequenceFeature(
+ shape=(),
+ dtype=dtypes.int64,
+ allow_missing=True,
+ default_value=-1),
+ "b":
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[],
+ dtype=dtypes.string,
+ allow_missing=True,
+ default_value="default"),
}
}, expected_output)
@@ -755,18 +750,21 @@ class ParseExampleTest(test.TestCase):
example(features=features({
cname: int64_feature([2]),
})),
- example(features=features({
- aname: float_feature([1, 1]),
- bname: bytes_feature([b"b0_str", b"b1_str"]),
- })),
- example(features=features({
- aname: float_feature([-1, -1, 2, 2]),
- bname: bytes_feature([b"b1"]),
- })),
- example(features=features({
- aname: float_feature([]),
- cname: int64_feature([3]),
- })),
+ example(
+ features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str", b"b1_str"]),
+ })),
+ example(
+ features=features({
+ aname: float_feature([-1, -1, 2, 2]),
+ bname: bytes_feature([b"b1"]),
+ })),
+ example(
+ features=features({
+ aname: float_feature([]),
+ cname: int64_feature([3]),
+ })),
]
serialized = [m.SerializeToString() for m in original]
@@ -827,7 +825,9 @@ class ParseExampleTest(test.TestCase):
"features": {
aname:
parsing_ops.FixedLenSequenceFeature(
- (2, 1), dtype=dtypes.float32, allow_missing=True,
+ (2, 1),
+ dtype=dtypes.float32,
+ allow_missing=True,
default_value=-2.0),
bname:
parsing_ops.FixedLenSequenceFeature(
@@ -867,7 +867,9 @@ class ParseExampleTest(test.TestCase):
"features": {
aname:
parsing_ops.FixedLenSequenceFeature(
- (2, 1), dtype=dtypes.float32, allow_missing=True,
+ (2, 1),
+ dtype=dtypes.float32,
+ allow_missing=True,
default_value=[]),
bname:
parsing_ops.FixedLenSequenceFeature(
@@ -908,26 +910,28 @@ class ParseExampleTest(test.TestCase):
"All dimensions of shape for feature c need to be known "
r"but received \(1, None\)."))
- self._test({
- "example_names": example_names,
- "serialized": ops.convert_to_tensor(serialized),
- "features": {
- aname:
- parsing_ops.FixedLenSequenceFeature(
- (2, 1), dtype=dtypes.float32, allow_missing=True),
- bname:
- parsing_ops.FixedLenSequenceFeature(
- (1, 1, 1), dtype=dtypes.string, allow_missing=True),
- cname:
- parsing_ops.FixedLenSequenceFeature(
- shape=[], dtype=dtypes.int64, allow_missing=False),
- dname:
- parsing_ops.FixedLenSequenceFeature(
- shape=[], dtype=dtypes.string, allow_missing=True),
- }
- }, expected_err=(ValueError,
- "Unsupported: FixedLenSequenceFeature requires "
- "allow_missing to be True."))
+ self._test(
+ {
+ "example_names": example_names,
+ "serialized": ops.convert_to_tensor(serialized),
+ "features": {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1), dtype=dtypes.float32, allow_missing=True),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (1, 1, 1), dtype=dtypes.string, allow_missing=True),
+ cname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.int64, allow_missing=False),
+ dname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.string, allow_missing=True),
+ }
+ },
+ expected_err=(ValueError,
+ "Unsupported: FixedLenSequenceFeature requires "
+ "allow_missing to be True."))
class ParseSingleExampleTest(test.TestCase):
@@ -949,8 +953,8 @@ class ParseSingleExampleTest(test.TestCase):
# Check shapes.
for k, f in kwargs["features"].items():
if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
- self.assertEqual(tuple(out[k].get_shape()),
- tensor_shape.as_shape(f.shape))
+ self.assertEqual(
+ tuple(out[k].get_shape()), tensor_shape.as_shape(f.shape))
elif isinstance(f, parsing_ops.VarLenFeature):
self.assertEqual(
tuple(out[k].indices.get_shape().as_list()), (None, 1))
@@ -959,29 +963,25 @@ class ParseSingleExampleTest(test.TestCase):
tuple(out[k].dense_shape.get_shape().as_list()), (1,))
def testSingleExampleWithSparseAndSparseFeatureAndDense(self):
- original = example(features=features({
- "c": float_feature([3, 4]),
- "d": float_feature([0.0, 1.0]),
- "val": bytes_feature([b"a", b"b"]),
- "idx": int64_feature([0, 3]),
- "st_a": float_feature([3.0, 4.0])
- }))
+ original = example(
+ features=features({
+ "c": float_feature([3, 4]),
+ "d": float_feature([0.0, 1.0]),
+ "val": bytes_feature([b"a", b"b"]),
+ "idx": int64_feature([0, 3]),
+ "st_a": float_feature([3.0, 4.0])
+ }))
serialized = original.SerializeToString()
expected_st_a = (
- np.array(
- [[0], [1]], dtype=np.int64), # indices
- np.array(
- [3.0, 4.0], dtype=np.float32), # values
- np.array(
- [2], dtype=np.int64)) # shape: max_values = 2
+ np.array([[0], [1]], dtype=np.int64), # indices
+ np.array([3.0, 4.0], dtype=np.float32), # values
+ np.array([2], dtype=np.int64)) # shape: max_values = 2
expected_sp = ( # indices, values, shape
- np.array(
- [[0], [3]], dtype=np.int64), np.array(
- ["a", "b"], dtype="|S"), np.array(
- [13], dtype=np.int64)) # max_values = 13
+ np.array([[0], [3]], dtype=np.int64), np.array(["a", "b"], dtype="|S"),
+ np.array([13], dtype=np.int64)) # max_values = 13
a_default = [1, 2, 3]
b_default = np.random.rand(3, 3).astype(bytes)
@@ -996,16 +996,14 @@ class ParseSingleExampleTest(test.TestCase):
self._test(
{
- "example_names":
- ops.convert_to_tensor("in1"),
- "serialized":
- ops.convert_to_tensor(serialized),
+ "example_names": ops.convert_to_tensor("in1"),
+ "serialized": ops.convert_to_tensor(serialized),
"features": {
"st_a":
parsing_ops.VarLenFeature(dtypes.float32),
"sp":
- parsing_ops.SparseFeature(
- ["idx"], "val", dtypes.string, [13]),
+ parsing_ops.SparseFeature(["idx"], "val", dtypes.string,
+ [13]),
"a":
parsing_ops.FixedLenFeature(
(1, 3), dtypes.int64, default_value=a_default),
@@ -1016,9 +1014,8 @@ class ParseSingleExampleTest(test.TestCase):
"c":
parsing_ops.FixedLenFeature(2, dtypes.float32),
"d":
- parsing_ops.FixedLenSequenceFeature([],
- dtypes.float32,
- allow_missing=True)
+ parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True)
}
},
expected_output)
@@ -1050,43 +1047,71 @@ class ParseSequenceExampleTest(test.TestCase):
kwargs,
expected_context_values=None,
expected_feat_list_values=None,
- expected_err=None):
+ expected_length_values=None,
+ expected_err=None,
+ batch=False):
expected_context_values = expected_context_values or {}
expected_feat_list_values = expected_feat_list_values or {}
+ expected_length_values = expected_length_values or {}
with self.test_session() as sess:
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
- c_out, fl_out = parsing_ops.parse_single_sequence_example(**kwargs)
+ if batch:
+ c_out, fl_out, _ = parsing_ops.parse_sequence_example(**kwargs)
+ else:
+ c_out, fl_out = parsing_ops.parse_single_sequence_example(**kwargs)
if c_out:
sess.run(flatten_values_tensors_or_sparse(c_out.values()))
if fl_out:
sess.run(flatten_values_tensors_or_sparse(fl_out.values()))
else:
# Returns dicts w/ Tensors and SparseTensors.
- context_out, feat_list_out = parsing_ops.parse_single_sequence_example(
- **kwargs)
+ if batch:
+ (context_out, feat_list_out,
+ lengths_out) = parsing_ops.parse_sequence_example(**kwargs)
+ else:
+ (context_out,
+ feat_list_out) = parsing_ops.parse_single_sequence_example(**kwargs)
+ lengths_out = {}
+
context_result = sess.run(
- flatten_values_tensors_or_sparse(context_out.values(
- ))) if context_out else []
+ flatten_values_tensors_or_sparse(
+ context_out.values())) if context_out else []
feat_list_result = sess.run(
- flatten_values_tensors_or_sparse(feat_list_out.values(
- ))) if feat_list_out else []
+ flatten_values_tensors_or_sparse(
+ feat_list_out.values())) if feat_list_out else []
+ lengths_result = sess.run(
+ flatten_values_tensors_or_sparse(
+ lengths_out.values())) if lengths_out else []
# Check values.
_compare_output_to_expected(self, context_out, expected_context_values,
context_result)
_compare_output_to_expected(self, feat_list_out,
expected_feat_list_values, feat_list_result)
+ _compare_output_to_expected(self, lengths_out, expected_length_values,
+ lengths_result)
# Check shapes; if serialized is a Tensor we need its size to
# properly check.
if "context_features" in kwargs:
for k, f in kwargs["context_features"].items():
if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
+ if batch:
+ self.assertEqual(
+ tuple(context_out[k].get_shape().as_list()[1:]), f.shape)
+ else:
+ self.assertEqual(
+ tuple(context_out[k].get_shape().as_list()), f.shape)
+ elif isinstance(f, parsing_ops.VarLenFeature) and batch:
self.assertEqual(
- tuple(context_out[k].get_shape().as_list()), f.shape)
- elif isinstance(f, parsing_ops.VarLenFeature):
+ tuple(context_out[k].indices.get_shape().as_list()), (None, 2))
+ self.assertEqual(
+ tuple(context_out[k].values.get_shape().as_list()), (None,))
+ self.assertEqual(
+ tuple(context_out[k].dense_shape.get_shape().as_list()), (2,))
+ elif isinstance(f, parsing_ops.VarLenFeature) and not batch:
self.assertEqual(
tuple(context_out[k].indices.get_shape().as_list()), (None, 1))
self.assertEqual(
@@ -1094,38 +1119,94 @@ class ParseSequenceExampleTest(test.TestCase):
self.assertEqual(
tuple(context_out[k].dense_shape.get_shape().as_list()), (1,))
+ def _testBoth(self,
+ kwargs,
+ expected_context_values=None,
+ expected_feat_list_values=None,
+ expected_err=None):
+ # Test using tf.parse_single_sequence_example
+ self._test(
+ kwargs,
+ expected_context_values=expected_context_values,
+ expected_feat_list_values=expected_feat_list_values,
+ expected_err=expected_err,
+ batch=False)
+
+ # Convert the input to a batch of size 1, and test using
+ # tf.parse_sequence_example.
+
+ # Some replacements are needed for the batch version.
+ kwargs["serialized"] = [kwargs.pop("serialized")]
+ kwargs["example_names"] = [kwargs.pop("example_name")
+ ] if "example_name" in kwargs else None
+ # Disable error string matching; it's not consistent for batch mode.
+ if expected_err:
+ expected_err = (expected_err[0], "")
+
+ # Add a batch dimension to expected output
+ if expected_context_values:
+ new_values = {}
+ for k in expected_context_values:
+ v = expected_context_values[k]
+ if isinstance(kwargs["context_features"][k],
+ parsing_ops.FixedLenFeature):
+ new_values[k] = np.expand_dims(v, axis=0)
+ else:
+ # Sparse tensor.
+ new_values[k] = (np.insert(v[0], 0, 0, axis=1), v[1],
+ np.insert(v[2], 0, 1))
+ expected_context_values = new_values
+
+ expected_length_values = {}
+ if expected_feat_list_values:
+ new_values = {}
+ for k in expected_feat_list_values:
+ v = expected_feat_list_values[k]
+ if isinstance(kwargs["sequence_features"][k],
+ parsing_ops.FixedLenSequenceFeature):
+ expected_length_values[k] = [np.shape(v)[0]]
+ new_values[k] = np.expand_dims(v, axis=0)
+ else:
+ # Sparse tensor.
+ new_values[k] = (np.insert(v[0], 0, 0, axis=1), v[1],
+ np.insert(v[2], 0, 1))
+ expected_feat_list_values = new_values
+
+ self._test(
+ kwargs,
+ expected_context_values=expected_context_values,
+ expected_feat_list_values=expected_feat_list_values,
+ expected_length_values=expected_length_values,
+ expected_err=expected_err,
+ batch=True)
+
def testSequenceExampleWithSparseAndDenseContext(self):
- original = sequence_example(context=features({
- "c": float_feature([3, 4]),
- "st_a": float_feature([3.0, 4.0])
- }))
+ original = sequence_example(
+ context=features({
+ "c": float_feature([3, 4]),
+ "st_a": float_feature([3.0, 4.0])
+ }))
serialized = original.SerializeToString()
expected_st_a = (
- np.array(
- [[0], [1]], dtype=np.int64), # indices
- np.array(
- [3.0, 4.0], dtype=np.float32), # values
- np.array(
- [2], dtype=np.int64)) # shape: num_features = 2
+ np.array([[0], [1]], dtype=np.int64), # indices
+ np.array([3.0, 4.0], dtype=np.float32), # values
+ np.array([2], dtype=np.int64)) # shape: num_features = 2
- a_default = [1, 2, 3]
+ a_default = [[1, 2, 3]]
b_default = np.random.rand(3, 3).astype(bytes)
expected_context_output = {
"st_a": expected_st_a,
- "a": [a_default],
+ "a": a_default,
"b": b_default,
- "c": np.array(
- [3, 4], dtype=np.float32),
+ "c": np.array([3, 4], dtype=np.float32),
}
- self._test(
+ self._testBoth(
{
- "example_name":
- "in1",
- "serialized":
- ops.convert_to_tensor(serialized),
+ "example_name": "in1",
+ "serialized": ops.convert_to_tensor(serialized),
"context_features": {
"st_a":
parsing_ops.VarLenFeature(dtypes.float32),
@@ -1143,51 +1224,54 @@ class ParseSequenceExampleTest(test.TestCase):
expected_context_values=expected_context_output)
def testSequenceExampleWithMultipleSizeFeatureLists(self):
- original = sequence_example(feature_lists=feature_lists({
- "a":
- feature_list([
- int64_feature([-1, 0, 1]),
- int64_feature([2, 3, 4]),
- int64_feature([5, 6, 7]),
- int64_feature([8, 9, 10]),
- ]),
- "b":
- feature_list([bytes_feature([b"r00", b"r01", b"r10", b"r11"])]),
- "c":
- feature_list([float_feature([3, 4]), float_feature([-1, 2])]),
- }))
+ original = sequence_example(
+ feature_lists=feature_lists({
+ "a":
+ feature_list([
+ int64_feature([-1, 0, 1]),
+ int64_feature([2, 3, 4]),
+ int64_feature([5, 6, 7]),
+ int64_feature([8, 9, 10]),
+ ]),
+ "b":
+ feature_list([bytes_feature([b"r00", b"r01", b"r10", b"r11"])]),
+ "c":
+ feature_list([float_feature([3, 4]),
+ float_feature([-1, 2])]),
+ }))
serialized = original.SerializeToString()
expected_feature_list_output = {
- "a": np.array(
- [ # outer dimension is time.
- [[-1, 0, 1]], # inside are 1x3 matrices
- [[2, 3, 4]],
- [[5, 6, 7]],
- [[8, 9, 10]]
- ],
- dtype=np.int64),
- "b": np.array(
- [ # outer dimension is time, inside are 2x2 matrices
- [[b"r00", b"r01"], [b"r10", b"r11"]]
- ],
- dtype=bytes),
- "c": np.array(
- [ # outer dimension is time, inside are 2-vectors
- [3, 4], [-1, 2]
- ],
- dtype=np.float32),
- "d": np.empty(
- shape=(0, 5), dtype=np.float32), # empty_allowed_missing
+ "a":
+ np.array(
+ [ # outer dimension is time.
+ [[-1, 0, 1]], # inside are 1x3 matrices
+ [[2, 3, 4]],
+ [[5, 6, 7]],
+ [[8, 9, 10]]
+ ],
+ dtype=np.int64),
+ "b":
+ np.array(
+ [ # outer dimension is time, inside are 2x2 matrices
+ [[b"r00", b"r01"], [b"r10", b"r11"]]
+ ],
+ dtype=bytes),
+ "c":
+ np.array(
+ [ # outer dimension is time, inside are 2-vectors
+ [3, 4], [-1, 2]
+ ],
+ dtype=np.float32),
+ "d":
+ np.empty(shape=(0, 5), dtype=np.float32), # empty_allowed_missing
}
- self._test(
+ self._testBoth(
{
- "example_name":
- "in1",
- "serialized":
- ops.convert_to_tensor(serialized),
+ "example_name": "in1",
+ "serialized": ops.convert_to_tensor(serialized),
"sequence_features": {
"a":
parsing_ops.FixedLenSequenceFeature((1, 3), dtypes.int64),
@@ -1203,56 +1287,51 @@ class ParseSequenceExampleTest(test.TestCase):
expected_feat_list_values=expected_feature_list_output)
def testSequenceExampleWithoutDebugName(self):
- original = sequence_example(feature_lists=feature_lists({
- "a":
- feature_list([int64_feature([3, 4]), int64_feature([1, 0])]),
- "st_a":
- feature_list([
- float_feature([3.0, 4.0]), float_feature([5.0]),
- float_feature([])
- ]),
- "st_b":
- feature_list([
- bytes_feature([b"a"]), bytes_feature([]), bytes_feature([]),
- bytes_feature([b"b", b"c"])
- ])
- }))
+ original = sequence_example(
+ feature_lists=feature_lists({
+ "a":
+ feature_list([int64_feature([3, 4]),
+ int64_feature([1, 0])]),
+ "st_a":
+ feature_list([
+ float_feature([3.0, 4.0]),
+ float_feature([5.0]),
+ float_feature([])
+ ]),
+ "st_b":
+ feature_list([
+ bytes_feature([b"a"]),
+ bytes_feature([]),
+ bytes_feature([]),
+ bytes_feature([b"b", b"c"])
+ ])
+ }))
serialized = original.SerializeToString()
expected_st_a = (
- np.array(
- [[0, 0], [0, 1], [1, 0]], dtype=np.int64), # indices
- np.array(
- [3.0, 4.0, 5.0], dtype=np.float32), # values
- np.array(
- [3, 2], dtype=np.int64)) # shape: num_time = 3, max_feat = 2
+ np.array([[0, 0], [0, 1], [1, 0]], dtype=np.int64), # indices
+ np.array([3.0, 4.0, 5.0], dtype=np.float32), # values
+ np.array([3, 2], dtype=np.int64)) # shape: num_time = 3, max_feat = 2
expected_st_b = (
- np.array(
- [[0, 0], [3, 0], [3, 1]], dtype=np.int64), # indices
- np.array(
- ["a", "b", "c"], dtype="|S"), # values
- np.array(
- [4, 2], dtype=np.int64)) # shape: num_time = 4, max_feat = 2
+ np.array([[0, 0], [3, 0], [3, 1]], dtype=np.int64), # indices
+ np.array(["a", "b", "c"], dtype="|S"), # values
+ np.array([4, 2], dtype=np.int64)) # shape: num_time = 4, max_feat = 2
expected_st_c = (
- np.empty(
- (0, 2), dtype=np.int64), # indices
- np.empty(
- (0,), dtype=np.int64), # values
- np.array(
- [0, 0], dtype=np.int64)) # shape: num_time = 0, max_feat = 0
+ np.empty((0, 2), dtype=np.int64), # indices
+ np.empty((0,), dtype=np.int64), # values
+ np.array([0, 0], dtype=np.int64)) # shape: num_time = 0, max_feat = 0
expected_feature_list_output = {
- "a": np.array(
- [[3, 4], [1, 0]], dtype=np.int64),
+ "a": np.array([[3, 4], [1, 0]], dtype=np.int64),
"st_a": expected_st_a,
"st_b": expected_st_b,
"st_c": expected_st_c,
}
- self._test(
+ self._testBoth(
{
"serialized": ops.convert_to_tensor(serialized),
"sequence_features": {
@@ -1265,56 +1344,51 @@ class ParseSequenceExampleTest(test.TestCase):
expected_feat_list_values=expected_feature_list_output)
def testSequenceExampleWithSparseAndDenseFeatureLists(self):
- original = sequence_example(feature_lists=feature_lists({
- "a":
- feature_list([int64_feature([3, 4]), int64_feature([1, 0])]),
- "st_a":
- feature_list([
- float_feature([3.0, 4.0]), float_feature([5.0]),
- float_feature([])
- ]),
- "st_b":
- feature_list([
- bytes_feature([b"a"]), bytes_feature([]), bytes_feature([]),
- bytes_feature([b"b", b"c"])
- ])
- }))
+ original = sequence_example(
+ feature_lists=feature_lists({
+ "a":
+ feature_list([int64_feature([3, 4]),
+ int64_feature([1, 0])]),
+ "st_a":
+ feature_list([
+ float_feature([3.0, 4.0]),
+ float_feature([5.0]),
+ float_feature([])
+ ]),
+ "st_b":
+ feature_list([
+ bytes_feature([b"a"]),
+ bytes_feature([]),
+ bytes_feature([]),
+ bytes_feature([b"b", b"c"])
+ ])
+ }))
serialized = original.SerializeToString()
expected_st_a = (
- np.array(
- [[0, 0], [0, 1], [1, 0]], dtype=np.int64), # indices
- np.array(
- [3.0, 4.0, 5.0], dtype=np.float32), # values
- np.array(
- [3, 2], dtype=np.int64)) # shape: num_time = 3, max_feat = 2
+ np.array([[0, 0], [0, 1], [1, 0]], dtype=np.int64), # indices
+ np.array([3.0, 4.0, 5.0], dtype=np.float32), # values
+ np.array([3, 2], dtype=np.int64)) # shape: num_time = 3, max_feat = 2
expected_st_b = (
- np.array(
- [[0, 0], [3, 0], [3, 1]], dtype=np.int64), # indices
- np.array(
- ["a", "b", "c"], dtype="|S"), # values
- np.array(
- [4, 2], dtype=np.int64)) # shape: num_time = 4, max_feat = 2
+ np.array([[0, 0], [3, 0], [3, 1]], dtype=np.int64), # indices
+ np.array(["a", "b", "c"], dtype="|S"), # values
+ np.array([4, 2], dtype=np.int64)) # shape: num_time = 4, max_feat = 2
expected_st_c = (
- np.empty(
- (0, 2), dtype=np.int64), # indices
- np.empty(
- (0,), dtype=np.int64), # values
- np.array(
- [0, 0], dtype=np.int64)) # shape: num_time = 0, max_feat = 0
+ np.empty((0, 2), dtype=np.int64), # indices
+ np.empty((0,), dtype=np.int64), # values
+ np.array([0, 0], dtype=np.int64)) # shape: num_time = 0, max_feat = 0
expected_feature_list_output = {
- "a": np.array(
- [[3, 4], [1, 0]], dtype=np.int64),
+ "a": np.array([[3, 4], [1, 0]], dtype=np.int64),
"st_a": expected_st_a,
"st_b": expected_st_b,
"st_c": expected_st_c,
}
- self._test(
+ self._testBoth(
{
"example_name": "in1",
"serialized": ops.convert_to_tensor(serialized),
@@ -1328,30 +1402,28 @@ class ParseSequenceExampleTest(test.TestCase):
expected_feat_list_values=expected_feature_list_output)
def testSequenceExampleWithEmptyFeatureInFeatureLists(self):
- original = sequence_example(feature_lists=feature_lists({
- "st_a":
- feature_list([
- float_feature([3.0, 4.0]),
- feature(),
- float_feature([5.0]),
- ]),
- }))
+ original = sequence_example(
+ feature_lists=feature_lists({
+ "st_a":
+ feature_list([
+ float_feature([3.0, 4.0]),
+ feature(),
+ float_feature([5.0]),
+ ]),
+ }))
serialized = original.SerializeToString()
expected_st_a = (
- np.array(
- [[0, 0], [0, 1], [2, 0]], dtype=np.int64), # indices
- np.array(
- [3.0, 4.0, 5.0], dtype=np.float32), # values
- np.array(
- [3, 2], dtype=np.int64)) # shape: num_time = 3, max_feat = 2
+ np.array([[0, 0], [0, 1], [2, 0]], dtype=np.int64), # indices
+ np.array([3.0, 4.0, 5.0], dtype=np.float32), # values
+ np.array([3, 2], dtype=np.int64)) # shape: num_time = 3, max_feat = 2
expected_feature_list_output = {
"st_a": expected_st_a,
}
- self._test(
+ self._testBoth(
{
"example_name": "in1",
"serialized": ops.convert_to_tensor(serialized),
@@ -1362,13 +1434,15 @@ class ParseSequenceExampleTest(test.TestCase):
expected_feat_list_values=expected_feature_list_output)
def testSequenceExampleListWithInconsistentDataFails(self):
- original = sequence_example(feature_lists=feature_lists({
- "a": feature_list([int64_feature([-1, 0]), float_feature([2, 3])])
- }))
+ original = sequence_example(
+ feature_lists=feature_lists({
+ "a": feature_list([int64_feature([-1, 0]),
+ float_feature([2, 3])])
+ }))
serialized = original.SerializeToString()
- self._test(
+ self._testBoth(
{
"example_name": "in1",
"serialized": ops.convert_to_tensor(serialized),
@@ -1380,13 +1454,14 @@ class ParseSequenceExampleTest(test.TestCase):
" Data types don't match. Expected type: int64"))
def testSequenceExampleListWithWrongDataTypeFails(self):
- original = sequence_example(feature_lists=feature_lists({
- "a": feature_list([float_feature([2, 3])])
- }))
+ original = sequence_example(
+ feature_lists=feature_lists({
+ "a": feature_list([float_feature([2, 3])])
+ }))
serialized = original.SerializeToString()
- self._test(
+ self._testBoth(
{
"example_name": "in1",
"serialized": ops.convert_to_tensor(serialized),
@@ -1399,17 +1474,19 @@ class ParseSequenceExampleTest(test.TestCase):
" Expected type: int64"))
def testSequenceExampleListWithWrongSparseDataTypeFails(self):
- original = sequence_example(feature_lists=feature_lists({
- "a":
- feature_list([
- int64_feature([3, 4]), int64_feature([1, 2]),
- float_feature([2.0, 3.0])
- ])
- }))
+ original = sequence_example(
+ feature_lists=feature_lists({
+ "a":
+ feature_list([
+ int64_feature([3, 4]),
+ int64_feature([1, 2]),
+ float_feature([2.0, 3.0])
+ ])
+ }))
serialized = original.SerializeToString()
- self._test(
+ self._testBoth(
{
"example_name": "in1",
"serialized": ops.convert_to_tensor(serialized),
@@ -1423,13 +1500,16 @@ class ParseSequenceExampleTest(test.TestCase):
" Feature is: float_list"))
def testSequenceExampleListWithWrongShapeFails(self):
- original = sequence_example(feature_lists=feature_lists({
- "a": feature_list([int64_feature([2, 3]), int64_feature([2, 3, 4])]),
- }))
+ original = sequence_example(
+ feature_lists=feature_lists({
+ "a":
+ feature_list([int64_feature([2, 3]),
+ int64_feature([2, 3, 4])]),
+ }))
serialized = original.SerializeToString()
- self._test(
+ self._testBoth(
{
"example_name": "in1",
"serialized": ops.convert_to_tensor(serialized),
@@ -1446,7 +1526,7 @@ class ParseSequenceExampleTest(test.TestCase):
# Test fails because we didn't add:
# feature_list_dense_defaults = {"a": None}
- self._test(
+ self._testBoth(
{
"example_name": "in1",
"serialized": ops.convert_to_tensor(original.SerializeToString()),
@@ -1461,6 +1541,67 @@ class ParseSequenceExampleTest(test.TestCase):
" feature_list_dense_missing_assumed_empty or"
" feature_list_dense_defaults?"))
+ def testSequenceExampleBatch(self):
+ first = sequence_example(
+ feature_lists=feature_lists({
+ "a":
+ feature_list([
+ int64_feature([-1, 0, 1]),
+ int64_feature([2, 3, 4]),
+ int64_feature([5, 6, 7]),
+ int64_feature([8, 9, 10]),
+ ])
+ }))
+ second = sequence_example(
+ feature_lists=feature_lists({
+ "a": feature_list([
+ int64_feature([21, 2, 11]),
+ ])
+ }))
+
+ serialized = [first.SerializeToString(), second.SerializeToString()]
+
+ expected_feature_list_output = {
+ "a":
+ np.array(
+ [ # outermost dimension is example id
+ [ # middle dimension is time.
+ [[-1, 0, 1]], # inside are 1x3 matrices
+ [[2, 3, 4]],
+ [[5, 6, 7]],
+ [[8, 9, 10]]
+ ],
+ [ # middle dimension is time.
+ [[21, 2, 11]], # inside are 1x3 matrices
+ [[0, 0, 0]], # additional entries are padded with 0
+ [[0, 0, 0]],
+ [[0, 0, 0]]
+ ]
+ ],
+ dtype=np.int64),
+ "d":
+ np.empty(shape=(2, 0, 5), dtype=np.float32), # allowed_missing
+ }
+
+ self._test(
+ {
+ "example_names": ops.convert_to_tensor(["in1", "in2"]),
+ "serialized": ops.convert_to_tensor(serialized),
+ "sequence_features": {
+ "a":
+ parsing_ops.FixedLenSequenceFeature((1, 3), dtypes.int64),
+ "d":
+ parsing_ops.FixedLenSequenceFeature(
+ (5,), dtypes.float32, allow_missing=True),
+ }
+ },
+ expected_feat_list_values=expected_feature_list_output,
+ expected_length_values={
+ "a": [4, 1],
+ "d": [0, 0]
+ },
+ batch=True)
+
class DecodeJSONExampleTest(test.TestCase):
@@ -1531,24 +1672,27 @@ class DecodeJSONExampleTest(test.TestCase):
example(features=features({
"st_d": feature()
})),
- example(features=features({
- "st_c": float_feature([1, 2, -1]),
- "st_d": bytes_feature([b"hi"])
- })),
+ example(
+ features=features({
+ "st_c": float_feature([1, 2, -1]),
+ "st_d": bytes_feature([b"hi"])
+ })),
])
def testSerializedContainingBytes(self):
aname = "a"
bname = "b*has+a:tricky_name"
self._testRoundTrip([
- example(features=features({
- aname: float_feature([1, 1]),
- bname: bytes_feature([b"b0_str"])
- })),
- example(features=features({
- aname: float_feature([-1, -1]),
- bname: bytes_feature([b"b1"])
- })),
+ example(
+ features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str"])
+ })),
+ example(
+ features=features({
+ aname: float_feature([-1, -1]),
+ bname: bytes_feature([b"b1"])
+ })),
])
def testInvalidSyntax(self):
diff --git a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
index 4935ed6ca5..f50e39d6d5 100644
--- a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
@@ -157,7 +157,7 @@ class MatMulGradientTest(test.TestCase):
m, [3, 4],
x_init_value=b.eval(),
delta=delta))
- self.assertLess(err, delta / 2.)
+ self.assertLessEqual(err, delta / 2.)
def testGradientInput(self):
for tr_a in [True, False]:
diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py
index cb5a66312f..fc39de150e 100644
--- a/tensorflow/python/kernel_tests/sparse_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_ops_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@@ -205,6 +206,22 @@ class SparseMergeTest(test_util.TensorFlowTestCase):
output = sess.run(sp_output)
self._AssertResultsNotSorted(output, vocab_size)
+ def testShouldSetLastDimensionInDynamicShape(self):
+ with ops.Graph().as_default():
+ shape = constant_op.constant([2, 2], dtype=dtypes.int64)
+ dynamic_shape = array_ops.placeholder_with_default(shape, shape=[2])
+ ids = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]],
+ values=[1, 3],
+ dense_shape=dynamic_shape)
+ values = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]],
+ values=[0.4, 0.7],
+ dense_shape=dynamic_shape)
+ merged = sparse_ops.sparse_merge(
+ sp_ids=ids, sp_values=values, vocab_size=5)
+ self.assertEqual(5, merged.get_shape()[1])
+
class SparseMergeHighDimTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index b736b12416..d57b79cb90 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -41,6 +41,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.util import compat
+from tensorflow.python.util import tf_inspect
class VariableScopeTest(test.TestCase):
@@ -995,6 +996,13 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(
variable_scope.get_local_variable("w", []).name, "outer/w:0")
+ def testSignatureGetVarVsGetLocalVar(self):
+ """get_{local,}variable() must take the same list of args."""
+ arg_names = tf_inspect.getargspec(variable_scope.get_variable)[0]
+ local_arg_names = tf_inspect.getargspec(
+ variable_scope.get_local_variable)[0]
+ self.assertEqual(arg_names, local_arg_names)
+
def testGetVarWithDevice(self):
g = ops.Graph()
varname_type = []
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 7bf3869ddf..21ccbc6c33 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -693,9 +693,6 @@ def strided_slice(input_,
parent_name = name
if not (var is None and isinstance(op, ops.EagerTensor)):
- # TODO(b/113297051): Assigning a function to an EagerTensor seems to leak
- # memory. Slicing variables still leaks, although ".assign" is removed for
- # EagerTensors which are not variable slices to mitigate the issue.
def assign(val, name=None):
"""Closure that holds all the arguments to create an assignment."""
diff --git a/tensorflow/python/ops/collective_ops_test.py b/tensorflow/python/ops/collective_ops_test.py
index 6f3cd74406..78c4b4bfe0 100644
--- a/tensorflow/python/ops/collective_ops_test.py
+++ b/tensorflow/python/ops/collective_ops_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class CollectiveOpTest(test.TestCase):
- def _testCollectiveReduce(self, t0, t1, expected):
+ def _testCollectiveReduce(self, t0, t1, expected, set_graph_key):
group_key = 1
instance_key = 1
with self.test_session(
@@ -43,7 +43,8 @@ class CollectiveOpTest(test.TestCase):
colred1 = collective_ops.all_reduce(in1, 2, group_key, instance_key,
'Add', 'Div')
run_options = config_pb2.RunOptions()
- run_options.experimental.collective_graph_key = 1
+ if set_graph_key:
+ run_options.experimental.collective_graph_key = 1
results = sess.run([colred0, colred1], options=run_options)
self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5)
self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5)
@@ -51,10 +52,15 @@ class CollectiveOpTest(test.TestCase):
def testCollectiveReduce(self):
self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
- [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2])
+ [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], True)
+
+ def testCollectiveAutoGraphKey(self):
+ self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
+ [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
+ [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], False)
def testCollectiveReduceScalar(self):
- self._testCollectiveReduce(0.1, 0.3, 0.2)
+ self._testCollectiveReduce(0.1, 0.3, 0.2, True)
def _testCollectiveBroadcast(self, t0):
group_key = 1
diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py
index 871f236f78..d7834ba350 100644
--- a/tensorflow/python/ops/custom_gradient.py
+++ b/tensorflow/python/ops/custom_gradient.py
@@ -82,11 +82,10 @@ def custom_gradient(f):
scope must be using `ResourceVariable`s.
Args:
- f: function `f(x)` that returns a tuple `(y, grad_fn)` where:
- - `x` is a `Tensor` or sequence of `Tensor` inputs to the function.
+ f: function `f(*x)` that returns a tuple `(y, grad_fn)` where:
+ - `x` is a sequence of `Tensor` inputs to the function.
- `y` is a `Tensor` or sequence of `Tensor` outputs of applying
- TensorFlow
- operations in `f` to `x`.
+ TensorFlow operations in `f` to `x`.
- `grad_fn` is a function with the signature `g(*grad_ys)` which returns
a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect
to the `Tensor`s in `x`. `grad_ys` is a `Tensor` or sequence of
@@ -96,7 +95,8 @@ def custom_gradient(f):
signature `g(*grad_ys, variables=None)`, where `variables` is a list of
the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where
`grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>`
- with the derivatives of `Tensor`s in `y` with respect to the variables.
+ with the derivatives of `Tensor`s in `y` with respect to the variables
+ (that is, grad_vars has one Tensor per variable in variables).
Returns:
A function `h(x)` which returns the same value as `f(x)[0]` and whose
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index a6be82673f..a4e7c84ae4 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -91,7 +91,7 @@ def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
Example:
```python
- elems = [1, 2, 3, 4, 5, 6]
+ elems = tf.constant([1, 2, 3, 4, 5, 6])
sum = foldl(lambda a, x: a + x, elems)
# sum == 21
```
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index 4d75ee3974..fff3d9b930 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -39,12 +39,12 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import linalg_ops_impl
from tensorflow.python.ops import gen_linalg_ops
+from tensorflow.python.ops import linalg_ops_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
-from tensorflow.python.util.deprecation import (
- deprecated, deprecated_arg_values)
+from tensorflow.python.util.deprecation import deprecated
+from tensorflow.python.util.deprecation import deprecated_arg_values
from tensorflow.python.util.tf_export import tf_export
@@ -226,9 +226,7 @@ class Constant(Initializer):
return {"value": self.value, "dtype": self.dtype.name}
-@tf_export("keras.initializers.RandomUniform", "initializers.random_uniform",
- "random_uniform_initializer", "keras.initializers.uniform",
- "keras.initializers.random_uniform")
+@tf_export("initializers.random_uniform", "random_uniform_initializer")
class RandomUniform(Initializer):
"""Initializer that generates tensors with a uniform distribution.
@@ -264,9 +262,7 @@ class RandomUniform(Initializer):
}
-@tf_export("keras.initializers.RandomNormal", "initializers.random_normal",
- "random_normal_initializer", "keras.initializers.normal",
- "keras.initializers.random_normal")
+@tf_export("initializers.random_normal", "random_normal_initializer")
class RandomNormal(Initializer):
"""Initializer that generates tensors with a normal distribution.
@@ -302,9 +298,7 @@ class RandomNormal(Initializer):
}
-@tf_export("keras.initializers.TruncatedNormal",
- "initializers.truncated_normal", "truncated_normal_initializer",
- "keras.initializers.truncated_normal")
+@tf_export("initializers.truncated_normal", "truncated_normal_initializer")
class TruncatedNormal(Initializer):
"""Initializer that generates a truncated normal distribution.
@@ -1116,29 +1110,10 @@ class Identity(Initializer):
def get_config(self):
return {"gain": self.gain, "dtype": self.dtype.name}
-# Aliases.
-
-# pylint: disable=invalid-name
-zeros_initializer = Zeros
-ones_initializer = Ones
-constant_initializer = Constant
-random_uniform_initializer = RandomUniform
-random_normal_initializer = RandomNormal
-truncated_normal_initializer = TruncatedNormal
-uniform_unit_scaling_initializer = UniformUnitScaling
-variance_scaling_initializer = VarianceScaling
-orthogonal_initializer = Orthogonal
-identity_initializer = Identity
-convolutional_delta_orthogonal = ConvolutionDeltaOrthogonal
-convolutional_orthogonal_1d = ConvolutionOrthogonal1D
-convolutional_orthogonal_2d = ConvolutionOrthogonal2D
-convolutional_orthogonal_3d = ConvolutionOrthogonal3D
-# pylint: enable=invalid-name
-
@tf_export("glorot_uniform_initializer", "keras.initializers.glorot_uniform",
"initializers.glorot_uniform")
-def glorot_uniform_initializer(seed=None, dtype=dtypes.float32):
+class GlorotUniform(VarianceScaling):
"""The Glorot uniform initializer, also called Xavier uniform initializer.
It draws samples from a uniform distribution within [-limit, limit]
@@ -1153,17 +1128,28 @@ def glorot_uniform_initializer(seed=None, dtype=dtypes.float32):
`tf.set_random_seed`
for behavior.
dtype: The data type. Only floating point types are supported.
-
- Returns:
- An initializer.
"""
- return variance_scaling_initializer(
- scale=1.0, mode="fan_avg", distribution="uniform", seed=seed, dtype=dtype)
+
+ def __init__(self,
+ seed=None,
+ dtype=dtypes.float32):
+ super(GlorotUniform, self).__init__(
+ scale=1.0,
+ mode="fan_avg",
+ distribution="uniform",
+ seed=seed,
+ dtype=dtype)
+
+ def get_config(self):
+ return {
+ "seed": self.seed,
+ "dtype": self.dtype.name
+ }
@tf_export("glorot_normal_initializer", "keras.initializers.glorot_normal",
"initializers.glorot_normal")
-def glorot_normal_initializer(seed=None, dtype=dtypes.float32):
+class GlorotNormal(VarianceScaling):
"""The Glorot normal initializer, also called Xavier normal initializer.
It draws samples from a truncated normal distribution centered on 0
@@ -1178,16 +1164,45 @@ def glorot_normal_initializer(seed=None, dtype=dtypes.float32):
`tf.set_random_seed`
for behavior.
dtype: The data type. Only floating point types are supported.
-
- Returns:
- An initializer.
"""
- return variance_scaling_initializer(
- scale=1.0,
- mode="fan_avg",
- distribution="truncated_normal",
- seed=seed,
- dtype=dtype)
+
+ def __init__(self,
+ seed=None,
+ dtype=dtypes.float32):
+ super(GlorotNormal, self).__init__(
+ scale=1.0,
+ mode="fan_avg",
+ distribution="truncated_normal",
+ seed=seed,
+ dtype=dtype)
+
+ def get_config(self):
+ return {
+ "seed": self.seed,
+ "dtype": self.dtype.name
+ }
+
+
+# Aliases.
+
+# pylint: disable=invalid-name
+zeros_initializer = Zeros
+ones_initializer = Ones
+constant_initializer = Constant
+random_uniform_initializer = RandomUniform
+random_normal_initializer = RandomNormal
+truncated_normal_initializer = TruncatedNormal
+uniform_unit_scaling_initializer = UniformUnitScaling
+variance_scaling_initializer = VarianceScaling
+glorot_uniform_initializer = GlorotUniform
+glorot_normal_initializer = GlorotNormal
+orthogonal_initializer = Orthogonal
+identity_initializer = Identity
+convolutional_delta_orthogonal = ConvolutionDeltaOrthogonal
+convolutional_orthogonal_1d = ConvolutionOrthogonal1D
+convolutional_orthogonal_2d = ConvolutionOrthogonal2D
+convolutional_orthogonal_3d = ConvolutionOrthogonal3D
+# pylint: enable=invalid-name
@tf_export("keras.initializers.lecun_normal", "initializers.lecun_normal")
diff --git a/tensorflow/python/ops/init_ops_test.py b/tensorflow/python/ops/init_ops_test.py
index 6a1fe17119..5693c3caaf 100644
--- a/tensorflow/python/ops/init_ops_test.py
+++ b/tensorflow/python/ops/init_ops_test.py
@@ -20,10 +20,14 @@ from __future__ import print_function
import numpy as np
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -163,6 +167,40 @@ class InitializersTest(test.TestCase):
with self.cached_session():
self._runner(init_ops.Orthogonal(seed=123), tensor_shape, target_mean=0.)
+ def testVariablePlacementWithOrthogonalInitializer(self):
+ if not context.context().num_gpus():
+ self.skipTest('No devices other than CPUs found')
+ with ops.Graph().as_default() as g:
+ with ops.device('gpu:0'):
+ variable_scope.get_variable(
+ name='v', shape=[8, 2], initializer=init_ops.Orthogonal)
+ variable_scope.get_variable(
+ name='w', shape=[8, 2], initializer=init_ops.RandomNormal)
+ run_metadata = config_pb2.RunMetadata()
+ run_options = config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE)
+ config = config_pb2.ConfigProto(
+ allow_soft_placement=False, log_device_placement=True)
+
+ # Note: allow_soft_placement=False will fail whenever we cannot satisfy
+ # the colocation constraints.
+ with session.Session(config=config, graph=g) as sess:
+ sess.run(
+ variables.global_variables_initializer(),
+ options=run_options,
+ run_metadata=run_metadata)
+
+ def test_eager_orthogonal_gpu(self):
+ if not context.context().num_gpus():
+ self.skipTest('No devices other than CPUs found')
+ with context.eager_mode():
+ v = variable_scope.get_variable(
+ name='v', shape=[8, 2], initializer=init_ops.Orthogonal)
+ w = variable_scope.get_variable(
+ name='w', shape=[8, 2], initializer=init_ops.RandomNormal)
+ self.assertTrue('GPU' in v.handle.device)
+ self.assertTrue('GPU' in w.handle.device)
+
def test_Identity(self):
with self.cached_session():
tensor_shape = (3, 4, 5)
diff --git a/tensorflow/python/ops/list_ops.py b/tensorflow/python/ops/list_ops.py
index d9ede87530..145a5f358c 100644
--- a/tensorflow/python/ops/list_ops.py
+++ b/tensorflow/python/ops/list_ops.py
@@ -97,3 +97,18 @@ def _TensorListSetItemGrad(op, dlist):
element_grad = gen_list_ops.tensor_list_get_item(
dlist, index, element_dtype=item.dtype)
return list_grad, index_grad, element_grad
+
+
+@ops.RegisterGradient("TensorListGather")
+def _TensorListGatherGrad(op, dtensor):
+ _, indices = op.inputs
+ return gen_list_ops.tensor_list_scatter(
+ tensor=dtensor, indices=indices,
+ element_shape=ops.convert_to_tensor(-1, dtype=dtypes.int32)), None
+
+
+@ops.RegisterGradient("TensorListScatter")
+def _TensorListScatterGrad(op, dlist):
+ t, indices, _ = op.inputs
+ return gen_list_ops.tensor_list_gather(
+ dlist, indices, element_dtype=t.dtype), None
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index 6041e2a0c5..8224097ac4 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -897,6 +897,352 @@ def _parse_single_example_raw(serialized,
return outputs
+@tf_export("io.parse_sequence_example")
+def parse_sequence_example(serialized,
+ context_features=None,
+ sequence_features=None,
+ example_names=None,
+ name=None):
+ # pylint: disable=line-too-long
+ """Parses a batch of `SequenceExample` protos.
+
+ Parses a vector of serialized
+ [`SequenceExample`](https://www.tensorflow.org/code/tensorflow/core/example/example.proto)
+ protos given in `serialized`.
+
+ This op parses serialized sequence examples into a tuple of dictionaries
+ mapping keys to `Tensor` and `SparseTensor` objects respectively.
+ The first dictionary contains mappings for keys appearing in
+ `context_features`, and the second dictionary contains mappings for keys
+ appearing in `sequence_features`.
+
+ At least one of `context_features` and `sequence_features` must be provided
+ and non-empty.
+
+ The `context_features` keys are associated with a `SequenceExample` as a
+ whole, independent of time / frame. In contrast, the `sequence_features` keys
+ provide a way to access variable-length data within the `FeatureList` section
+ of the `SequenceExample` proto. While the shapes of `context_features` values
+ are fixed with respect to frame, the frame dimension (the first dimension)
+ of `sequence_features` values may vary between `SequenceExample` protos,
+ and even between `feature_list` keys within the same `SequenceExample`.
+
+ `context_features` contains `VarLenFeature` and `FixedLenFeature` objects.
+ Each `VarLenFeature` is mapped to a `SparseTensor`, and each `FixedLenFeature`
+ is mapped to a `Tensor`, of the specified type, shape, and default value.
+
+ `sequence_features` contains `VarLenFeature` and `FixedLenSequenceFeature`
+ objects. Each `VarLenFeature` is mapped to a `SparseTensor`, and each
+ `FixedLenSequenceFeature` is mapped to a `Tensor`, each of the specified type.
+ The shape will be `(B,T,) + df.dense_shape` for `FixedLenSequenceFeature`
+ `df`, where `B` is the batch size, and `T` is the length of the associated
+ `FeatureList` in the `SequenceExample`. For instance,
+ `FixedLenSequenceFeature([])` yields a scalar 2-D `Tensor` of static shape
+ `[None, None]` and dynamic shape `[B, T]`, while
+ `FixedLenSequenceFeature([k])` (for `int k >= 1`) yields a 3-D matrix `Tensor`
+ of static shape `[None, None, k]` and dynamic shape `[B, T, k]`.
+
+ Like the input, the resulting output tensors have a batch dimension. This
+ means that the original per-example shapes of `VarLenFeature`s and
+ `FixedLenSequenceFeature`s can be lost. To handle that situation, this op also
+ provides dicts of shape tensors as part of the output. There is one dict for
+ the context features, and one for the feature_list features. Context features
+ of type `FixedLenFeature`s will not be present, since their shapes are already
+ known by the caller. In situations where the input 'FixedLenFeature`s are of
+ different lengths across examples, the shorter examples will be padded with
+ default datatype values: 0 for numeric types, and the empty string for string
+ types.
+
+ Each `SparseTensor` corresponding to `sequence_features` represents a ragged
+ vector. Its indices are `[time, index]`, where `time` is the `FeatureList`
+ entry and `index` is the value's index in the list of values associated with
+ that time.
+
+ `FixedLenFeature` entries with a `default_value` and `FixedLenSequenceFeature`
+ entries with `allow_missing=True` are optional; otherwise, we will fail if
+ that `Feature` or `FeatureList` is missing from any example in `serialized`.
+
+ `example_name` may contain a descriptive name for the corresponding serialized
+ proto. This may be useful for debugging purposes, but it has no effect on the
+ output. If not `None`, `example_name` must be a scalar.
+
+ Args:
+ serialized: A vector (1-D Tensor) of type string containing binary
+ serialized `SequenceExample` protos.
+ context_features: A `dict` mapping feature keys to `FixedLenFeature` or
+ `VarLenFeature` values. These features are associated with a
+ `SequenceExample` as a whole.
+ sequence_features: A `dict` mapping feature keys to
+ `FixedLenSequenceFeature` or `VarLenFeature` values. These features are
+ associated with data within the `FeatureList` section of the
+ `SequenceExample` proto.
+ example_names: A vector (1-D Tensor) of strings (optional), the name of the
+ serialized protos.
+ name: A name for this operation (optional).
+
+ Returns:
+ A tuple of two `dict`s, each mapping keys to `Tensor`s and `SparseTensor`s.
+ The first dict contains the context key/values.
+ The second dict contains the feature_list key/values.
+
+ Raises:
+ ValueError: if any feature is invalid.
+ """
+ if not (context_features or sequence_features):
+ raise ValueError("Missing features.")
+ (context_sparse_keys, context_sparse_types, context_dense_keys,
+ context_dense_types,
+ context_dense_defaults, context_dense_shapes) = _features_to_raw_params(
+ context_features, [VarLenFeature, FixedLenFeature])
+ (feature_list_sparse_keys, feature_list_sparse_types, feature_list_dense_keys,
+ feature_list_dense_types, feature_list_dense_defaults,
+ feature_list_dense_shapes) = _features_to_raw_params(
+ sequence_features, [VarLenFeature, FixedLenSequenceFeature])
+ return _parse_sequence_example_raw(
+ serialized, example_names, context_sparse_keys, context_sparse_types,
+ context_dense_keys, context_dense_types, context_dense_defaults,
+ context_dense_shapes, feature_list_sparse_keys, feature_list_sparse_types,
+ feature_list_dense_keys, feature_list_dense_types,
+ feature_list_dense_shapes, feature_list_dense_defaults, name)
+
+
+def _parse_sequence_example_raw(serialized,
+ debug_name=None,
+ context_sparse_keys=None,
+ context_sparse_types=None,
+ context_dense_keys=None,
+ context_dense_types=None,
+ context_dense_defaults=None,
+ context_dense_shapes=None,
+ feature_list_sparse_keys=None,
+ feature_list_sparse_types=None,
+ feature_list_dense_keys=None,
+ feature_list_dense_types=None,
+ feature_list_dense_shapes=None,
+ feature_list_dense_defaults=None,
+ name=None):
+ """Parses a vector of `SequenceExample` protos.
+
+ Args:
+ serialized: A vector (1-D Tensor) of type string, containing binary
+ serialized `SequenceExample` protos.
+ debug_name: A vector (1-D Tensor) of strings (optional), the names of the
+ serialized protos.
+ context_sparse_keys: A list of string keys in the `SequenceExample`'s
+ features. The results for these keys will be returned as `SparseTensor`
+ objects.
+ context_sparse_types: A list of `DTypes`, the same length as `sparse_keys`.
+ Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string`
+ (`BytesList`) are supported.
+ context_dense_keys: A list of string keys in the examples' features. The
+ results for these keys will be returned as `Tensor`s
+ context_dense_types: A list of DTypes, same length as `context_dense_keys`.
+ Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string`
+ (`BytesList`) are supported.
+ context_dense_defaults: A dict mapping string keys to `Tensor`s. The keys of
+ the dict must match the context_dense_keys of the feature.
+ context_dense_shapes: A list of tuples, same length as `context_dense_keys`.
+ The shape of the data for each context_dense feature referenced by
+ `context_dense_keys`. Required for any input tensors identified by
+ `context_dense_keys` whose shapes are anything other than `[]` or `[1]`.
+ feature_list_sparse_keys: A list of string keys in the `SequenceExample`'s
+ feature_lists. The results for these keys will be returned as
+ `SparseTensor` objects.
+ feature_list_sparse_types: A list of `DTypes`, same length as `sparse_keys`.
+ Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string`
+ (`BytesList`) are supported.
+ feature_list_dense_keys: A list of string keys in the `SequenceExample`'s
+ features_lists. The results for these keys will be returned as `Tensor`s.
+ feature_list_dense_types: A list of `DTypes`, same length as
+ `feature_list_dense_keys`. Only `tf.float32` (`FloatList`), `tf.int64`
+ (`Int64List`), and `tf.string` (`BytesList`) are supported.
+ feature_list_dense_shapes: A list of tuples, same length as
+ `feature_list_dense_keys`. The shape of the data for each `FeatureList`
+ feature referenced by `feature_list_dense_keys`.
+ feature_list_dense_defaults: A dict mapping key strings to values. The only
+ currently allowed value is `None`. Any key appearing in this dict with
+ value `None` is allowed to be missing from the `SequenceExample`. If
+ missing, the key is treated as zero-length.
+ name: A name for this operation (optional).
+
+ Returns:
+ A tuple of three `dict`s, each mapping keys to `Tensor`s and
+ `SparseTensor`s. The first dict contains the context key/values,
+ the second dict contains the feature_list key/values, and the final dict
+ contains the lengths of any dense feature_list features.
+
+ Raises:
+ ValueError: If context_sparse and context_dense key sets intersect,
+ if feature_list_sparse and feature_list_dense key sets intersect,
+ if input lengths do not match up, or if a value in
+ feature_list_dense_defaults is not None.
+ TypeError: if feature_list_dense_defaults is not either None or a dict.
+ """
+ with ops.name_scope(name, "ParseSequenceExample", [serialized]):
+ context_dense_defaults = ({} if context_dense_defaults is None else
+ context_dense_defaults)
+ context_sparse_keys = ([] if context_sparse_keys is None else
+ context_sparse_keys)
+ context_sparse_types = ([] if context_sparse_types is None else
+ context_sparse_types)
+ context_dense_keys = ([]
+ if context_dense_keys is None else context_dense_keys)
+ context_dense_types = ([] if context_dense_types is None else
+ context_dense_types)
+ context_dense_shapes = ([[]] * len(context_dense_keys)
+ if context_dense_shapes is None else
+ context_dense_shapes)
+ feature_list_sparse_keys = ([] if feature_list_sparse_keys is None else
+ feature_list_sparse_keys)
+ feature_list_sparse_types = ([] if feature_list_sparse_types is None else
+ feature_list_sparse_types)
+ feature_list_dense_keys = ([] if feature_list_dense_keys is None else
+ feature_list_dense_keys)
+ feature_list_dense_types = ([] if feature_list_dense_types is None else
+ feature_list_dense_types)
+ feature_list_dense_shapes = ([[]] * len(feature_list_dense_keys)
+ if feature_list_dense_shapes is None else
+ feature_list_dense_shapes)
+ feature_list_dense_defaults = (
+ dict()
+ if feature_list_dense_defaults is None else feature_list_dense_defaults)
+ debug_name = [] if debug_name is None else debug_name
+
+ # Internal
+ feature_list_dense_missing_assumed_empty = []
+
+ num_context_dense = len(context_dense_keys)
+ num_feature_list_dense = len(feature_list_dense_keys)
+ num_context_sparse = len(context_sparse_keys)
+ num_feature_list_sparse = len(feature_list_sparse_keys)
+
+ if len(context_dense_shapes) != num_context_dense:
+ raise ValueError(
+ "len(context_dense_shapes) != len(context_dense_keys): %d vs. %d" %
+ (len(context_dense_shapes), num_context_dense))
+ if len(context_dense_types) != num_context_dense:
+ raise ValueError(
+ "len(context_dense_types) != len(num_context_dense): %d vs. %d" %
+ (len(context_dense_types), num_context_dense))
+ if len(feature_list_dense_shapes) != num_feature_list_dense:
+ raise ValueError(
+ "len(feature_list_dense_shapes) != len(feature_list_dense_keys): "
+ "%d vs. %d" % (len(feature_list_dense_shapes),
+ num_feature_list_dense))
+ if len(feature_list_dense_types) != num_feature_list_dense:
+ raise ValueError(
+ "len(feature_list_dense_types) != len(num_feature_list_dense):"
+ "%d vs. %d" % (len(feature_list_dense_types), num_feature_list_dense))
+ if len(context_sparse_types) != num_context_sparse:
+ raise ValueError(
+ "len(context_sparse_types) != len(context_sparse_keys): %d vs. %d" %
+ (len(context_sparse_types), num_context_sparse))
+ if len(feature_list_sparse_types) != num_feature_list_sparse:
+ raise ValueError(
+ "len(feature_list_sparse_types) != len(feature_list_sparse_keys): "
+ "%d vs. %d" % (len(feature_list_sparse_types),
+ num_feature_list_sparse))
+ if (num_context_dense + num_context_sparse + num_feature_list_dense +
+ num_feature_list_sparse) == 0:
+ raise ValueError(
+ "Must provide at least one context_sparse key, context_dense key, "
+ ", feature_list_sparse key, or feature_list_dense key")
+ if not set(context_dense_keys).isdisjoint(set(context_sparse_keys)):
+ raise ValueError(
+ "context_dense and context_sparse keys must not intersect; "
+ "intersection: %s" % set(context_dense_keys).intersection(
+ set(context_sparse_keys)))
+ if not set(feature_list_dense_keys).isdisjoint(
+ set(feature_list_sparse_keys)):
+ raise ValueError(
+ "feature_list_dense and feature_list_sparse keys must not intersect; "
+ "intersection: %s" % set(feature_list_dense_keys).intersection(
+ set(feature_list_sparse_keys)))
+ if not isinstance(feature_list_dense_defaults, dict):
+ raise TypeError("feature_list_dense_defaults must be a dict")
+ for k, v in feature_list_dense_defaults.items():
+ if v is not None:
+ raise ValueError(
+ "Value feature_list_dense_defaults[%s] must be None" % k)
+ feature_list_dense_missing_assumed_empty.append(k)
+
+ context_dense_defaults_vec = []
+ for i, key in enumerate(context_dense_keys):
+ default_value = context_dense_defaults.get(key)
+ if default_value is None:
+ default_value = constant_op.constant([], dtype=context_dense_types[i])
+ elif not isinstance(default_value, ops.Tensor):
+ key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
+ default_value = ops.convert_to_tensor(
+ default_value, dtype=context_dense_types[i], name=key_name)
+
+ context_dense_defaults_vec.append(default_value)
+
+ context_dense_shapes = [
+ tensor_shape.as_shape(shape).as_proto()
+ for shape in context_dense_shapes
+ ]
+ feature_list_dense_shapes = [
+ tensor_shape.as_shape(shape).as_proto()
+ for shape in feature_list_dense_shapes
+ ]
+
+ # pylint: disable=protected-access
+ outputs = gen_parsing_ops.parse_sequence_example(
+ serialized=serialized,
+ debug_name=debug_name,
+ Ncontext_sparse=num_context_sparse,
+ Ncontext_dense=num_context_dense,
+ Nfeature_list_sparse=num_feature_list_sparse,
+ Nfeature_list_dense=num_feature_list_dense,
+ context_dense_defaults=context_dense_defaults_vec,
+ context_sparse_keys=context_sparse_keys,
+ context_sparse_types=context_sparse_types,
+ context_dense_keys=context_dense_keys,
+ context_dense_shapes=context_dense_shapes,
+ feature_list_sparse_keys=feature_list_sparse_keys,
+ feature_list_sparse_types=feature_list_sparse_types,
+ feature_list_dense_keys=feature_list_dense_keys,
+ feature_list_dense_types=feature_list_dense_types,
+ feature_list_dense_shapes=feature_list_dense_shapes,
+ feature_list_dense_missing_assumed_empty=(
+ feature_list_dense_missing_assumed_empty),
+ name=name)
+ # pylint: enable=protected-access
+
+ (context_sparse_indices, context_sparse_values, context_sparse_shapes,
+ context_dense_values, feature_list_sparse_indices,
+ feature_list_sparse_values, feature_list_sparse_shapes,
+ feature_list_dense_values, feature_list_dense_lengths) = outputs
+
+ context_sparse_tensors = [
+ sparse_tensor.SparseTensor(ix, val, shape)
+ for (ix, val,
+ shape) in zip(context_sparse_indices, context_sparse_values,
+ context_sparse_shapes)
+ ]
+
+ feature_list_sparse_tensors = [
+ sparse_tensor.SparseTensor(ix, val, shape)
+ for (ix, val, shape
+ ) in zip(feature_list_sparse_indices, feature_list_sparse_values,
+ feature_list_sparse_shapes)
+ ]
+
+ context_output = dict(
+ zip(context_sparse_keys + context_dense_keys,
+ context_sparse_tensors + context_dense_values))
+ feature_list_output = dict(
+ zip(feature_list_sparse_keys + feature_list_dense_keys,
+ feature_list_sparse_tensors + feature_list_dense_values))
+ feature_list_lengths = dict(
+ zip(feature_list_dense_keys, feature_list_dense_lengths))
+
+ return (context_output, feature_list_output, feature_list_lengths)
+
+
+# TODO(sundberg): rewrite this method to call the batch version, which is more
+# efficient especially for large inputs.
@tf_export("parse_single_sequence_example")
def parse_single_sequence_example(
serialized, context_features=None, sequence_features=None,
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index c128a1039a..fa13568596 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -374,6 +374,9 @@ class LayerRNNCell(RNNCell):
class BasicRNNCell(LayerRNNCell):
"""The most basic RNN cell.
+ Note that this cell is not optimized for performance. Please use
+ `tf.contrib.cudnn_rnn.CudnnRNNTanh` for better performance on GPU.
+
Args:
num_units: int, The number of units in the RNN cell.
activation: Nonlinearity to use. Default: `tanh`. It could also be string
@@ -399,6 +402,10 @@ class BasicRNNCell(LayerRNNCell):
**kwargs):
super(BasicRNNCell, self).__init__(
_reuse=reuse, name=name, dtype=dtype, **kwargs)
+ if context.executing_eagerly() and context.num_gpus() > 0:
+ logging.warn("%s: Note that this cell is not optimized for performance. "
+ "Please use tf.contrib.cudnn_rnn.CudnnRNNTanh for better "
+ "performance on GPU.", self)
# Inputs must be 2-dimensional.
self.input_spec = base_layer.InputSpec(ndim=2)
@@ -457,6 +464,10 @@ class BasicRNNCell(LayerRNNCell):
class GRUCell(LayerRNNCell):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).
+ Note that this cell is not optimized for performance. Please use
+ `tf.contrib.cudnn_rnn.CudnnGRU` for better performance on GPU, or
+ `tf.contrib.rnn.GRUBlockCellV2` for better performance on CPU.
+
Args:
num_units: int, The number of units in the GRU cell.
activation: Nonlinearity to use. Default: `tanh`.
@@ -487,6 +498,10 @@ class GRUCell(LayerRNNCell):
super(GRUCell, self).__init__(
_reuse=reuse, name=name, dtype=dtype, **kwargs)
+ if context.executing_eagerly() and context.num_gpus() > 0:
+ logging.warn("%s: Note that this cell is not optimized for performance. "
+ "Please use tf.contrib.cudnn_rnn.CudnnGRU for better "
+ "performance on GPU.", self)
# Inputs must be 2-dimensional.
self.input_spec = base_layer.InputSpec(ndim=2)
@@ -610,6 +625,11 @@ class BasicLSTMCell(LayerRNNCell):
For advanced models, please use the full `tf.nn.rnn_cell.LSTMCell`
that follows.
+
+ Note that this cell is not optimized for performance. Please use
+ `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or
+ `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for
+ better performance on CPU.
"""
@deprecated(None, "This class is deprecated, please use "
@@ -656,6 +676,10 @@ class BasicLSTMCell(LayerRNNCell):
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
+ if context.executing_eagerly() and context.num_gpus() > 0:
+ logging.warn("%s: Note that this cell is not optimized for performance. "
+ "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better "
+ "performance on GPU.", self)
# Inputs must be 2-dimensional.
self.input_spec = base_layer.InputSpec(ndim=2)
@@ -774,6 +798,11 @@ class LSTMCell(LayerRNNCell):
The class uses optional peep-hole connections, optional cell clipping, and
an optional projection layer.
+
+ Note that this cell is not optimized for performance. Please use
+ `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or
+ `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for
+ better performance on CPU.
"""
def __init__(self, num_units,
@@ -833,6 +862,10 @@ class LSTMCell(LayerRNNCell):
"%s: The num_unit_shards and proj_unit_shards parameters are "
"deprecated and will be removed in Jan 2017. "
"Use a variable scope with a partitioner instead.", self)
+ if context.executing_eagerly() and context.num_gpus() > 0:
+ logging.warn("%s: Note that this cell is not optimized for performance. "
+ "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better "
+ "performance on GPU.", self)
# Inputs must be 2-dimensional.
self.input_spec = base_layer.InputSpec(ndim=2)
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index d1b8be4df7..400a42a3c0 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -1351,7 +1351,11 @@ def sparse_merge(sp_ids, sp_values, vocab_size, name=None,
new_shape = array_ops.concat([sp_ids[0].dense_shape[:-1], vocab_size], 0)
result = sparse_tensor.SparseTensor(new_indices, new_values, new_shape)
- return result if already_sorted else sparse_reorder(result)
+ if already_sorted:
+ return result
+ sorted_result = sparse_reorder(result)
+ return sparse_tensor.SparseTensor(
+ sorted_result.indices, sorted_result.values, new_shape)
@tf_export("sparse_retain")
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index f53e06fdf9..a43676cd70 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -895,7 +895,7 @@ class _VariableStore(object):
elif not tf_inspect.getargspec(initializer).args:
init_val = initializer
else:
- raise ValueError("You can only pass an initializer function that"
+ raise ValueError("You can only pass an initializer function that "
"expects no arguments to its callable when the "
"shape is not fully defined. The given initializer "
"function expects the following args %s" %
@@ -1558,6 +1558,22 @@ Args:
def custom_getter(getter, name, *args, **kwargs):
return getter(name + '_suffix', *args, **kwargs)
```
+ constraint: An optional projection function to be applied to the variable
+ after being updated by an `Optimizer` (e.g. used to implement norm
+ constraints or value constraints for layer weights). The function must
+ take as input the unprojected Tensor representing the value of the
+ variable and return the Tensor for the projected value
+ (which must have the same shape). Constraints are not safe to
+ use when doing asynchronous distributed training.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ `tf.VariableSynchronization`. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ `tf.VariableAggregation`.
Returns:
The created or existing `Variable` (or `PartitionedVariable`, if a
@@ -1591,10 +1607,10 @@ def get_local_variable( # pylint: disable=missing-docstring
partitioner=None,
validate_shape=True,
use_resource=None,
- synchronization=VariableSynchronization.AUTO,
- aggregation=VariableAggregation.NONE,
custom_getter=None,
- constraint=None):
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
if collections:
collections += [ops.GraphKeys.LOCAL_VARIABLES]
else:
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index f7da3f7d64..7a46157739 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -55,33 +55,47 @@ def _make_getter(captured_getter, captured_previous):
@tf_export("VariableSynchronization")
class VariableSynchronization(enum.Enum):
- """Indicates when a distributed variable will be synced."""
-
- # Indicates that the synchronization will be determined by the current
- # `DistributionStrategy` (eg. With `MirroredStrategy` this would be
- # `ON_WRITE`).
+ """Indicates when a distributed variable will be synced.
+
+ * `AUTO`: Indicates that the synchronization will be determined by the current
+ `DistributionStrategy` (eg. With `MirroredStrategy` this would be
+ `ON_WRITE`).
+ * `NONE`: Indicates that there will only be one copy of the variable, so
+ there is no need to sync.
+ * `ON_WRITE`: Indicates that the variable will be updated across devices
+ every time it is written.
+ * `ON_READ`: Indicates that the variable will be aggregated across devices
+ when it is read (eg. when checkpointing or when evaluating an op that uses
+ the variable).
+ """
AUTO = 0
-
- # Indicates that there will only be one copy of the variable, so there is no
- # need to sync.
NONE = 1
-
- # Indicates that the variable will be aggregated across devices
- # every time it is updated.
ON_WRITE = 2
-
- # Indicates that the variable will be aggregated across devices
- # when it is read (eg. when checkpointing or when evaluating an op that uses
- # the variable).
ON_READ = 3
@tf_export("VariableAggregation")
class VariableAggregation(enum.Enum):
- """Indicates how a distributed variable will be aggregated."""
+ """Indicates how a distributed variable will be aggregated.
+
+ `tf.contrib.distribute.DistributionStrategy` distributes a model by making
+ multiple copies (called "towers") acting data-parallel on different elements
+ of the input batch. When performing some variable-update operation, say
+ `var.assign_add(x)`, in a model, we need to resolve how to combine the
+ different values for `x` computed in the different towers.
+
+ * `NONE`: This is the default, giving an error if you use a
+ variable-update operation with multiple towers.
+ * `SUM`: Add the updates across towers.
+ * `MEAN`: Take the arithmetic mean ("average") of the updates across towers.
+ * `ONLY_FIRST_TOWER`: This is for when every tower is performing the same
+ update, but we only want to perform the update once. Used, e.g., for the
+ global step counter.
+ """
NONE = 0
SUM = 1
MEAN = 2
+ ONLY_FIRST_TOWER = 3
class VariableMetaclass(type):
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index e1c233cdd9..a31861ae40 100755
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -50,11 +50,11 @@ limitations under the License.
%rename("%s") TFE_Py_TapeSetRestartOnThread;
%rename("%s") TFE_Py_TapeSetIsEmpty;
%rename("%s") TFE_Py_TapeSetShouldRecord;
-%rename("%s") TFE_Py_TapeSetWatch;
%rename("%s") TFE_Py_TapeSetDeleteTrace;
%rename("%s") TFE_Py_TapeSetRecordOperation;
%rename("%s") TFE_Py_TapeSetWatchVariable;
%rename("%s") TFE_Py_TapeGradient;
+%rename("%s") TFE_Py_TapeWatch;
%rename("%s") TFE_Py_TapeWatchedVariables;
%rename("%s") TFE_NewContextOptions;
%rename("%s") TFE_ContextOptionsSetConfig;
diff --git a/tensorflow/python/tools/api/generator/BUILD b/tensorflow/python/tools/api/generator/BUILD
index 36af091163..90be2cc4f7 100644
--- a/tensorflow/python/tools/api/generator/BUILD
+++ b/tensorflow/python/tools/api/generator/BUILD
@@ -6,6 +6,7 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "py_test")
load("//tensorflow/python/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES")
load("//tensorflow/python/tools/api/generator:api_init_files.bzl", "TENSORFLOW_API_INIT_FILES")
+load("//tensorflow/python/tools/api/generator:api_init_files_v1.bzl", "TENSORFLOW_API_INIT_FILES_V1")
exports_files(
[
@@ -55,7 +56,7 @@ py_test(
args = [
"--package=tensorflow.python",
"--api_name=tensorflow",
- ] + TENSORFLOW_API_INIT_FILES,
+ ] + TENSORFLOW_API_INIT_FILES + TENSORFLOW_API_INIT_FILES_V1,
main = "doc_srcs_test.py",
srcs_version = "PY2AND3",
deps = [
diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl
index 64f0469482..92446e2f8f 100644
--- a/tensorflow/python/tools/api/generator/api_init_files.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files.bzl
@@ -87,7 +87,6 @@ TENSORFLOW_API_INIT_FILES = [
"sysconfig/__init__.py",
"test/__init__.py",
"train/__init__.py",
- "train/queue_runner/__init__.py",
"user_ops/__init__.py",
# END GENERATED FILES
]
diff --git a/tensorflow/python/tools/api/generator/doc_srcs.py b/tensorflow/python/tools/api/generator/doc_srcs.py
index ad1988494d..fbec9c6635 100644
--- a/tensorflow/python/tools/api/generator/doc_srcs.py
+++ b/tensorflow/python/tools/api/generator/doc_srcs.py
@@ -62,8 +62,6 @@ _TENSORFLOW_DOC_SOURCES = {
'sysconfig': DocSource(docstring_module_name='platform.sysconfig'),
'test': DocSource(docstring_module_name='platform.test'),
'train': DocSource(docstring_module_name='training.training'),
- 'train.queue_runner': DocSource(
- docstring_module_name='training.queue_runner'),
}
_ESTIMATOR_DOC_SOURCES = {
diff --git a/tensorflow/python/tools/optimize_for_inference_lib.py b/tensorflow/python/tools/optimize_for_inference_lib.py
index bb90d1cd6e..108f2b593c 100644
--- a/tensorflow/python/tools/optimize_for_inference_lib.py
+++ b/tensorflow/python/tools/optimize_for_inference_lib.py
@@ -133,14 +133,14 @@ def ensure_graph_is_valid(graph_def):
"""
node_map = {}
for node in graph_def.node:
- if node.name not in node_map.keys():
+ if node.name not in node_map:
node_map[node.name] = node
else:
raise ValueError("Duplicate node names detected for ", node.name)
for node in graph_def.node:
for input_name in node.input:
input_node_name = node_name_from_input(input_name)
- if input_node_name not in node_map.keys():
+ if input_node_name not in node_map:
raise ValueError("Input for ", node.name, " not found: ", input_name)
@@ -225,7 +225,7 @@ def fold_batch_norms(input_graph_def):
"""
input_node_map = {}
for node in input_graph_def.node:
- if node.name not in input_node_map.keys():
+ if node.name not in input_node_map:
input_node_map[node.name] = node
else:
raise ValueError("Duplicate node names detected for ", node.name)
@@ -390,7 +390,7 @@ def fuse_resize_and_conv(input_graph_def, output_node_names):
input_node_map = {}
for node in input_graph_def.node:
- if node.name not in input_node_map.keys():
+ if node.name not in input_node_map:
input_node_map[node.name] = node
else:
raise ValueError("Duplicate node names detected for ", node.name)
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index 38fed5335e..6716c79f87 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -40,8 +40,8 @@ from tensorflow.python.client import session
from tensorflow.python.debug.wrappers import local_cli_wrapper
from tensorflow.python.framework import meta_graph as meta_graph_lib
from tensorflow.python.framework import ops as ops_lib
-from tensorflow.python.platform import app # pylint: disable=unused-import
from tensorflow.python.lib.io import file_io
+from tensorflow.python.platform import app # pylint: disable=unused-import
from tensorflow.python.saved_model import loader
from tensorflow.python.tools import saved_model_utils
@@ -140,7 +140,7 @@ def _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key, indent=0):
outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def(
meta_graph_def, signature_def_key)
- indent_str = " " * indent
+ indent_str = ' ' * indent
def in_print(s):
print(indent_str + s)
@@ -166,7 +166,7 @@ def _print_tensor_info(tensor_info, indent=0):
tensor_info: TensorInfo object to be printed.
indent: How far (in increments of 2 spaces) to indent each line output
"""
- indent_str = " " * indent
+ indent_str = ' ' * indent
def in_print(s):
print(indent_str + s)
@@ -270,7 +270,7 @@ def scan_meta_graph_def(meta_graph_def):
def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
input_tensor_key_feed_dict, outdir,
- overwrite_flag, tf_debug=False):
+ overwrite_flag, worker=None, tf_debug=False):
"""Runs SavedModel and fetch all outputs.
Runs the input dictionary through the MetaGraphDef within a SavedModel
@@ -288,6 +288,8 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
it will be created.
overwrite_flag: A boolean flag to allow overwrite output file if file with
the same name exists.
+ worker: If provided, the session will be run on the worker. Valid worker
+ specification is a bns or gRPC path.
tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the
intermediate Tensor values and runtime GraphDefs while running the
SavedModel.
@@ -308,7 +310,7 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
# Check if input tensor keys are valid.
for input_key_name in input_tensor_key_feed_dict.keys():
- if input_key_name not in inputs_tensor_info.keys():
+ if input_key_name not in inputs_tensor_info:
raise ValueError(
'"%s" is not a valid input key. Please choose from %s, or use '
'--show option.' %
@@ -328,7 +330,7 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
for tensor_key in output_tensor_keys_sorted
]
- with session.Session(graph=ops_lib.Graph()) as sess:
+ with session.Session(worker, graph=ops_lib.Graph()) as sess:
loader.load(sess, tag_set.split(','), saved_model_dir)
if tf_debug:
@@ -632,7 +634,8 @@ def run(args):
args.inputs, args.input_exprs, args.input_examples)
run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def,
tensor_key_feed_dict, args.outdir,
- args.overwrite, tf_debug=args.tf_debug)
+ args.overwrite, worker=args.worker,
+ tf_debug=args.tf_debug)
def scan(args):
@@ -769,6 +772,12 @@ def create_parser():
help='if set, will use TensorFlow Debugger (tfdbg) to watch the '
'intermediate Tensors and runtime GraphDefs while running the '
'SavedModel.')
+ parser_run.add_argument(
+ '--worker',
+ type=str,
+ default=None,
+ help='if specified, a Session will be run on the worker. '
+ 'Valid worker specification is a bns or gRPC path.')
parser_run.set_defaults(func=run)
# scan command
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index 45d217e8b1..13dddd37ac 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -685,6 +685,11 @@ def _serialize_object_graph(root_checkpointable, saveables_cache):
saveables_cache=saveables_cache)
+def named_saveables(root_checkpointable):
+ """Gather list of all SaveableObjects in the Checkpointable object."""
+ return _serialize_object_graph(root_checkpointable, None)[0]
+
+
def list_objects(root_checkpointable):
"""Traverse the object graph and list all accessible objects.
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 1ac7c39872..21ca1735e0 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
@@ -371,7 +372,7 @@ class DistributionStrategy(object):
use its API, including `merge_call()` to get back to cross-tower
context), once for each tower. May use values with locality T or
M, and any variable.
- * `d.reduce(m, t)`: in cross-tower context, accepts t with locality T
+ * `d.reduce(m, t, t)`: in cross-tower context, accepts t with locality T
and produces a value with locality M.
* `d.reduce(m, t, v)`: in cross-tower context, accepts t with
locality T and produces a value with locality V(`v`).
@@ -404,10 +405,11 @@ class DistributionStrategy(object):
Another thing you might want to do in the middle of your tower function
is an all-reduce of some intermediate value, using `d.reduce()` or
- `d.batch_reduce()` without supplying a variable as the destination.
+ `d.batch_reduce()`. You simply provide the same tensor as the input and
+ destination.
Layers should expect to be called in a tower context, and can use
- the `get_tower_context()` function to get a `TowerContext` object. The
+ the `get_tower_context()` function to get a `TowerContext` object. The
`TowerContext` object has a `merge_call()` method for entering
cross-tower context where you can use `reduce()` (or
`batch_reduce()`) and then optionally `update()` to update state.
@@ -718,18 +720,18 @@ class DistributionStrategy(object):
def _call_for_each_tower(self, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
- def reduce(self, aggregation, value, destinations=None):
+ def reduce(self, aggregation, value, destinations):
"""Combine (via e.g. sum or mean) values across towers.
Args:
aggregation: Indicates how a variable will be aggregated. Accepted values
- are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
+ are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`,
+ `tf.VariableAggregation.ONLY_FIRST_TOWER`.
value: A per-device value with one value per tower.
- destinations: An optional mirrored variable, a device string,
- list of device strings. The return value will be copied to all
- destination devices (or all the devices where the mirrored
- variable resides). If `None` or unspecified, the destinations
- will match the devices `value` resides on.
+ destinations: A mirrored variable, a per-device tensor, a device string,
+ or list of device strings. The return value will be copied to all
+ destination devices (or all the devices where the `destinations` value
+ resides). To perform an all-reduction, pass `value` to `destinations`.
Returns:
A value mirrored to `destinations`.
@@ -740,7 +742,8 @@ class DistributionStrategy(object):
_require_cross_tower_context(self)
assert aggregation in [
variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN
+ variable_scope.VariableAggregation.MEAN,
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER
]
return self._reduce(aggregation, value, destinations)
@@ -752,7 +755,8 @@ class DistributionStrategy(object):
Args:
aggregation: Indicates how a variable will be aggregated. Accepted values
- are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
+ are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`,
+ `tf.VariableAggregation.ONLY_FIRST_TOWER`.
value_destination_pairs: A sequence of (value, destinations)
pairs. See `reduce()` for a description.
@@ -763,7 +767,8 @@ class DistributionStrategy(object):
_require_cross_tower_context(self)
assert aggregation in [
variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN
+ variable_scope.VariableAggregation.MEAN,
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER
]
return self._batch_reduce(aggregation, value_destination_pairs)
@@ -1072,10 +1077,15 @@ class TowerContext(object):
require_tower_context(self)
return device_util.current()
- # TODO(josh11b): Implement `start_all_reduce(method, t)` that returns
- # a function returning the result of reducing `t` across all
- # towers. Most likely can be implemented in terms of `merge_call()`
- # and `batch_reduce()`.
+ # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient
+ # all-reduce. It would return a function returning the result of reducing `t`
+ # across all towers. The caller would wait to call this function until they
+ # needed the reduce result, allowing an efficient implementation:
+ # * With eager execution, the reduction could be performed asynchronously
+ # in the background, not blocking until the result was needed.
+ # * When constructing a graph, it could batch up all reduction requests up
+ # to that point that the first result is needed. Most likely this can be
+ # implemented in terms of `merge_call()` and `batch_reduce()`.
# ------------------------------------------------------------------------------
@@ -1168,9 +1178,14 @@ class _DefaultDistributionStrategy(DistributionStrategy):
# ------------------------------------------------------------------------------
-# Common operations
+# Deprecated, use v.assign_add(amount) instead. Internal API, so expect
+# it to be deleted soon.
+@deprecation.deprecated(None,
+ "Use v.assign_add(amount) instead. You may need to set "
+ "aggregation=tf.VariableAggregation.ONLY_FIRST_TOWER "
+ "when creating the variable.")
def increment_var(v, amount=1):
"""`v += amount`, distributed-aware version."""
def update(vu):
diff --git a/tensorflow/python/training/ftrl_test.py b/tensorflow/python/training/ftrl_test.py
index 775bdb3f60..76ca5b45c9 100644
--- a/tensorflow/python/training/ftrl_test.py
+++ b/tensorflow/python/training/ftrl_test.py
@@ -117,8 +117,7 @@ class FtrlOptimizerTest(test.TestCase):
# Run 1 step of sgd
sgd_op.run()
# Validate updated params
- self.assertAllCloseAccordingToType(
- [[0, 1]], var0.eval(), atol=0.01)
+ self.assertAllCloseAccordingToType([[0, 1]], var0.eval(), atol=0.01)
def testFtrlWithL1(self):
for dtype in [dtypes.half, dtypes.float32]:
@@ -212,24 +211,96 @@ class FtrlOptimizerTest(test.TestCase):
v0_val, v1_val = sess.run([var0, var1])
self.assertAllCloseAccordingToType(
- np.array([-0.22078767, -0.41378114]), v0_val)
+ np.array([-0.22578995, -0.44345796]), v0_val)
self.assertAllCloseAccordingToType(
- np.array([-0.02919818, -0.07343706]), v1_val)
+ np.array([-0.14378493, -0.13229476]), v1_val)
+
+ def testFtrlWithL1_L2_L2ShrinkageSparse(self):
+ """Tests the new FTRL op with support for l2 shrinkage on sparse grads."""
+ for dtype in [dtypes.half, dtypes.float32]:
+ with self.test_session() as sess:
+ var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
+ var1 = variables.Variable([[4.0], [3.0]], dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
+ constant_op.constant([0]), constant_op.constant([2, 1]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant([0.02], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]), constant_op.constant([2, 1]))
+
+ opt = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0,
+ l2_shrinkage_regularization_strength=0.1)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllCloseAccordingToType([[1.0], [2.0]], v0_val)
+ self.assertAllCloseAccordingToType([[4.0], [3.0]], v1_val)
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update.run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllCloseAccordingToType([[-0.22578995], [2.]], v0_val)
+ self.assertAllCloseAccordingToType([[4.], [-0.13229476]], v1_val)
+
+ def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
+ """Verifies that l2 shrinkage in FTRL does not change lr schedule."""
+ for dtype in [dtypes.half, dtypes.float32]:
+ with self.test_session() as sess:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([1.0, 2.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.1, 0.2], dtype=dtype)
+
+ opt0 = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0,
+ l2_shrinkage_regularization_strength=0.1)
+ opt1 = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0)
+ update0 = opt0.apply_gradients([(grads0, var0)])
+ update1 = opt1.apply_gradients([(grads1, var1)])
+ variables.global_variables_initializer().run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
+ self.assertAllCloseAccordingToType([1.0, 2.0], v1_val)
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update0.run()
+ update1.run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ # var0 is experiencing L2 shrinkage so it should be smaller than var1
+ # in magnitude.
+ self.assertTrue((v0_val**2 < v1_val**2).all())
+ accum0 = list(sess.run(opt0._slots)["accum"].values())[0]
+ accum1 = list(sess.run(opt1._slots)["accum"].values())[0]
+ # L2 shrinkage should not change how we update grad accumulator.
+ self.assertAllCloseAccordingToType(accum0, accum1)
def applyOptimizer(self, opt, dtype, steps=5, is_sparse=False):
if is_sparse:
var0 = variables.Variable([[0.0], [0.0]], dtype=dtype)
var1 = variables.Variable([[0.0], [0.0]], dtype=dtype)
grads0 = ops.IndexedSlices(
- constant_op.constant(
- [0.1], shape=[1, 1], dtype=dtype),
- constant_op.constant([0]),
- constant_op.constant([2, 1]))
+ constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
+ constant_op.constant([0]), constant_op.constant([2, 1]))
grads1 = ops.IndexedSlices(
- constant_op.constant(
- [0.02], shape=[1, 1], dtype=dtype),
- constant_op.constant([1]),
- constant_op.constant([2, 1]))
+ constant_op.constant([0.02], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]), constant_op.constant([2, 1]))
else:
var0 = variables.Variable([0.0, 0.0], dtype=dtype)
var1 = variables.Variable([0.0, 0.0], dtype=dtype)
@@ -277,8 +348,7 @@ class FtrlOptimizerTest(test.TestCase):
with self.test_session():
val2, val3 = self.applyOptimizer(
- adagrad.AdagradOptimizer(
- 3.0, initial_accumulator_value=0.1), dtype)
+ adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1), dtype)
self.assertAllCloseAccordingToType(val0, val2)
self.assertAllCloseAccordingToType(val1, val3)
@@ -299,8 +369,7 @@ class FtrlOptimizerTest(test.TestCase):
with self.test_session():
val2, val3 = self.applyOptimizer(
- adagrad.AdagradOptimizer(
- 3.0, initial_accumulator_value=0.1),
+ adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1),
dtype,
is_sparse=True)
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index c077630de2..0e0125a956 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -800,7 +800,8 @@ class _MonitoredSession(object):
self.tf_sess = self._session_creator.create_session()
# We don't want coordinator to suppress any exception.
self.coord = coordinator.Coordinator(clean_stop_exception_types=[])
- queue_runner.start_queue_runners(sess=self.tf_sess, coord=self.coord)
+ if ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
+ queue_runner.start_queue_runners(sess=self.tf_sess, coord=self.coord)
# Inform the hooks that a new session has been created.
for hook in self._hooks:
hook.after_create_session(self.tf_sess, self.coord)
@@ -1363,3 +1364,6 @@ class _HookedSession(_WrappedSession):
options.debug_options.debug_tensor_watch_opts.extend(
incoming_options.debug_options.debug_tensor_watch_opts)
+ options.debug_options.reset_disk_byte_usage = (
+ options.debug_options.reset_disk_byte_usage or
+ incoming_options.debug_options.reset_disk_byte_usage)
diff --git a/tensorflow/python/training/queue_runner_impl.py b/tensorflow/python/training/queue_runner_impl.py
index d38c5499c7..ac9d4c850d 100644
--- a/tensorflow/python/training/queue_runner_impl.py
+++ b/tensorflow/python/training/queue_runner_impl.py
@@ -27,10 +27,14 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
+_DEPRECATION_INSTRUCTION = (
+ "To construct input pipelines, use the `tf.data` module.")
-@tf_export("train.queue_runner.QueueRunner", "train.QueueRunner")
+
+@tf_export(v1=["train.queue_runner.QueueRunner", "train.QueueRunner"])
class QueueRunner(object):
"""Holds a list of enqueue operations for a queue, each to be run in a thread.
@@ -53,6 +57,7 @@ class QueueRunner(object):
@end_compatibility
"""
+ @deprecation.deprecated(None, _DEPRECATION_INSTRUCTION)
def __init__(self, queue=None, enqueue_ops=None, close_op=None,
cancel_op=None, queue_closed_exception_types=None,
queue_runner_def=None, import_scope=None):
@@ -386,7 +391,8 @@ class QueueRunner(object):
import_scope=import_scope)
-@tf_export("train.queue_runner.add_queue_runner", "train.add_queue_runner")
+@tf_export(v1=["train.queue_runner.add_queue_runner", "train.add_queue_runner"])
+@deprecation.deprecated(None, _DEPRECATION_INSTRUCTION)
def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS):
"""Adds a `QueueRunner` to a collection in the graph.
@@ -405,8 +411,9 @@ def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS):
ops.add_to_collection(collection, qr)
-@tf_export("train.queue_runner.start_queue_runners",
- "train.start_queue_runners")
+@tf_export(v1=["train.queue_runner.start_queue_runners",
+ "train.start_queue_runners"])
+@deprecation.deprecated(None, _DEPRECATION_INSTRUCTION)
def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
collection=ops.GraphKeys.QUEUE_RUNNERS):
"""Starts all queue runners collected in the graph.
@@ -458,6 +465,13 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
raise TypeError("sess must be a `tf.Session` object. "
"Given class: {}".format(sess.__class__))
+ queue_runners = ops.get_collection(collection)
+ if not queue_runners:
+ logging.warning(
+ "`tf.train.start_queue_runners()` was called when no queue runners "
+ "were defined. You can safely remove the call to this deprecated "
+ "function.")
+
with sess.graph.as_default():
threads = []
for qr in ops.get_collection(collection):
diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py
index 2ff3eeb153..d998d6af81 100644
--- a/tensorflow/python/training/training_util.py
+++ b/tensorflow/python/training/training_util.py
@@ -129,6 +129,7 @@ def create_global_step(graph=None):
dtype=dtypes.int64,
initializer=init_ops.zeros_initializer(),
trainable=False,
+ aggregation=variables.VariableAggregation.ONLY_FIRST_TOWER,
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
ops.GraphKeys.GLOBAL_STEP])
# Create in proper graph and base name_scope.
@@ -139,6 +140,7 @@ def create_global_step(graph=None):
dtype=dtypes.int64,
initializer=init_ops.zeros_initializer(),
trainable=False,
+ aggregation=variables.VariableAggregation.ONLY_FIRST_TOWER,
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
ops.GraphKeys.GLOBAL_STEP])
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 55408ab9ab..207f22c931 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -3894,7 +3894,7 @@ bool CudnnSupport::DoDepthConcatenate(
for (size_t i = 0; i < input_data.size(); ++i) {
const auto& dimensions = input_dimensions[i];
tmp.resize(dimensions.ElementCount());
- stream->ThenMemcpyD2H<float>(*input_data[i], &tmp);
+ stream->ThenMemcpyD2H<float>(*input_data[i], absl::MakeSpan(tmp));
port::Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) {
LOG(ERROR) << "BlockHostUntilDone failed: " << block_status;
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 3562a5192d..adac895a17 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -329,11 +329,16 @@ def tf_binary_additional_srcs():
],
)
+def _linux_kernel_dso_name(kernel_build_target):
+ """Given a build target, construct the dso name for linux."""
+ parts = kernel_build_target.split(":")
+ return "%s:libtfkernel_%s.so" % (parts[0], parts[1])
+
# Helper functions to add kernel dependencies to tf binaries when using dynamic
# kernel linking.
def tf_binary_dynamic_kernel_dsos(kernels):
return if_dynamic_kernels(
- extra_deps = ["libtfkernel_%s.so" % clean_dep(k) for k in kernels],
+ extra_deps = [_linux_kernel_dso_name(k) for k in kernels],
otherwise = [],
)
@@ -796,6 +801,7 @@ def tf_cuda_cc_test(
extra_copts = [],
linkstatic = 0,
args = [],
+ kernels = [],
linkopts = []):
tf_cc_test(
name = name,
@@ -808,6 +814,7 @@ def tf_cuda_cc_test(
linkstatic = linkstatic,
linkopts = linkopts,
args = args,
+ kernels = kernels,
)
tf_cc_test(
name = name,
@@ -829,6 +836,7 @@ def tf_cuda_cc_test(
extra_copts = extra_copts,
linkopts = linkopts,
args = args,
+ kernels = kernels,
)
register_extension_info(
@@ -884,6 +892,7 @@ def tf_cc_tests(
size = "medium",
args = None,
linkopts = [],
+ kernels = [],
nocopts = None):
for src in srcs:
tf_cc_test(
@@ -896,6 +905,7 @@ def tf_cc_tests(
args = args,
linkopts = linkopts,
nocopts = nocopts,
+ kernels = kernels,
)
def tf_cc_test_mkl(
@@ -943,8 +953,9 @@ def tf_cc_tests_gpu(
linkstatic = 0,
tags = [],
size = "medium",
+ kernels = [],
args = None):
- tf_cc_tests(srcs, deps, linkstatic, tags = tags, size = size, args = args)
+ tf_cc_tests(srcs, deps, linkstatic, tags = tags, size = size, kernels = kernels, args = args)
def tf_cuda_cc_tests(
srcs,
@@ -954,6 +965,7 @@ def tf_cuda_cc_tests(
size = "medium",
linkstatic = 0,
args = None,
+ kernels = [],
linkopts = []):
for src in srcs:
tf_cuda_cc_test(
@@ -964,6 +976,7 @@ def tf_cuda_cc_tests(
size = size,
linkstatic = linkstatic,
args = args,
+ kernels = kernels,
linkopts = linkopts,
)
@@ -1352,12 +1365,13 @@ def transitive_hdrs(name, deps = [], **kwargs):
# Create a header only library that includes all the headers exported by
# the libraries in deps.
-def cc_header_only_library(name, deps = [], includes = [], **kwargs):
+def cc_header_only_library(name, deps = [], includes = [], extra_deps = [], **kwargs):
_transitive_hdrs(name = name + "_gather", deps = deps)
native.cc_library(
name = name,
hdrs = [":" + name + "_gather"],
includes = includes,
+ deps = extra_deps,
**kwargs
)
@@ -1654,17 +1668,17 @@ def tf_py_wrap_cc(
# Note that this only works on Windows. See the definition of
# //third_party/tensorflow/tools/pip_package:win_pip_package_marker for specific reasons.
# 2. When --define=no_tensorflow_py_deps=false (by default), it's a normal py_test.
-def py_test(deps = [], data = [], **kwargs):
+def py_test(deps = [], data = [], kernels = [], **kwargs):
native.py_test(
# TODO(jlebar): Ideally we'd use tcmalloc here.,
deps = select({
"//conditions:default": deps,
clean_dep("//tensorflow:no_tensorflow_py_deps"): [],
- }),
+ }) + tf_binary_dynamic_kernel_deps(kernels),
data = data + select({
"//conditions:default": [],
clean_dep("//tensorflow:no_tensorflow_py_deps"): ["//tensorflow/tools/pip_package:win_pip_package_marker"],
- }),
+ }) + tf_binary_dynamic_kernel_dsos(kernels),
**kwargs
)
@@ -1683,6 +1697,7 @@ def tf_py_test(
tags = [],
shard_count = 1,
additional_deps = [],
+ kernels = [],
flaky = 0,
xla_enabled = False,
grpc_enabled = False):
@@ -1699,6 +1714,7 @@ def tf_py_test(
tags = tags,
visibility = [clean_dep("//tensorflow:internal")],
shard_count = shard_count,
+ kernels = kernels,
data = data,
deps = [
clean_dep("//tensorflow/python:extra_py_tests_deps"),
@@ -1722,6 +1738,7 @@ def cuda_py_test(
args = [],
shard_count = 1,
additional_deps = [],
+ kernels = [],
tags = [],
flaky = 0,
xla_enabled = False,
@@ -1737,6 +1754,7 @@ def cuda_py_test(
tags = test_tags,
shard_count = shard_count,
additional_deps = additional_deps,
+ kernels = kernels,
flaky = flaky,
xla_enabled = xla_enabled,
grpc_enabled = grpc_enabled,
@@ -1756,6 +1774,7 @@ def sycl_py_test(
args = [],
shard_count = 1,
additional_deps = [],
+ kernels = [],
tags = [],
flaky = 0,
xla_enabled = False,
@@ -1771,6 +1790,7 @@ def sycl_py_test(
tags = test_tags,
shard_count = shard_count,
additional_deps = additional_deps,
+ kernels = kernels,
flaky = flaky,
xla_enabled = xla_enabled,
grpc_enabled = grpc_enabled,
@@ -1786,6 +1806,7 @@ def py_tests(
srcs,
size = "medium",
additional_deps = [],
+ kernels = [],
data = [],
tags = [],
shard_count = 1,
@@ -1805,6 +1826,7 @@ def py_tests(
shard_count = shard_count,
data = data,
additional_deps = additional_deps,
+ kernels = kernels,
xla_enabled = xla_enabled,
grpc_enabled = grpc_enabled,
)
@@ -1814,6 +1836,7 @@ def cuda_py_tests(
srcs,
size = "medium",
additional_deps = [],
+ kernels = [],
data = [],
shard_count = 1,
tags = [],
@@ -1830,6 +1853,7 @@ def cuda_py_tests(
tags = test_tags,
shard_count = shard_count,
prefix = prefix,
+ kernels = kernels,
xla_enabled = xla_enabled,
grpc_enabled = grpc_enabled,
)
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-variable-aggregation.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-variable-aggregation.pbtxt
index 36b534af36..66a20547eb 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-variable-aggregation.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-variable-aggregation.pbtxt
@@ -10,6 +10,10 @@ tf_class {
mtype: "<enum \'VariableAggregation\'>"
}
member {
+ name: "ONLY_FIRST_TOWER"
+ mtype: "<enum \'VariableAggregation\'>"
+ }
+ member {
name: "SUM"
mtype: "<enum \'VariableAggregation\'>"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.glorot_normal_initializer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.glorot_normal_initializer.pbtxt
new file mode 100644
index 0000000000..483d1f8ba0
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.glorot_normal_initializer.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.glorot_normal_initializer"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotNormal\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.glorot_uniform_initializer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.glorot_uniform_initializer.pbtxt
new file mode 100644
index 0000000000..bb8540d0fd
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.glorot_uniform_initializer.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.glorot_uniform_initializer"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotUniform\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.initializers.glorot_normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.initializers.glorot_normal.pbtxt
new file mode 100644
index 0000000000..4a81e52df9
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.initializers.glorot_normal.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.initializers.glorot_normal"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotNormal\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.initializers.glorot_uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.initializers.glorot_uniform.pbtxt
new file mode 100644
index 0000000000..815dc81dff
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.initializers.glorot_uniform.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.initializers.glorot_uniform"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotUniform\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.initializers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.initializers.pbtxt
index bc0426f2f1..d499c67d89 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.initializers.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.initializers.pbtxt
@@ -5,6 +5,14 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "glorot_normal"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "glorot_uniform"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "identity"
mtype: "<type \'type\'>"
}
@@ -45,14 +53,6 @@ tf_module {
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "glorot_normal"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
- name: "glorot_uniform"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
name: "he_normal"
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt
index 3a36c168aa..8938cf217b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt
@@ -25,6 +25,10 @@ tf_module {
argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "parse_sequence_example"
+ argspec: "args=[\'serialized\', \'context_features\', \'sequence_features\', \'example_names\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "parse_tensor"
argspec: "args=[\'serialized\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-normal.pbtxt
index 23cd02c0b0..26784ce55d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-normal.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.RandomNormal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-uniform.pbtxt
index d98628f422..4110bda5f6 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-uniform.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.RandomUniform"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'-0.05\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-truncated-normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-truncated-normal.pbtxt
index 86d48257c1..0451d0d73a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-truncated-normal.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-truncated-normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.TruncatedNormal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.TruncatedNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.TruncatedNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.glorot_normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.glorot_normal.pbtxt
new file mode 100644
index 0000000000..ef0815972d
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.glorot_normal.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.keras.initializers.glorot_normal"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotNormal\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.glorot_uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.glorot_uniform.pbtxt
new file mode 100644
index 0000000000..439b5ada9b
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.glorot_uniform.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.keras.initializers.glorot_uniform"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotUniform\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.normal.pbtxt
index 7485772784..8d0b5c242b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.normal.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.normal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.pbtxt
index 8645e54302..1540c2915b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.pbtxt
@@ -45,6 +45,14 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "glorot_normal"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "glorot_uniform"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "identity"
mtype: "<type \'type\'>"
}
@@ -89,14 +97,6 @@ tf_module {
argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "glorot_normal"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
- name: "glorot_uniform"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
name: "he_normal"
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_normal.pbtxt
index a6df1e87a3..bac8211a10 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_normal.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.random_normal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_uniform.pbtxt
index 37a0fa0d55..ab0d74d071 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_uniform.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.random_uniform"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'-0.05\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.truncated_normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.truncated_normal.pbtxt
index f97e93f0b7..358cca2b9c 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.truncated_normal.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.truncated_normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.truncated_normal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.TruncatedNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.TruncatedNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.uniform.pbtxt
index 58186b1383..e6c731361a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.uniform.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.uniform"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'-0.05\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index 821ca7b140..dd9f7c49e0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -365,6 +365,14 @@ tf_module {
mtype: "<type \'module\'>"
}
member {
+ name: "glorot_normal_initializer"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "glorot_uniform_initializer"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "graph_util"
mtype: "<type \'module\'>"
}
@@ -1182,7 +1190,7 @@ tf_module {
}
member_method {
name: "get_local_variable"
- argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'synchronization\', \'aggregation\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\', \'None\'], "
+ argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "get_seed"
@@ -1217,14 +1225,6 @@ tf_module {
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "glorot_normal_initializer"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
- name: "glorot_uniform_initializer"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
name: "gradients"
argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-variable-aggregation.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-variable-aggregation.pbtxt
index 36b534af36..66a20547eb 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-variable-aggregation.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-variable-aggregation.pbtxt
@@ -10,6 +10,10 @@ tf_class {
mtype: "<enum \'VariableAggregation\'>"
}
member {
+ name: "ONLY_FIRST_TOWER"
+ mtype: "<enum \'VariableAggregation\'>"
+ }
+ member {
name: "SUM"
mtype: "<enum \'VariableAggregation\'>"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.glorot_normal_initializer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.glorot_normal_initializer.pbtxt
new file mode 100644
index 0000000000..483d1f8ba0
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.glorot_normal_initializer.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.glorot_normal_initializer"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotNormal\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.glorot_uniform_initializer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.glorot_uniform_initializer.pbtxt
new file mode 100644
index 0000000000..bb8540d0fd
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.glorot_uniform_initializer.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.glorot_uniform_initializer"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotUniform\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_normal.pbtxt
new file mode 100644
index 0000000000..4a81e52df9
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_normal.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.initializers.glorot_normal"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotNormal\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_uniform.pbtxt
new file mode 100644
index 0000000000..815dc81dff
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_uniform.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.initializers.glorot_uniform"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotUniform\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt
index bc0426f2f1..d499c67d89 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt
@@ -5,6 +5,14 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "glorot_normal"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "glorot_uniform"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "identity"
mtype: "<type \'type\'>"
}
@@ -45,14 +53,6 @@ tf_module {
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "glorot_normal"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
- name: "glorot_uniform"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
name: "he_normal"
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt
index 3a36c168aa..8938cf217b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt
@@ -25,6 +25,10 @@ tf_module {
argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "parse_sequence_example"
+ argspec: "args=[\'serialized\', \'context_features\', \'sequence_features\', \'example_names\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "parse_tensor"
argspec: "args=[\'serialized\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-normal.pbtxt
index 23cd02c0b0..26784ce55d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.RandomNormal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-uniform.pbtxt
index d98628f422..4110bda5f6 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-uniform.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.RandomUniform"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'-0.05\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt
index 86d48257c1..0451d0d73a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.TruncatedNormal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.TruncatedNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.TruncatedNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_normal.pbtxt
new file mode 100644
index 0000000000..ef0815972d
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_normal.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.keras.initializers.glorot_normal"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotNormal\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_uniform.pbtxt
new file mode 100644
index 0000000000..439b5ada9b
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_uniform.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.keras.initializers.glorot_uniform"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotUniform\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.normal.pbtxt
index 7485772784..8d0b5c242b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.normal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt
index 8645e54302..1540c2915b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt
@@ -45,6 +45,14 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "glorot_normal"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "glorot_uniform"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "identity"
mtype: "<type \'type\'>"
}
@@ -89,14 +97,6 @@ tf_module {
argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "glorot_normal"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
- name: "glorot_uniform"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
name: "he_normal"
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_normal.pbtxt
index a6df1e87a3..bac8211a10 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.random_normal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_uniform.pbtxt
index 37a0fa0d55..ab0d74d071 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_uniform.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.random_uniform"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'-0.05\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt
index f97e93f0b7..358cca2b9c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.truncated_normal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.TruncatedNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.TruncatedNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.uniform.pbtxt
index 58186b1383..e6c731361a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.uniform.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.uniform"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'-0.05\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index 519cf66aa4..7d45ea22c8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -365,6 +365,14 @@ tf_module {
mtype: "<type \'module\'>"
}
member {
+ name: "glorot_normal_initializer"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "glorot_uniform_initializer"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "graph_util"
mtype: "<type \'module\'>"
}
@@ -1158,7 +1166,7 @@ tf_module {
}
member_method {
name: "get_local_variable"
- argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'synchronization\', \'aggregation\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\', \'None\'], "
+ argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "get_seed"
@@ -1193,14 +1201,6 @@ tf_module {
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "glorot_normal_initializer"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
- name: "glorot_uniform_initializer"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
name: "gradients"
argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-queue-runner.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-queue-runner.pbtxt
deleted file mode 100644
index d84d0058ee..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-queue-runner.pbtxt
+++ /dev/null
@@ -1,49 +0,0 @@
-path: "tensorflow.train.QueueRunner"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.queue_runner_impl.QueueRunner\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "cancel_op"
- mtype: "<type \'property\'>"
- }
- member {
- name: "close_op"
- mtype: "<type \'property\'>"
- }
- member {
- name: "enqueue_ops"
- mtype: "<type \'property\'>"
- }
- member {
- name: "exceptions_raised"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "queue"
- mtype: "<type \'property\'>"
- }
- member {
- name: "queue_closed_exception_types"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'queue\', \'enqueue_ops\', \'close_op\', \'cancel_op\', \'queue_closed_exception_types\', \'queue_runner_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "create_threads"
- argspec: "args=[\'self\', \'sess\', \'coord\', \'daemon\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'False\'], "
- }
- member_method {
- name: "from_proto"
- argspec: "args=[\'queue_runner_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "to_proto"
- argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
index 9f35395284..c35e254843 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
@@ -145,10 +145,6 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
- name: "QueueRunner"
- mtype: "<type \'type\'>"
- }
- member {
name: "RMSPropOptimizer"
mtype: "<type \'type\'>"
}
@@ -236,10 +232,6 @@ tf_module {
name: "WorkerSessionCreator"
mtype: "<type \'type\'>"
}
- member {
- name: "queue_runner"
- mtype: "<type \'module\'>"
- }
member_method {
name: "MonitoredTrainingSession"
argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\', \'save_checkpoint_steps\', \'summary_dir\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'None\', \'120\', \'100\', \'7200\', \'<object object instance>\', \'None\'], "
@@ -249,10 +241,6 @@ tf_module {
argspec: "args=[\'filepattern\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "add_queue_runner"
- argspec: "args=[\'qr\', \'collection\'], varargs=None, keywords=None, defaults=[\'queue_runners\'], "
- }
- member_method {
name: "assert_global_step"
argspec: "args=[\'global_step_tensor\'], varargs=None, keywords=None, defaults=None"
}
@@ -433,10 +421,6 @@ tf_module {
argspec: "args=[\'tensor_list\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\'], "
}
member_method {
- name: "start_queue_runners"
- argspec: "args=[\'sess\', \'coord\', \'daemon\', \'start\', \'collection\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'True\', \'queue_runners\'], "
- }
- member_method {
name: "string_input_producer"
argspec: "args=[\'string_tensor\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\', \'cancel_op\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.-queue-runner.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.-queue-runner.pbtxt
deleted file mode 100644
index 23d402de30..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.-queue-runner.pbtxt
+++ /dev/null
@@ -1,49 +0,0 @@
-path: "tensorflow.train.queue_runner.QueueRunner"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.queue_runner_impl.QueueRunner\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "cancel_op"
- mtype: "<type \'property\'>"
- }
- member {
- name: "close_op"
- mtype: "<type \'property\'>"
- }
- member {
- name: "enqueue_ops"
- mtype: "<type \'property\'>"
- }
- member {
- name: "exceptions_raised"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "queue"
- mtype: "<type \'property\'>"
- }
- member {
- name: "queue_closed_exception_types"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'queue\', \'enqueue_ops\', \'close_op\', \'cancel_op\', \'queue_closed_exception_types\', \'queue_runner_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "create_threads"
- argspec: "args=[\'self\', \'sess\', \'coord\', \'daemon\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'False\'], "
- }
- member_method {
- name: "from_proto"
- argspec: "args=[\'queue_runner_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "to_proto"
- argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.pbtxt
deleted file mode 100644
index 6e2d043049..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.pbtxt
+++ /dev/null
@@ -1,15 +0,0 @@
-path: "tensorflow.train.queue_runner"
-tf_module {
- member {
- name: "QueueRunner"
- mtype: "<type \'type\'>"
- }
- member_method {
- name: "add_queue_runner"
- argspec: "args=[\'qr\', \'collection\'], varargs=None, keywords=None, defaults=[\'queue_runners\'], "
- }
- member_method {
- name: "start_queue_runners"
- argspec: "args=[\'sess\', \'coord\', \'daemon\', \'start\', \'collection\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'True\', \'queue_runners\'], "
- }
-}
diff --git a/tensorflow/tools/ci_build/Dockerfile.gpu b/tensorflow/tools/ci_build/Dockerfile.gpu
index 383f9545c9..f05c7a4809 100644
--- a/tensorflow/tools/ci_build/Dockerfile.gpu
+++ b/tensorflow/tools/ci_build/Dockerfile.gpu
@@ -30,4 +30,3 @@ RUN mkdir /usr/local/cuda-9.0/lib && \
# Configure the build for our CUDA configuration.
ENV TF_NEED_CUDA 1
-ENV TF_CUDA_COMPUTE_CAPABILITIES 3.0
diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.gpu b/tensorflow/tools/ci_build/Dockerfile.rbe.gpu
index 24ff4765a6..b656205836 100644
--- a/tensorflow/tools/ci_build/Dockerfile.rbe.gpu
+++ b/tensorflow/tools/ci_build/Dockerfile.rbe.gpu
@@ -19,8 +19,8 @@ RUN /install/install_golang.sh
# Install clang from pre-built package
RUN cd /tmp && \
- wget https://storage.googleapis.com/clang-builds-stable/clang-ubuntu16_04/clang_r323528.tar.gz && \
- echo "26752d9f5785df07193fac8316ba5d5ba3bec36d970c29a1577360848818ac74 clang_r323528.tar.gz" | sha256sum -c && \
+ wget https://storage.googleapis.com/clang-builds-stable/clang-ubuntu16_04/clang_r337145.tar.gz && \
+ echo "ab98c63eb09c04112cc992bc95ebc0dcea8c5e9d0760438789be2896cdc69ff8 clang_r337145.tar.gz" | sha256sum -c && \
tar -C /usr/local -xf clang_r323528.tar.gz && \
- rm clang_r323528.tar.gz
+ rm clang_r337145.tar.gz
diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh
index 993894d658..1d7d9df72f 100755
--- a/tensorflow/tools/ci_build/ci_parameterized_build.sh
+++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh
@@ -377,6 +377,10 @@ else
if [[ ${IS_MAC} == "1" ]]; then
EXTRA_ARGS="${EXTRA_ARGS},-nomac"
fi
+ EXTRA_ARGS="${EXTRA_ARGS} --build_tag_filters=-no_oss,-oss_serial,-benchmark-test"
+ if [[ ${IS_MAC} == "1" ]]; then
+ EXTRA_ARGS="${EXTRA_ARGS},-nomac"
+ fi
fi
# For any "tool" dependencies in genrules, Bazel will build them for host
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index b6fa6f6dab..e487779e7a 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -13,8 +13,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
cuda-cusparse-dev-9-0 \
curl \
git \
- libcudnn7=7.1.4.18-1+cuda9.0 \
- libcudnn7-dev=7.1.4.18-1+cuda9.0 \
+ libcudnn7=7.2.1.38-1+cuda9.0 \
+ libcudnn7-dev=7.2.1.38-1+cuda9.0 \
libnccl2=2.2.13-1+cuda9.0 \
libnccl-dev=2.2.13-1+cuda9.0 \
libcurl3-dev \
@@ -35,6 +35,12 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
find /usr/local/cuda-9.0/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
rm /usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a
+RUN apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
+ apt-get update && \
+ apt-get install libnvinfer4=4.1.2-1+cuda9.0 && \
+ apt-get install libnvinfer-dev=4.1.2-1+cuda9.0
+
# Link NCCL libray and header where the build script expects them.
RUN mkdir /usr/local/cuda-9.0/lib && \
ln -s /usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/local/cuda/lib/libnccl.so.2 && \
@@ -100,6 +106,7 @@ RUN git clone --branch=r1.10 --depth=1 https://github.com/tensorflow/tensorflow.
ENV CI_BUILD_PYTHON python
ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
ENV TF_NEED_CUDA 1
+ENV TF_NEED_TENSORRT 1
ENV TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0
ENV TF_CUDA_VERSION=9.0
ENV TF_CUDNN_VERSION=7
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7 b/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
deleted file mode 100644
index eb139ec5f8..0000000000
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
+++ /dev/null
@@ -1,117 +0,0 @@
-FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04
-
-LABEL maintainer="Gunhan Gulsoy <gunan@google.com>"
-
-# It is possible to override these for releases.
-ARG TF_BRANCH=master
-ARG BAZEL_VERSION=0.15.0
-ARG TF_AVAILABLE_CPUS=32
-
-RUN apt-get update && apt-get install -y --no-install-recommends \
- build-essential \
- curl \
- git \
- golang \
- libcurl3-dev \
- libfreetype6-dev \
- libpng12-dev \
- libzmq3-dev \
- pkg-config \
- python-dev \
- python-pip \
- rsync \
- software-properties-common \
- unzip \
- zip \
- zlib1g-dev \
- openjdk-8-jdk \
- openjdk-8-jre-headless \
- wget \
- && \
- apt-get clean && \
- rm -rf /var/lib/apt/lists/*
-
-RUN pip --no-cache-dir install --upgrade \
- pip setuptools
-
-RUN pip --no-cache-dir install \
- ipykernel \
- jupyter \
- keras_applications==1.0.5 \
- keras_preprocessing==1.0.3 \
- matplotlib \
- numpy \
- scipy \
- sklearn \
- pandas \
- wheel \
- && \
- python -m ipykernel.kernelspec
-
-# Set up our notebook config.
-COPY jupyter_notebook_config.py /root/.jupyter/
-
-# Jupyter has issues with being run directly:
-# https://github.com/ipython/ipython/issues/7062
-# We just add a little wrapper script.
-COPY run_jupyter.sh /
-
-# Set up Bazel.
-
-# Running bazel inside a `docker build` command causes trouble, cf:
-# https://github.com/bazelbuild/bazel/issues/134
-# The easiest solution is to set up a bazelrc file forcing --batch.
-RUN echo "startup --batch" >>/etc/bazel.bazelrc
-# Similarly, we need to workaround sandboxing issues:
-# https://github.com/bazelbuild/bazel/issues/418
-RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
- >>/etc/bazel.bazelrc
-WORKDIR /
-RUN mkdir /bazel && \
- cd /bazel && \
- wget --quiet https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
- wget --quiet https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \
- chmod +x bazel-*.sh && \
- ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
- rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh
-
-# Download and build TensorFlow.
-WORKDIR /
-RUN git clone https://github.com/tensorflow/tensorflow.git && \
- cd tensorflow && \
- git checkout ${TF_BRANCH}
-WORKDIR /tensorflow
-
-# Configure the build for our CUDA configuration.
-ENV CI_BUILD_PYTHON=python \
- LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:${LD_LIBRARY_PATH} \
- CUDNN_INSTALL_PATH=/usr/lib/x86_64-linux-gnu \
- PYTHON_BIN_PATH=/usr/bin/python \
- PYTHON_LIB_PATH=/usr/local/lib/python2.7/dist-packages \
- TF_NEED_CUDA=1 \
- TF_CUDA_VERSION=9.0 \
- TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2,6.0,6.1,7.0 \
- TF_CUDNN_VERSION=7
-RUN ./configure
-
-# Build and Install TensorFlow.
-RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 && \
- LD_LIBRARY_PATH=/usr/local/cuda/lib64/stubs:${LD_LIBRARY_PATH} \
- bazel build -c opt \
- --config=cuda \
- --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" \
- --jobs=${TF_AVAILABLE_CPUS} \
- tensorflow/tools/pip_package:build_pip_package && \
- mkdir /pip_pkg && \
- bazel-bin/tensorflow/tools/pip_package/build_pip_package /pip_pkg && \
- pip --no-cache-dir install --upgrade /pip_pkg/tensorflow-*.whl && \
- rm -rf /pip_pkg && \
- rm -rf /root/.cache
-# Clean up pip wheel and Bazel cache when done.
-
-WORKDIR /root
-
-# TensorBoard
-EXPOSE 6006
-# IPython
-EXPOSE 8888
diff --git a/tensorflow/tools/docker/Dockerfile.gpu b/tensorflow/tools/docker/Dockerfile.gpu
index c68082842d..781bf9e851 100644
--- a/tensorflow/tools/docker/Dockerfile.gpu
+++ b/tensorflow/tools/docker/Dockerfile.gpu
@@ -12,7 +12,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
cuda-cusolver-9-0 \
cuda-cusparse-9-0 \
curl \
- libcudnn7=7.1.4.18-1+cuda9.0 \
+ libcudnn7=7.2.1.38-1+cuda9.0 \
libnccl2=2.2.13-1+cuda9.0 \
libfreetype6-dev \
libhdf5-serial-dev \
@@ -28,6 +28,11 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
+RUN apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
+ apt-get update && \
+ apt-get install libnvinfer4=4.1.2-1+cuda9.0
+
RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
python get-pip.py && \
rm get-pip.py
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel-jupyter.Dockerfile
index 0f5fedf2fe..68c0e2f2bd 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel-jupyter.Dockerfile
@@ -50,8 +50,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
cuda-cusparse-dev-9-0 \
curl \
git \
- libcudnn7=7.1.4.18-1+cuda9.0 \
- libcudnn7-dev=7.1.4.18-1+cuda9.0 \
+ libcudnn7=7.2.1.38-1+cuda9.0 \
+ libcudnn7-dev=7.2.1.38-1+cuda9.0 \
libnccl2=2.2.13-1+cuda9.0 \
libnccl-dev=2.2.13-1+cuda9.0 \
libcurl3-dev \
@@ -71,6 +71,12 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
find /usr/local/cuda-9.0/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
rm /usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a
+RUN apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
+ apt-get update && \
+ apt-get install libnvinfer4=4.1.2-1+cuda9.0 && \
+ apt-get install libnvinfer-dev=4.1.2-1+cuda9.0
+
# Link NCCL libray and header where the build script expects them.
RUN mkdir /usr/local/cuda-9.0/lib && \
ln -s /usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/local/cuda/lib/libnccl.so.2 && \
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel.Dockerfile
index a6e280082e..77be0dd287 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel.Dockerfile
@@ -48,8 +48,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
cuda-cusparse-dev-9-0 \
curl \
git \
- libcudnn7=7.1.4.18-1+cuda9.0 \
- libcudnn7-dev=7.1.4.18-1+cuda9.0 \
+ libcudnn7=7.2.1.38-1+cuda9.0 \
+ libcudnn7-dev=7.2.1.38-1+cuda9.0 \
libnccl2=2.2.13-1+cuda9.0 \
libnccl-dev=2.2.13-1+cuda9.0 \
libcurl3-dev \
@@ -69,6 +69,12 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
find /usr/local/cuda-9.0/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
rm /usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a
+RUN apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
+ apt-get update && \
+ apt-get install libnvinfer4=4.1.2-1+cuda9.0 && \
+ apt-get install libnvinfer-dev=4.1.2-1+cuda9.0
+
# Link NCCL libray and header where the build script expects them.
RUN mkdir /usr/local/cuda-9.0/lib && \
ln -s /usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/local/cuda/lib/libnccl.so.2 && \
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/nvidia-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-jupyter.Dockerfile
index f1799113b1..5ff1fa917a 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/nvidia-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-jupyter.Dockerfile
@@ -48,7 +48,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
cuda-curand-9-0 \
cuda-cusolver-9-0 \
cuda-cusparse-9-0 \
- libcudnn7=7.1.4.18-1+cuda9.0 \
+ libcudnn7=7.2.1.38-1+cuda9.0 \
libnccl2=2.2.13-1+cuda9.0 \
libfreetype6-dev \
libhdf5-serial-dev \
@@ -61,6 +61,11 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
+RUN apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
+ apt-get update && \
+ apt-get install libnvinfer4=4.1.2-1+cuda9.0
+
ARG USE_PYTHON_3_NOT_2=True
ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
ARG PYTHON=python${_PY_SUFFIX}
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/nvidia.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/nvidia.Dockerfile
index 690eb68b22..3df810b5fe 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/nvidia.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/nvidia.Dockerfile
@@ -46,7 +46,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
cuda-curand-9-0 \
cuda-cusolver-9-0 \
cuda-cusparse-9-0 \
- libcudnn7=7.1.4.18-1+cuda9.0 \
+ libcudnn7=7.2.1.38-1+cuda9.0 \
libnccl2=2.2.13-1+cuda9.0 \
libfreetype6-dev \
libhdf5-serial-dev \
@@ -59,6 +59,11 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
+RUN apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
+ apt-get update && \
+ apt-get install libnvinfer4=4.1.2-1+cuda9.0
+
ARG USE_PYTHON_3_NOT_2=True
ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
ARG PYTHON=python${_PY_SUFFIX}
diff --git a/tensorflow/tools/dockerfiles/partials/nvidia-devel.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/nvidia-devel.partial.Dockerfile
index f31b695e77..45159f711f 100644
--- a/tensorflow/tools/dockerfiles/partials/nvidia-devel.partial.Dockerfile
+++ b/tensorflow/tools/dockerfiles/partials/nvidia-devel.partial.Dockerfile
@@ -12,8 +12,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
cuda-cusparse-dev-9-0 \
curl \
git \
- libcudnn7=7.1.4.18-1+cuda9.0 \
- libcudnn7-dev=7.1.4.18-1+cuda9.0 \
+ libcudnn7=7.2.1.38-1+cuda9.0 \
+ libcudnn7-dev=7.2.1.38-1+cuda9.0 \
libnccl2=2.2.13-1+cuda9.0 \
libnccl-dev=2.2.13-1+cuda9.0 \
libcurl3-dev \
@@ -33,6 +33,12 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
find /usr/local/cuda-9.0/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
rm /usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a
+RUN apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
+ apt-get update && \
+ apt-get install libnvinfer4=4.1.2-1+cuda9.0 && \
+ apt-get install libnvinfer-dev=4.1.2-1+cuda9.0
+
# Link NCCL libray and header where the build script expects them.
RUN mkdir /usr/local/cuda-9.0/lib && \
ln -s /usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/local/cuda/lib/libnccl.so.2 && \
diff --git a/tensorflow/tools/dockerfiles/partials/nvidia.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/nvidia.partial.Dockerfile
index 13d865b9d4..1064390af3 100644
--- a/tensorflow/tools/dockerfiles/partials/nvidia.partial.Dockerfile
+++ b/tensorflow/tools/dockerfiles/partials/nvidia.partial.Dockerfile
@@ -9,7 +9,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
cuda-curand-9-0 \
cuda-cusolver-9-0 \
cuda-cusparse-9-0 \
- libcudnn7=7.1.4.18-1+cuda9.0 \
+ libcudnn7=7.2.1.38-1+cuda9.0 \
libnccl2=2.2.13-1+cuda9.0 \
libfreetype6-dev \
libhdf5-serial-dev \
@@ -21,3 +21,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
&& \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
+
+RUN apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
+ apt-get update && \
+ apt-get install libnvinfer4=4.1.2-1+cuda9.0
diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py
index 090cf48a07..483921fc2f 100644
--- a/tensorflow/tools/docs/generate_lib.py
+++ b/tensorflow/tools/docs/generate_lib.py
@@ -418,8 +418,8 @@ class _GenerateGuideIndex(py_guide_parser.PyGuideParser):
self.section_tag = tag
def process_line(self, _, line):
- """Index @{symbol} references as in the current file & section."""
- for match in parser.SYMBOL_REFERENCE_RE.finditer(line):
+ """Index the file and section of each `symbol` reference."""
+ for match in parser.AUTO_REFERENCE_RE.finditer(line):
val = self.index.get(match.group(1), [])
val.append(
_GuideRef(self.base_name, self.title, self.section_title,
diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py
index 8e444a15cf..997afc6ac7 100644
--- a/tensorflow/tools/docs/parser.py
+++ b/tensorflow/tools/docs/parser.py
@@ -1698,7 +1698,7 @@ class _Metadata(object):
version: The source version.
"""
- def __init__(self, name, version='stable'):
+ def __init__(self, name, version='Stable'):
"""Creates a Metadata builder.
Args:
diff --git a/tensorflow/tools/docs/pretty_docs.py b/tensorflow/tools/docs/pretty_docs.py
index 63d4fef91c..aecf753a58 100644
--- a/tensorflow/tools/docs/pretty_docs.py
+++ b/tensorflow/tools/docs/pretty_docs.py
@@ -93,6 +93,15 @@ def _build_class_page(page_info):
parts.append('\n\n')
+ # Sort the methods list, but make sure constructors come first.
+ constructor_names = ['__init__', '__new__']
+ constructors = sorted(
+ method for method in page_info.methods
+ if method.short_name in constructor_names)
+ other_methods = sorted(
+ method for method in page_info.methods
+ if method.short_name not in constructor_names)
+
if len(page_info.aliases) > 1:
parts.append('### Aliases:\n\n')
parts.extend('* Class `%s`\n' % name for name in page_info.aliases)
@@ -109,6 +118,11 @@ def _build_class_page(page_info):
parts.append('\n\n')
+ if constructors:
+ for method_info in constructors:
+ parts.append(_build_method_section(method_info, heading_level=2))
+ parts.append('\n\n')
+
if page_info.classes:
parts.append('## Child Classes\n')
@@ -134,28 +148,11 @@ def _build_class_page(page_info):
parts.append('\n\n')
- if page_info.methods:
+ if other_methods:
parts.append('## Methods\n\n')
- # Sort the methods list, but make sure constructors come first.
- constructors = ['__init__', '__new__']
- inits = [method for method in page_info.methods
- if method.short_name in constructors]
- others = [method for method in page_info.methods
- if method.short_name not in constructors]
-
- for method_info in sorted(inits) + sorted(others):
- h3 = ('<h3 id="{short_name}">'
- '<code>{short_name}</code>'
- '</h3>\n\n')
- parts.append(h3.format(**method_info._asdict()))
-
- if method_info.signature is not None:
- parts.append(_build_signature(method_info, use_full_name=False))
-
- parts.append(method_info.doc.docstring)
- parts.append(_build_function_details(method_info.doc.function_details))
- parts.append(_build_compatibility(method_info.doc.compatibility))
- parts.append('\n\n')
+
+ for method_info in other_methods:
+ parts.append(_build_method_section(method_info))
parts.append('\n\n')
if page_info.other_members:
@@ -172,6 +169,33 @@ def _build_class_page(page_info):
return ''.join(parts)
+def _build_method_section(method_info, heading_level=3):
+ """Generates a markdown section for a method.
+
+ Args:
+ method_info: A `MethodInfo` object.
+ heading_level: An Int, which HTML heading level to use.
+
+ Returns:
+ A markdown string.
+ """
+ parts = []
+ heading = ('<h{heading_level} id="{short_name}">'
+ '<code>{short_name}</code>'
+ '</h{heading_level}>\n\n')
+ parts.append(heading.format(heading_level=heading_level,
+ **method_info._asdict()))
+
+ if method_info.signature is not None:
+ parts.append(_build_signature(method_info, use_full_name=False))
+
+ parts.append(method_info.doc.docstring)
+ parts.append(_build_function_details(method_info.doc.function_details))
+ parts.append(_build_compatibility(method_info.doc.compatibility))
+ parts.append('\n\n')
+ return ''.join(parts)
+
+
def _build_module_page(page_info):
"""Given a ClassPageInfo object Return the page as an md string."""
parts = ['# Module: {full_name}\n\n'.format(full_name=page_info.full_name)]
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 9f9340254c..997725d865 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -60,31 +60,31 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
mkl_repository(
name = "mkl_linux",
urls = [
- "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.15/mklml_lnx_2018.0.3.20180406.tgz",
- "https://github.com/intel/mkl-dnn/releases/download/v0.15/mklml_lnx_2018.0.3.20180406.tgz",
+ "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.16/mklml_lnx_2019.0.20180710.tgz",
+ "https://github.com/intel/mkl-dnn/releases/download/v0.16/mklml_lnx_2019.0.20180710.tgz",
],
- sha256 = "d2305244fdc9b87db7426ed4496e87a4b3977ad3374d73b8000e8b7a5b7aa725",
- strip_prefix = "mklml_lnx_2018.0.3.20180406",
+ sha256 = "e2233534a9d15c387e22260997af4312a39e9f86f791768409be273b5453c4e6",
+ strip_prefix = "mklml_lnx_2019.0.20180710",
build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
)
mkl_repository(
name = "mkl_windows",
urls = [
- "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.15/mklml_win_2018.0.3.20180406.zip",
- "https://github.com/intel/mkl-dnn/releases/download/v0.15/mklml_win_2018.0.3.20180406.zip",
+ "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.16/mklml_win_2019.0.20180710.zip",
+ "https://github.com/intel/mkl-dnn/releases/download/v0.16/mklml_win_2019.0.20180710.zip",
],
- sha256 = "a584a5bf1c8d2ad70b90d12b52652030e9a338217719064fdb84b7ad0d693694",
- strip_prefix = "mklml_win_2018.0.3.20180406",
+ sha256 = "3fdcff17b018a0082491adf3ba143358265336a801646e46e0191ec8d58d24a2",
+ strip_prefix = "mklml_win_2019.0.20180710",
build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
)
mkl_repository(
name = "mkl_darwin",
urls = [
- "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.15/mklml_mac_2018.0.3.20180406.tgz",
- "https://github.com/intel/mkl-dnn/releases/download/v0.15/mklml_mac_2018.0.3.20180406.tgz",
+ "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.16/mklml_mac_2019.0.20180710.tgz",
+ "https://github.com/intel/mkl-dnn/releases/download/v0.16/mklml_mac_2019.0.20180710.tgz",
],
- sha256 = "094e3dfd61c816136dc8d12a45cc611ce26c5f4828176a3644cd0b0efa15a25b",
- strip_prefix = "mklml_mac_2018.0.3.20180406",
+ sha256 = "411a30014a938eb83fb9f37b3dbe8e371b106fc1dd621fc23123cadc72737ce6",
+ strip_prefix = "mklml_mac_2019.0.20180710",
build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
)
@@ -95,22 +95,22 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "mkl_dnn",
urls = [
- "https://mirror.bazel.build/github.com/intel/mkl-dnn/archive/0c1cf54b63732e5a723c5670f66f6dfb19b64d20.tar.gz",
- "https://github.com/intel/mkl-dnn/archive/0c1cf54b63732e5a723c5670f66f6dfb19b64d20.tar.gz",
+ "https://mirror.bazel.build/github.com/intel/mkl-dnn/archive/4e333787e0d66a1dca1218e99a891d493dbc8ef1.tar.gz",
+ "https://github.com/intel/mkl-dnn/archive/4e333787e0d66a1dca1218e99a891d493dbc8ef1.tar.gz",
],
- sha256 = "da1f27f92453a65331197dd8e4992e810fb7b1c4e0b902a1da5611592df2b633",
- strip_prefix = "mkl-dnn-0c1cf54b63732e5a723c5670f66f6dfb19b64d20",
+ sha256 = "363cc9239eacf8e7917753c6d8c94f767e4cd049160d0654a61ef32d5e1b3049",
+ strip_prefix = "mkl-dnn-4e333787e0d66a1dca1218e99a891d493dbc8ef1",
build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"),
)
tf_http_archive(
name = "com_google_absl",
urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/f0f15c2778b0e4959244dd25e63f445a455870f5.tar.gz",
- "https://github.com/abseil/abseil-cpp/archive/f0f15c2778b0e4959244dd25e63f445a455870f5.tar.gz",
+ "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/c075ad321696fa5072e097f0a51e4fe76a6fe13e.tar.gz",
+ "https://github.com/abseil/abseil-cpp/archive/c075ad321696fa5072e097f0a51e4fe76a6fe13e.tar.gz",
],
- sha256 = "4ee36dacb75846eaa209ce8060bb269a42b7b3903612ca6d9e86a692659fe8c1",
- strip_prefix = "abseil-cpp-f0f15c2778b0e4959244dd25e63f445a455870f5",
+ sha256 = "cb4e11259742954f88802be6f33c1007c16502d90d68e8898b5e5084264ca8a9",
+ strip_prefix = "abseil-cpp-c075ad321696fa5072e097f0a51e4fe76a6fe13e",
build_file = clean_dep("//third_party:com_google_absl.BUILD"),
)
@@ -491,11 +491,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/deac5c28e00179be248aaf03abd329a848e8fac8.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/deac5c28e00179be248aaf03abd329a848e8fac8.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/67bd0d9a0f5597f57f272061fd70f24dffb3d223.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/67bd0d9a0f5597f57f272061fd70f24dffb3d223.tar.gz",
],
- sha256 = "bb55a553facff0408574a7bbd0d93c7371dbf527c7020fc6f4b9adeb0d83f780",
- strip_prefix = "llvm-deac5c28e00179be248aaf03abd329a848e8fac8",
+ sha256 = "b8f4ffbcaeea345e2245fd7028c7e960d71c2a2007c20bbfc5d79ecc86992a5e",
+ strip_prefix = "llvm-67bd0d9a0f5597f57f272061fd70f24dffb3d223",
build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
)
@@ -767,6 +767,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
],
build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
)
+
tf_http_archive(
name = "tflite_mobilenet_ssd_quant",
sha256 = "a809cd290b4d6a2e8a9d5dad076e0bd695b8091974e0eed1052b480b2f21b6dc",
@@ -778,6 +779,17 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
)
tf_http_archive(
+ name = "tflite_mobilenet_ssd_quant_protobuf",
+ sha256 = "09280972c5777f1aa775ef67cb4ac5d5ed21970acd8535aeca62450ef14f0d79",
+ urls = [
+ "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18.tar.gz",
+ "http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18.tar.gz",
+ ],
+ strip_prefix = "ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18",
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
+ )
+
+ tf_http_archive(
name = "tflite_conv_actions_frozen",
sha256 = "d947b38cba389b5e2d0bfc3ea6cc49c784e187b41a071387b3742d1acac7691e",
urls = [
diff --git a/third_party/clang_toolchain/download_clang.bzl b/third_party/clang_toolchain/download_clang.bzl
index 5ef47cdd0d..e782739661 100644
--- a/third_party/clang_toolchain/download_clang.bzl
+++ b/third_party/clang_toolchain/download_clang.bzl
@@ -39,15 +39,15 @@ def download_clang(repo_ctx, out_folder):
# Latest CLANG_REVISION and CLANG_SUB_REVISION of the Chromiums's release
# can be found in https://chromium.googlesource.com/chromium/src/tools/clang/+/master/scripts/update.py
- CLANG_REVISION = "338452"
+ CLANG_REVISION = "340427"
CLANG_SUB_REVISION = 1
package_version = "%s-%s" % (CLANG_REVISION, CLANG_SUB_REVISION)
checksums = {
- "Linux_x64": "213ba23a0a9855ede5041f66661caa9c5c59a573ec60b82a31839f9a97f397bf",
- "Mac": "4267774201f8cb50c25e081375e87038d58db80064a20a0d9d7fe57ea4357ece",
- "Win": "a8a5d5b25443c099e2c20d1a0cdce2f1d17e2dba84de66a6dc6a239ce3e78c34",
+ "Linux_x64": "8a8f21fb624fc7be7e91e439a13114847185375bb932db51ba590174ecaf764b",
+ "Mac": "ba894536b7c8d37103a5ddba784f268d55e65bb2ea1200a2cf9f2ef1590eaacd",
+ "Win": "c3f5bd977266dfd011411c94a13e00974b643b70fb0225a5fb030f7f703fa474",
}
platform_folder = _get_platform_folder(repo_ctx.os.name)
diff --git a/third_party/gpus/crosstool/CROSSTOOL.tpl b/third_party/gpus/crosstool/CROSSTOOL.tpl
index 3972c96a2f..3189cf8e31 100644
--- a/third_party/gpus/crosstool/CROSSTOOL.tpl
+++ b/third_party/gpus/crosstool/CROSSTOOL.tpl
@@ -208,7 +208,7 @@ toolchain {
action: "c++-link-dynamic-library"
action: "c++-link-nodeps-dynamic-library"
flag_group {
- flag: "-B/usr/bin/"
+ %{linker_bin_path_flag}
}
}
}
@@ -446,7 +446,7 @@ toolchain {
action: "c++-link-dynamic-library"
action: "c++-link-nodeps-dynamic-library"
flag_group {
- flag: "-B/usr/bin/"
+ %{linker_bin_path_flag}
}
}
}
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index f6a39aeaf1..5648b1525a 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -1303,6 +1303,19 @@ def _create_local_cuda_repository(repository_ctx):
host_compiler_includes = _host_compiler_includes(repository_ctx, cc_fullpath)
cuda_defines = {}
+ # Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see
+ # https://github.com/bazelbuild/bazel/issues/760).
+ # However, this stops our custom clang toolchain from picking the provided
+ # LLD linker, so we're only adding '-B/usr/bin' when using non-downloaded
+ # toolchain.
+ # TODO: when bazel stops adding '-B/usr/bin' by default, remove this
+ # flag from the CROSSTOOL completely (see
+ # https://github.com/bazelbuild/bazel/issues/5634)
+ if should_download_clang:
+ cuda_defines["%{linker_bin_path_flag}"] = ""
+ else:
+ cuda_defines["%{linker_bin_path_flag}"] = 'flag: "-B/usr/bin"'
+
if is_cuda_clang:
cuda_defines["%{host_compiler_path}"] = str(cc)
cuda_defines["%{host_compiler_warnings}"] = """
diff --git a/third_party/toolchains/gpus/crosstool/BUILD b/third_party/toolchains/gpus/crosstool/BUILD
index 1f9065007c..bb0b6b3bbb 100644
--- a/third_party/toolchains/gpus/crosstool/BUILD
+++ b/third_party/toolchains/gpus/crosstool/BUILD
@@ -11,6 +11,7 @@ cc_toolchain_suite(
toolchains = {
"local|compiler": ":cc-compiler-local",
"darwin|compiler": ":cc-compiler-darwin",
+ "x64_windows|msvc-cl": ":cc-compiler-windows",
},
)
@@ -46,6 +47,20 @@ cc_toolchain(
supports_param_files = 0,
)
+cc_toolchain(
+ name = "cc-compiler-windows",
+ all_files = ":empty",
+ compiler_files = ":empty",
+ cpu = "x64_windows",
+ dwp_files = ":empty",
+ dynamic_runtime_libs = [":empty"],
+ linker_files = ":empty",
+ objcopy_files = ":empty",
+ static_runtime_libs = [":empty"],
+ strip_files = ":empty",
+ supports_param_files = 1,
+)
+
filegroup(
name = "empty",
srcs = [],
@@ -55,3 +70,8 @@ filegroup(
name = "crosstool_wrapper_driver_is_not_gcc",
srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"],
)
+
+filegroup(
+ name = "windows_msvc_wrapper_files",
+ srcs = glob(["windows/msvc_*"]),
+)
diff --git a/third_party/toolchains/gpus/crosstool/CROSSTOOL b/third_party/toolchains/gpus/crosstool/CROSSTOOL
index d6ee7e38c4..b8eeb31ecb 100644
--- a/third_party/toolchains/gpus/crosstool/CROSSTOOL
+++ b/third_party/toolchains/gpus/crosstool/CROSSTOOL
@@ -26,6 +26,10 @@ default_toolchain {
cpu: "ppc"
toolchain_identifier: "local_linux"
}
+default_toolchain {
+ cpu: "x64_windows"
+ toolchain_identifier: "local_windows"
+}
toolchain {
abi_version: "local"
@@ -144,9 +148,11 @@ toolchain {
flag_group {
# All warnings are enabled. Maybe enable -Werror as well?
flag: "-Wall"
+
# Some parts of the codebase set -Werror and hit this warning, so
# switch it off for now.
flag: "-Wno-invalid-partial-specialization"
+
}
}
}
@@ -307,3 +313,1120 @@ toolchain {
cxx_builtin_include_directory: "/usr/include/x86_64-linux-gnu"
cxx_builtin_include_directory: "/usr/include"
}
+
+toolchain {
+ abi_version: "local"
+ abi_libc_version: "local"
+ compiler: "compiler"
+ host_system_name: "local"
+ needsPic: true
+ target_libc: "macosx"
+ target_cpu: "darwin"
+ target_system_name: "local"
+ toolchain_identifier: "local_darwin"
+ feature {
+ name: "c++11"
+ flag_set {
+ action: "c++-compile"
+ flag_group {
+ flag: "-std=c++11"
+ }
+ }
+ }
+
+ feature {
+ name: "stdlib"
+ flag_set {
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "-lc++"
+ }
+ }
+ }
+
+ feature {
+ name: "determinism"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # Make C++ compilation deterministic. Use linkstamping instead of these
+ # compiler symbols.
+ flag: "-Wno-builtin-macro-redefined"
+ flag: "-D__DATE__=\"redacted\""
+ flag: "-D__TIMESTAMP__=\"redacted\""
+ flag: "-D__TIME__=\"redacted\""
+ }
+ }
+ }
+
+ # This feature will be enabled for builds that support pic by bazel.
+ feature {
+ name: "pic"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ expand_if_all_available: "pic"
+ flag: "-fPIC"
+ }
+ flag_group {
+ expand_if_none_available: "pic"
+ flag: "-fPIE"
+ }
+ }
+ }
+
+ # Security hardening on by default.
+ feature {
+ name: "hardening"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # Conservative choice; -D_FORTIFY_SOURCE=2 may be unsafe in some cases.
+ # We need to undef it before redefining it as some distributions now
+ # have it enabled by default.
+ flag: "-U_FORTIFY_SOURCE"
+ flag: "-D_FORTIFY_SOURCE=1"
+ flag: "-fstack-protector"
+ }
+ }
+ flag_set {
+ action: "c++-link-executable"
+ flag_group {
+ flag: "-pie"
+ }
+ }
+ }
+
+ feature {
+ name: "warnings"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # All warnings are enabled. Maybe enable -Werror as well?
+ flag: "-Wall"
+
+ # Some parts of the codebase set -Werror and hit this warning, so
+ # switch it off for now.
+ flag: "-Wno-invalid-partial-specialization"
+
+ }
+ }
+ }
+
+ # Keep stack frames for debugging, even in opt mode.
+ feature {
+ name: "frame-pointer"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-fno-omit-frame-pointer"
+ }
+ }
+ }
+
+ feature {
+ name: "no-canonical-prefixes"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag:"-no-canonical-prefixes"
+ }
+ }
+ }
+
+ feature {
+ name: "disable-assertions"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-DNDEBUG"
+ }
+ }
+ }
+
+ feature {
+ name: "linker-bin-path"
+
+ flag_set {
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "-B/usr/bin/"
+ }
+ }
+ }
+
+ feature {
+ name: "undefined-dynamic"
+ flag_set {
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ action: "c++-link-executable"
+ flag_group {
+ flag: "-undefined"
+ flag: "dynamic_lookup"
+ }
+ }
+ }
+
+ feature {
+ name: "common"
+ implies: "stdlib"
+ implies: "c++11"
+ implies: "determinism"
+ implies: "hardening"
+ implies: "warnings"
+ implies: "frame-pointer"
+ implies: "no-canonical-prefixes"
+ implies: "linker-bin-path"
+ implies: "undefined-dynamic"
+ }
+
+ feature {
+ name: "opt"
+ implies: "common"
+ implies: "disable-assertions"
+
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # No debug symbols.
+ # Maybe we should enable https://gcc.gnu.org/wiki/DebugFission for opt
+ # or even generally? However, that can't happen here, as it requires
+ # special handling in Bazel.
+ flag: "-g0"
+
+ # Conservative choice for -O
+ # -O3 can increase binary size and even slow down the resulting binaries.
+ # Profile first and / or use FDO if you need better performance than this.
+ flag: "-O2"
+
+ # Removal of unused code and data at link time (can this increase binary size in some cases?).
+ flag: "-ffunction-sections"
+ flag: "-fdata-sections"
+ }
+ }
+ }
+
+ feature {
+ name: "fastbuild"
+ implies: "common"
+ }
+
+ feature {
+ name: "dbg"
+ implies: "common"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-g"
+ }
+ }
+ }
+
+ # Set clang as a C/C++ compiler.
+ tool_path { name: "gcc" path: "/usr/local/bin/clang" }
+
+ # Use the default system toolchain for everything else.
+ tool_path { name: "ar" path: "/usr/bin/libtool" }
+ tool_path { name: "compat-ld" path: "/usr/bin/ld" }
+ tool_path { name: "cpp" path: "/usr/bin/cpp" }
+ tool_path { name: "dwp" path: "/usr/bin/dwp" }
+ tool_path { name: "gcov" path: "/usr/bin/gcov" }
+ tool_path { name: "ld" path: "/usr/bin/ld" }
+ tool_path { name: "nm" path: "/usr/bin/nm" }
+ tool_path { name: "objcopy" path: "/usr/bin/objcopy" }
+ tool_path { name: "objdump" path: "/usr/bin/objdump" }
+ tool_path { name: "strip" path: "/usr/bin/strip" }
+
+ # Enabled dynamic linking.
+ linking_mode_flags { mode: DYNAMIC }
+
+ cxx_builtin_include_directory: "/usr/include/c++/5.4.0"
+ cxx_builtin_include_directory: "/usr/include/x86_64-linux-gnu/c++/5.4.0"
+ cxx_builtin_include_directory: "/usr/include/c++/5.4.0/backward"
+ cxx_builtin_include_directory: "/usr/local/include"
+ cxx_builtin_include_directory: "/usr/local/lib/clang/7.0.0/include"
+ cxx_builtin_include_directory: "/usr/include/x86_64-linux-gnu"
+ cxx_builtin_include_directory: "/usr/include"
+}
+
+toolchain {
+ toolchain_identifier: "local_windows"
+ host_system_name: "local"
+ target_system_name: "local"
+
+ abi_version: "local"
+ abi_libc_version: "local"
+ target_cpu: "x64_windows"
+ compiler: "msvc-cl"
+ target_libc: "msvcrt"
+
+
+
+ tool_path {
+ name: "ar"
+ path: ""
+ }
+ tool_path {
+ name: "ml"
+ path: ""
+ }
+ tool_path {
+ name: "cpp"
+ path: ""
+ }
+ tool_path {
+ name: "gcc"
+ path: ""
+ }
+ tool_path {
+ name: "gcov"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "ld"
+ path: ""
+ }
+ tool_path {
+ name: "nm"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "objcopy"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "objdump"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "strip"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ supports_interface_shared_objects: true
+
+ # TODO(pcloudy): Review those flags below, they should be defined by cl.exe
+ compiler_flag: "/DCOMPILER_MSVC"
+
+ # Don't define min/max macros in windows.h.
+ compiler_flag: "/DNOMINMAX"
+
+ # Platform defines.
+ compiler_flag: "/D_WIN32_WINNT=0x0600"
+ # Turn off warning messages.
+ compiler_flag: "/D_CRT_SECURE_NO_DEPRECATE"
+ compiler_flag: "/D_CRT_SECURE_NO_WARNINGS"
+ compiler_flag: "/D_SILENCE_STDEXT_HASH_DEPRECATION_WARNINGS"
+
+ # Useful options to have on for compilation.
+ # Increase the capacity of object files to 2^32 sections.
+ compiler_flag: "/bigobj"
+ # Allocate 500MB for precomputed headers.
+ compiler_flag: "/Zm500"
+ # Use unsigned char by default.
+ compiler_flag: "/J"
+ # Use function level linking.
+ compiler_flag: "/Gy"
+ # Use string pooling.
+ compiler_flag: "/GF"
+ # Catch C++ exceptions only and tell the compiler to assume that functions declared
+ # as extern "C" never throw a C++ exception.
+ compiler_flag: "/EHsc"
+
+ # Globally disabled warnings.
+ # Don't warn about elements of array being be default initialized.
+ compiler_flag: "/wd4351"
+ # Don't warn about no matching delete found.
+ compiler_flag: "/wd4291"
+ # Don't warn about diamond inheritance patterns.
+ compiler_flag: "/wd4250"
+ # Don't warn about insecure functions (e.g. non _s functions).
+ compiler_flag: "/wd4996"
+
+ linker_flag: "/MACHINE:X64"
+
+ feature {
+ name: "no_legacy_features"
+ }
+
+ # Suppress startup banner.
+ feature {
+ name: "nologo"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-module-compile"
+ action: "c++-module-codegen"
+ action: "c++-header-parsing"
+ action: "assemble"
+ action: "preprocess-assemble"
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ action: "c++-link-static-library"
+ flag_group {
+ flag: "/nologo"
+ }
+ }
+ }
+
+ feature {
+ name: 'has_configured_linker_path'
+ }
+
+ # This feature indicates strip is not supported, building stripped binary will just result a copy of orignial binary
+ feature {
+ name: 'no_stripping'
+ }
+
+ # This feature indicates this is a toolchain targeting Windows.
+ feature {
+ name: 'targets_windows'
+ implies: 'copy_dynamic_libraries_to_binary'
+ enabled: true
+ }
+
+ feature {
+ name: 'copy_dynamic_libraries_to_binary'
+ }
+
+ action_config {
+ config_name: 'assemble'
+ action_name: 'assemble'
+ tool {
+ tool_path: ''
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'sysroot'
+ }
+
+ action_config {
+ config_name: 'preprocess-assemble'
+ action_name: 'preprocess-assemble'
+ tool {
+ tool_path: ''
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'sysroot'
+ }
+
+ action_config {
+ config_name: 'c-compile'
+ action_name: 'c-compile'
+ tool {
+ tool_path: ''
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'legacy_compile_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'parse_showincludes'
+ implies: 'user_compile_flags'
+ implies: 'sysroot'
+ implies: 'unfiltered_compile_flags'
+ }
+
+ action_config {
+ config_name: 'c++-compile'
+ action_name: 'c++-compile'
+ tool {
+ tool_path: ''
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'legacy_compile_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'parse_showincludes'
+ implies: 'user_compile_flags'
+ implies: 'sysroot'
+ implies: 'unfiltered_compile_flags'
+ }
+
+ action_config {
+ config_name: 'c++-link-executable'
+ action_name: 'c++-link-executable'
+ tool {
+ tool_path: ''
+ }
+ implies: 'nologo'
+ implies: 'linkstamps'
+ implies: 'output_execpath_flags'
+ implies: 'input_param_flags'
+ implies: 'user_link_flags'
+ implies: 'legacy_link_flags'
+ implies: 'linker_subsystem_flag'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ implies: 'no_stripping'
+ }
+
+ action_config {
+ config_name: 'c++-link-dynamic-library'
+ action_name: 'c++-link-dynamic-library'
+ tool {
+ tool_path: ''
+ }
+ implies: 'nologo'
+ implies: 'shared_flag'
+ implies: 'linkstamps'
+ implies: 'output_execpath_flags'
+ implies: 'input_param_flags'
+ implies: 'user_link_flags'
+ implies: 'legacy_link_flags'
+ implies: 'linker_subsystem_flag'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ implies: 'no_stripping'
+ implies: 'has_configured_linker_path'
+ implies: 'def_file'
+ }
+
+ action_config {
+ config_name: 'c++-link-nodeps-dynamic-library'
+ action_name: 'c++-link-nodeps-dynamic-library'
+ tool {
+ tool_path: ''
+ }
+ implies: 'nologo'
+ implies: 'shared_flag'
+ implies: 'linkstamps'
+ implies: 'output_execpath_flags'
+ implies: 'input_param_flags'
+ implies: 'user_link_flags'
+ implies: 'legacy_link_flags'
+ implies: 'linker_subsystem_flag'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ implies: 'no_stripping'
+ implies: 'has_configured_linker_path'
+ implies: 'def_file'
+ }
+
+ action_config {
+ config_name: 'c++-link-static-library'
+ action_name: 'c++-link-static-library'
+ tool {
+ tool_path: ''
+ }
+ implies: 'nologo'
+ implies: 'archiver_flags'
+ implies: 'input_param_flags'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ }
+
+ # TODO(b/65151735): Remove legacy_compile_flags feature when legacy fields are
+ # not used in this crosstool
+ feature {
+ name: 'legacy_compile_flags'
+ flag_set {
+ expand_if_all_available: 'legacy_compile_flags'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ iterate_over: 'legacy_compile_flags'
+ flag: '%{legacy_compile_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: "msvc_env"
+ env_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-module-compile"
+ action: "c++-module-codegen"
+ action: "c++-header-parsing"
+ action: "assemble"
+ action: "preprocess-assemble"
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ action: "c++-link-static-library"
+ env_entry {
+ key: "PATH"
+ value: ""
+ }
+ env_entry {
+ key: "INCLUDE"
+ value: ""
+ }
+ env_entry {
+ key: "LIB"
+ value: ""
+ }
+ env_entry {
+ key: "TMP"
+ value: ""
+ }
+ env_entry {
+ key: "TEMP"
+ value: ""
+ }
+ }
+ }
+
+ feature {
+ name: 'include_paths'
+ flag_set {
+ action: "assemble"
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ flag_group {
+ iterate_over: 'quote_include_paths'
+ flag: '/I%{quote_include_paths}'
+ }
+ flag_group {
+ iterate_over: 'include_paths'
+ flag: '/I%{include_paths}'
+ }
+ flag_group {
+ iterate_over: 'system_include_paths'
+ flag: '/I%{system_include_paths}'
+ }
+ }
+ }
+
+ feature {
+ name: "preprocessor_defines"
+ flag_set {
+ action: "assemble"
+ action: "preprocess-assemble"
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-header-parsing"
+ action: "c++-module-compile"
+ flag_group {
+ flag: "/D%{preprocessor_defines}"
+ iterate_over: "preprocessor_defines"
+ }
+ }
+ }
+
+ # Tell Bazel to parse the output of /showIncludes
+ feature {
+ name: 'parse_showincludes'
+ flag_set {
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-module-compile'
+ action: 'c++-header-parsing'
+ flag_group {
+ flag: "/showIncludes"
+ }
+ }
+ }
+
+
+ feature {
+ name: 'generate_pdb_file'
+ requires: {
+ feature: 'dbg'
+ }
+ requires: {
+ feature: 'fastbuild'
+ }
+ }
+
+ feature {
+ name: 'shared_flag'
+ flag_set {
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: '/DLL'
+ }
+ }
+ }
+
+ feature {
+ name: 'linkstamps'
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ expand_if_all_available: 'linkstamp_paths'
+ flag_group {
+ iterate_over: 'linkstamp_paths'
+ flag: '%{linkstamp_paths}'
+ }
+ }
+ }
+
+ feature {
+ name: 'output_execpath_flags'
+ flag_set {
+ expand_if_all_available: 'output_execpath'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: '/OUT:%{output_execpath}'
+ }
+ }
+ }
+
+ feature {
+ name: 'archiver_flags'
+ flag_set {
+ expand_if_all_available: 'output_execpath'
+ action: 'c++-link-static-library'
+ flag_group {
+ flag: '/OUT:%{output_execpath}'
+ }
+ }
+ }
+
+ feature {
+ name: 'input_param_flags'
+ flag_set {
+ expand_if_all_available: 'interface_library_output_path'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/IMPLIB:%{interface_library_output_path}"
+ }
+ }
+ flag_set {
+ expand_if_all_available: 'libopts'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'libopts'
+ flag: '%{libopts}'
+ }
+ }
+ flag_set {
+ expand_if_all_available: 'libraries_to_link'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ action: 'c++-link-static-library'
+ flag_group {
+ iterate_over: 'libraries_to_link'
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'object_file_group'
+ }
+ iterate_over: 'libraries_to_link.object_files'
+ flag_group {
+ flag: '%{libraries_to_link.object_files}'
+ }
+ }
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'object_file'
+ }
+ flag_group {
+ flag: '%{libraries_to_link.name}'
+ }
+ }
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'interface_library'
+ }
+ flag_group {
+ flag: '%{libraries_to_link.name}'
+ }
+ }
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'static_library'
+ }
+ flag_group {
+ expand_if_false: 'libraries_to_link.is_whole_archive'
+ flag: '%{libraries_to_link.name}'
+ }
+ flag_group {
+ expand_if_true: 'libraries_to_link.is_whole_archive'
+ flag: '/WHOLEARCHIVE:%{libraries_to_link.name}'
+ }
+ }
+ }
+ }
+ }
+
+ # Since this feature is declared earlier in the CROSSTOOL than
+ # "user_link_flags", this feature will be applied prior to it anwyhere they
+ # are both implied. And since "user_link_flags" contains the linkopts from
+ # the build rule, this allows the user to override the /SUBSYSTEM in the BUILD
+ # file.
+ feature {
+ name: 'linker_subsystem_flag'
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: '/SUBSYSTEM:CONSOLE'
+ }
+ }
+ }
+
+ # The "user_link_flags" contains user-defined linkopts (from build rules)
+ # so it should be defined after features that declare user-overridable flags.
+ # For example the "linker_subsystem_flag" defines a default "/SUBSYSTEM" flag
+ # but we want to let the user override it, therefore "link_flag_subsystem" is
+ # defined earlier in the CROSSTOOL file than "user_link_flags".
+ feature {
+ name: 'user_link_flags'
+ flag_set {
+ expand_if_all_available: 'user_link_flags'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'user_link_flags'
+ flag: '%{user_link_flags}'
+ }
+ }
+ }
+ feature {
+ name: 'legacy_link_flags'
+ flag_set {
+ expand_if_all_available: 'legacy_link_flags'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'legacy_link_flags'
+ flag: '%{legacy_link_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: 'linker_param_file'
+ flag_set {
+ expand_if_all_available: 'linker_param_file'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ action: 'c++-link-static-library'
+ flag_group {
+ flag: '@%{linker_param_file}'
+ }
+ }
+ }
+
+ feature {
+ name: 'static_link_msvcrt'
+ }
+
+ feature {
+ name: 'static_link_msvcrt_no_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MT"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:libcmt.lib"
+ }
+ }
+ requires: { feature: 'fastbuild'}
+ requires: { feature: 'opt'}
+ }
+
+ feature {
+ name: 'dynamic_link_msvcrt_no_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MD"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:msvcrt.lib"
+ }
+ }
+ requires: { feature: 'fastbuild'}
+ requires: { feature: 'opt'}
+ }
+
+ feature {
+ name: 'static_link_msvcrt_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MTd"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:libcmtd.lib"
+ }
+ }
+ requires: { feature: 'dbg'}
+ }
+
+ feature {
+ name: 'dynamic_link_msvcrt_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MDd"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:msvcrtd.lib"
+ }
+ }
+ requires: { feature: 'dbg'}
+ }
+
+ feature {
+ name: 'dbg'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/Od"
+ flag: "/Z7"
+ flag: "/DDEBUG"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEBUG:FULL"
+ flag: "/INCREMENTAL:NO"
+ }
+ }
+ implies: 'generate_pdb_file'
+ }
+
+ feature {
+ name: 'fastbuild'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/Od"
+ flag: "/Z7"
+ flag: "/DDEBUG"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEBUG:FASTLINK"
+ flag: "/INCREMENTAL:NO"
+ }
+ }
+ implies: 'generate_pdb_file'
+ }
+
+ feature {
+ name: 'opt'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/O2"
+ flag: "/DNDEBUG"
+ }
+ }
+ }
+
+ feature {
+ name: 'user_compile_flags'
+ flag_set {
+ expand_if_all_available: 'user_compile_flags'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ iterate_over: 'user_compile_flags'
+ flag: '%{user_compile_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: 'sysroot'
+ flag_set {
+ expand_if_all_available: 'sysroot'
+ action: 'assemble'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'sysroot'
+ flag: '--sysroot=%{sysroot}'
+ }
+ }
+ }
+
+ feature {
+ name: 'unfiltered_compile_flags'
+ flag_set {
+ expand_if_all_available: 'unfiltered_compile_flags'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ iterate_over: 'unfiltered_compile_flags'
+ flag: '%{unfiltered_compile_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: 'compiler_output_flags'
+ flag_set {
+ action: 'assemble'
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_none_available: 'output_assembly_file'
+ expand_if_none_available: 'output_preprocess_file'
+ flag: '/Fo%{output_file}'
+ flag: '/Zi'
+ }
+ }
+ flag_set {
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_none_available: 'output_assembly_file'
+ expand_if_none_available: 'output_preprocess_file'
+ flag: '/Fo%{output_file}'
+ }
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_all_available: 'output_assembly_file'
+ flag: '/Fa%{output_file}'
+ }
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_all_available: 'output_preprocess_file'
+ flag: '/P'
+ flag: '/Fi%{output_file}'
+ }
+ }
+ }
+
+ feature {
+ name: 'compiler_input_flags'
+ flag_set {
+ action: 'assemble'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ expand_if_all_available: 'source_file'
+ flag: '/c'
+ flag: '%{source_file}'
+ }
+ }
+ }
+
+ feature {
+ name : 'def_file',
+ flag_set {
+ expand_if_all_available: 'def_file_path'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEF:%{def_file_path}"
+ # We can specify a different DLL name in DEF file, /ignore:4070 suppresses
+ # the warning message about DLL name doesn't match the default one.
+ # See https://msdn.microsoft.com/en-us/library/sfkk2fz7.aspx
+ flag: "/ignore:4070"
+ }
+ }
+ }
+
+ feature {
+ name: 'windows_export_all_symbols'
+ }
+
+ feature {
+ name: 'no_windows_export_all_symbols'
+ }
+
+ linking_mode_flags { mode: DYNAMIC }
+} \ No newline at end of file
diff --git a/third_party/toolchains/gpus/cuda/BUILD b/third_party/toolchains/gpus/cuda/BUILD
index 4cb8380938..f59e025019 100644
--- a/third_party/toolchains/gpus/cuda/BUILD
+++ b/third_party/toolchains/gpus/cuda/BUILD
@@ -133,6 +133,15 @@ cc_library(
)
cc_library(
+ name = "cudnn_header",
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
name = "cufft",
srcs = ["cuda/lib/libcufft.so.9.0"],
data = ["cuda/lib/libcufft.so.9.0"],
@@ -1191,33 +1200,10 @@ if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/includ
genrule(
name = "cuda-nvvm",
outs = [
- "cuda/nvvm/bin/cicc",
- "cuda/nvvm/include/nvvm.h",
- "cuda/nvvm/lib64/libnvvm.so",
- "cuda/nvvm/lib64/libnvvm.so.3",
- "cuda/nvvm/lib64/libnvvm.so.3.2.0",
"cuda/nvvm/libdevice/libdevice.10.bc",
- "cuda/nvvm/libnvvm-samples/CMakeLists.txt",
- "cuda/nvvm/libnvvm-samples/README.txt",
- "cuda/nvvm/libnvvm-samples/build.bat",
- "cuda/nvvm/libnvvm-samples/build.sh",
- "cuda/nvvm/libnvvm-samples/common/include/DDSWriter.h",
- "cuda/nvvm/libnvvm-samples/common/include/drvapi_error_string.h",
- "cuda/nvvm/libnvvm-samples/cuda-c-linking/CMakeLists.txt",
- "cuda/nvvm/libnvvm-samples/cuda-c-linking/README.txt",
- "cuda/nvvm/libnvvm-samples/cuda-c-linking/cuda-c-linking.cpp",
- "cuda/nvvm/libnvvm-samples/cuda-c-linking/math-funcs.cu",
- "cuda/nvvm/libnvvm-samples/ptxgen/CMakeLists.txt",
- "cuda/nvvm/libnvvm-samples/ptxgen/README.txt",
- "cuda/nvvm/libnvvm-samples/ptxgen/ptxgen.c",
- "cuda/nvvm/libnvvm-samples/simple/CMakeLists.txt",
- "cuda/nvvm/libnvvm-samples/simple/README.txt",
- "cuda/nvvm/libnvvm-samples/simple/simple-gpu.ll",
- "cuda/nvvm/libnvvm-samples/simple/simple-gpu64.ll",
- "cuda/nvvm/libnvvm-samples/simple/simple.c",
],
cmd = """
-if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/nvvm/bin/cicc" "$(@D)/cuda/nvvm/bin/cicc" && cp "/usr/local/cuda-9.0/nvvm/include/nvvm.h" "$(@D)/cuda/nvvm/include/nvvm.h" && cp "/usr/local/cuda-9.0/nvvm/lib64/libnvvm.so" "$(@D)/cuda/nvvm/lib64/libnvvm.so" && cp "/usr/local/cuda-9.0/nvvm/lib64/libnvvm.so.3" "$(@D)/cuda/nvvm/lib64/libnvvm.so.3" && cp "/usr/local/cuda-9.0/nvvm/lib64/libnvvm.so.3.2.0" "$(@D)/cuda/nvvm/lib64/libnvvm.so.3.2.0" && cp "/usr/local/cuda-9.0/nvvm/libdevice/libdevice.10.bc" "$(@D)/cuda/nvvm/libdevice/libdevice.10.bc" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/CMakeLists.txt" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/README.txt" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/build.bat" "$(@D)/cuda/nvvm/libnvvm-samples/build.bat" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/build.sh" "$(@D)/cuda/nvvm/libnvvm-samples/build.sh" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/common/include/DDSWriter.h" "$(@D)/cuda/nvvm/libnvvm-samples/common/include/DDSWriter.h" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/common/include/drvapi_error_string.h" "$(@D)/cuda/nvvm/libnvvm-samples/common/include/drvapi_error_string.h" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/cuda-c-linking/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/CMakeLists.txt" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/cuda-c-linking/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/README.txt" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/cuda-c-linking/cuda-c-linking.cpp" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/cuda-c-linking.cpp" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/cuda-c-linking/math-funcs.cu" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/math-funcs.cu" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/ptxgen/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/ptxgen/CMakeLists.txt" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/ptxgen/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/ptxgen/README.txt" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/ptxgen/ptxgen.c" "$(@D)/cuda/nvvm/libnvvm-samples/ptxgen/ptxgen.c" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/simple/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/simple/CMakeLists.txt" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/simple/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/simple/README.txt" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/simple/simple-gpu.ll" "$(@D)/cuda/nvvm/libnvvm-samples/simple/simple-gpu.ll" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/simple/simple-gpu64.ll" "$(@D)/cuda/nvvm/libnvvm-samples/simple/simple-gpu64.ll" && cp "/usr/local/cuda-9.0/nvvm/libnvvm-samples/simple/simple.c" "$(@D)/cuda/nvvm/libnvvm-samples/simple/simple.c"
+if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/nvvm/libdevice/libdevice.10.bc" "$(@D)//libdevice.10.bc"
""",
)
@@ -1272,7 +1258,7 @@ genrule(
"cuda/lib/libcupti.so.9.0",
],
cmd = """
-if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/stubs/libcuda.so" "$(@D)/cuda/lib/libcuda.so" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so.9.0" "$(@D)/cuda/lib/libcudart.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart_static.a" "$(@D)/cuda/lib/libcudart_static.a" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcublas.so.9.0" "$(@D)/cuda/lib/libcublas.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcusolver.so.9.0" "$(@D)/cuda/lib/libcusolver.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcurand.so.9.0" "$(@D)/cuda/lib/libcurand.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcufft.so.9.0" "$(@D)/cuda/lib/libcufft.so.9.0" && cp "/usr/lib/x86_64-linux-gnu/libcudnn.so.7" "$(@D)/cuda/lib/libcudnn.so.7" && cp "/usr/local/cuda-9.0/extras/CUPTI/lib64/libcupti.so.9.0" "$(@D)/cuda/lib/libcupti.so.9.0"
+if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/stubs/libcuda.so" "$(@D)/cuda/lib/libcuda.so" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so.9.0.176" "$(@D)/cuda/lib/libcudart.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart_static.a" "$(@D)/cuda/lib/libcudart_static.a" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcublas.so.9.0.480" "$(@D)/cuda/lib/libcublas.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcusolver.so.9.0.176" "$(@D)/cuda/lib/libcusolver.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcurand.so.9.0.176" "$(@D)/cuda/lib/libcurand.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcufft.so.9.0.176" "$(@D)/cuda/lib/libcufft.so.9.0" && cp "/usr/lib/x86_64-linux-gnu/libcudnn.so.7.2.1" "$(@D)/cuda/lib/libcudnn.so.7" && cp "/usr/local/cuda-9.0/extras/CUPTI/lib64/libcupti.so.9.0.176" "$(@D)/cuda/lib/libcupti.so.9.0"
""",
)
diff --git a/third_party/toolchains/gpus/cuda/build_defs.bzl b/third_party/toolchains/gpus/cuda/build_defs.bzl
index badaf43019..9210bfe016 100644
--- a/third_party/toolchains/gpus/cuda/build_defs.bzl
+++ b/third_party/toolchains/gpus/cuda/build_defs.bzl
@@ -2,6 +2,7 @@
# execution service.
# DO NOT EDIT: automatically generated file
+# Macros for building CUDA code.
def if_cuda(if_true, if_false = []):
"""Shorthand for select()'ing on whether we're building with CUDA.
@@ -12,15 +13,13 @@ def if_cuda(if_true, if_false = []):
return select({
"@local_config_cuda//cuda:using_nvcc": if_true,
"@local_config_cuda//cuda:using_clang": if_true,
- "//conditions:default": if_false
+ "//conditions:default": if_false,
})
-
def cuda_default_copts():
"""Default options for all CUDA compilations."""
return if_cuda(["-x", "cuda", "-DGOOGLE_CUDA=1"] + ["--cuda-gpu-arch=sm_30"])
-
def cuda_is_configured():
"""Returns true if CUDA was enabled during the configure process."""
return True
@@ -32,6 +31,5 @@ def if_cuda_is_configured(x):
--config=cuda. Used to allow non-CUDA code to depend on CUDA libraries.
"""
if cuda_is_configured():
- return x
+ return x
return []
-
diff --git a/third_party/toolchains/gpus/cuda/cuda/cuda_config.h b/third_party/toolchains/gpus/cuda/cuda/cuda_config.h
index f6662274cc..7cdaf144ad 100644
--- a/third_party/toolchains/gpus/cuda/cuda/cuda_config.h
+++ b/third_party/toolchains/gpus/cuda/cuda/cuda_config.h
@@ -19,9 +19,9 @@ limitations under the License.
#define TF_CUDA_CAPABILITIES CudaVersion("3.0")
-#define TF_CUDA_VERSION "8.0"
-#define TF_CUDNN_VERSION "5"
+#define TF_CUDA_VERSION "9.0"
+#define TF_CUDNN_VERSION "7"
-#define TF_CUDA_TOOLKIT_PATH "/usr/local/cuda-8.0"
+#define TF_CUDA_TOOLKIT_PATH "/usr/local/cuda-9.0"
#endif // CUDA_CUDA_CONFIG_H_
diff --git a/third_party/toolchains/gpus/py/BUILD b/third_party/toolchains/gpus/py/BUILD
index 2d5ace93ff..1235988abb 100644
--- a/third_party/toolchains/gpus/py/BUILD
+++ b/third_party/toolchains/gpus/py/BUILD
@@ -6,18 +6,24 @@ licenses(["restricted"])
package(default_visibility = ["//visibility:public"])
+# To build Python C/C++ extension on Windows, we need to link to python import library pythonXY.lib
+# See https://docs.python.org/3/extending/windows.html
+cc_import(
+ name = "python_lib",
+ interface_library = select({
+ ":windows": ":python_import_lib",
+ # A placeholder for Unix platforms which makes --no_build happy.
+ "//conditions:default": "not-existing.lib",
+ }),
+ system_provided = 1,
+)
+
cc_library(
name = "python_headers",
hdrs = [":python_include"],
- data = select({
- ":windows": [":python_import_lib"],
- "//conditions:default": [],
- }),
includes = ["python_include"],
- linkopts = select({
- # TODO(pcloudy): Ideally, this should just go into deps after resolving
- # https://github.com/bazelbuild/bazel/issues/3237,
- ":windows": ["$(locations :python_import_lib)"],
+ deps = select({
+ ":windows": [":python_lib"],
"//conditions:default": [],
}),
)
diff --git a/tools/bazel.rc b/tools/bazel.rc
index 660e3d3280..601e07ffdd 100644
--- a/tools/bazel.rc
+++ b/tools/bazel.rc
@@ -33,6 +33,11 @@ build:mkl_open_source_only --define=using_mkl_dnn_only=true
build:download_clang --crosstool_top=@local_config_download_clang//:toolchain
build:download_clang --define=using_clang=true
+# Instruct clang to use LLD for linking.
+# This only works with GPU builds currently, since Bazel sets -B/usr/bin in
+# auto-generated CPU crosstool, forcing /usr/bin/ld.lld to be preferred over
+# the downloaded one.
+build:download_clang_use_lld --linkopt='-fuse-ld=lld'
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true